본문 바로가기
RL with python/Python example code

[RL-Python] Multi-armed bandit - Epsilon Greedy

by achrxme 2023. 8. 21.

 

import numpy as np
from scipy import stats
import random
import matplotlib.pyplot as plt


def get_reward(prob, n=10):
    reward = 0
    for i in range(n):
        if random.random() < prob:
            reward += 1
    return reward


def get_best_arm(record):
    arm_index = np.argmax(record[:, 1], axis=0)
    return arm_index


def update_record(record, action, r):
    # 지금 까지의 평균 * 시행 수 = 지금 까지 reward의 합 
    # -> 이걸 이용 해서 시행 수 + 1, 새로운 reward 포함하여 record update
    new_r = (record[action, 0] * record[action, 1] + r) / (record[action, 0] + 1)
    record[action, 0] += 1
    record[action, 1] = new_r
    return record


n = 10
record = np.zeros((n, 2))
probs = np.random.rand(n)
eps = 0.2

fig, ax = plt.subplots(1, 1)
ax.set_xlabel("Plays")
ax.set_ylabel("Avg Reward")

fig.set_size_inches(9, 5)
rewards = [0]
for i in range(5000):

    if random.random() > eps:
        choice = get_best_arm(record)   # Exploitation -> greedy
    else:
        choice = np.random.randint(10)  # Exploration

    r = get_reward(probs[choice])  # bandit Reward -> 10번 돌려서 probs[choice]보다 높게 나온 수
    
    record = update_record(record, choice, r)   # Update the recode = (시행횟수, 평균 Reward)를 담고 있음
    
    mean_reward = ((i + 1) * rewards[-1] + r) / (i + 2)
    rewards.append(mean_reward)
ax.scatter(np.arange(len(rewards)), rewards)

plt.show()