import numpy as np
from scipy import stats
import random
import matplotlib.pyplot as plt
def get_reward(prob, n=10):
reward = 0
for i in range(n):
if random.random() < prob:
reward += 1
return reward
def get_best_arm(record):
arm_index = np.argmax(record[:, 1], axis=0)
return arm_index
def update_record(record, action, r):
# 지금 까지의 평균 * 시행 수 = 지금 까지 reward의 합
# -> 이걸 이용 해서 시행 수 + 1, 새로운 reward 포함하여 record update
new_r = (record[action, 0] * record[action, 1] + r) / (record[action, 0] + 1)
record[action, 0] += 1
record[action, 1] = new_r
return record
n = 10
record = np.zeros((n, 2))
probs = np.random.rand(n)
eps = 0.2
fig, ax = plt.subplots(1, 1)
ax.set_xlabel("Plays")
ax.set_ylabel("Avg Reward")
fig.set_size_inches(9, 5)
rewards = [0]
for i in range(5000):
if random.random() > eps:
choice = get_best_arm(record) # Exploitation -> greedy
else:
choice = np.random.randint(10) # Exploration
r = get_reward(probs[choice]) # bandit Reward -> 10번 돌려서 probs[choice]보다 높게 나온 수
record = update_record(record, choice, r) # Update the recode = (시행횟수, 평균 Reward)를 담고 있음
mean_reward = ((i + 1) * rewards[-1] + r) / (i + 2)
rewards.append(mean_reward)
ax.scatter(np.arange(len(rewards)), rewards)
plt.show()