본문 바로가기
RL with python/Python example code

[RL-Python] Multi-armed bandit - softmax

by achrxme 2023. 8. 21.
import numpy as np
from scipy import stats
import random
import matplotlib.pyplot as plt


def get_reward(prob, n=10):
    reward = 0
    for i in range(n):
        if random.random() < prob:
            reward += 1
    return reward


def softmax(av, tau=1.12):
    softm = ( np.exp(av / tau) / np.sum( np.exp(av / tau) ) )
    return softm

def update_record(record, action, r):
    # 지금 까지의 평균 * 시행 수 = 지금 까지 reward의 합
    # -> 이걸 이용 해서 시행 수 + 1, 새로운 reward 포함하여 record update
    new_r = (record[action, 0] * record[action, 1] + r) / (record[action, 0] + 1)
    record[action, 0] += 1
    record[action, 1] = new_r
    return record


n = 10
record = np.zeros((n, 2))
probs = np.random.rand(n)
eps = 0.2

fig, ax = plt.subplots(1, 1)
ax.set_xlabel("Plays")
ax.set_ylabel("Avg Reward")

fig,ax = plt.subplots(1,1)
ax.set_xlabel("Plays")
ax.set_ylabel("Avg Reward")
fig.set_size_inches(9, 5)
rewards = [0]
for i in range(500):
    p = softmax(record[:, 1], tau=0.7)
    choice = np.random.choice(np.arange(n),p=p)     # Epsilon-Greedy 방법이 아닌 softmax 처리한 확률 분포에서 선택
    r = get_reward(probs[choice])
    record = update_record(record, choice, r)
    mean_reward = ((i+1) * rewards[-1] + r)/(i+2)
    rewards.append(mean_reward)
ax.scatter(np.arange(len(rewards)), rewards)

plt.show()