RL with python/Python example code
[RL-PyTorch] N Step Actor-Critic
achrxme
2023. 9. 2. 22:53
Mote Carlo : 각 에피소드의 끝에서 학습
Temporal Diffence (TD, Fully online learning) : 매 스텝마다 학습 -> Bootstrapping 적용 시 bias 커질 수 있음
N Step TD : N 스텝 마다 학습 -> Bootstrapping 적용해도 TD보다는 편향이 적음
*Bootstrapping : 예측에 기반해서 또 다른 예측을 수행 -> 예측을 하기 전에 최대한 많은 데이터를 모으는 것이 바람직함
# Multi step Temporal Difference
import torch
from torch import nn
from torch import optim
import numpy as np
from torch.nn import functional as F
import gym
import torch.multiprocessing as mp
class ActorCritic(nn.Module):
def __init__(self):
super(ActorCritic, self).__init__()
self.l1 = nn.Linear(4, 25)
self.l2 = nn.Linear(25, 50)
self.actor_lin1 = nn.Linear(50, 2)
self.l3 = nn.Linear(50, 25)
self.critic_lin1 = nn.Linear(25, 1)
def forward(self, x):
x = F.normalize(x, dim=0)
y = F.relu(self.l1(x))
y = F.relu(self.l2(y))
actor = F.log_softmax(self.actor_lin1(y), dim=0)
c = F.relu(self.l3(y.detach()))
critic = torch.tanh(self.critic_lin1(c))
return actor, critic
def run_episode(worker_env, worker_model, N_steps=10):
raw_state = np.array(worker_env.env.state)
state = torch.from_numpy(raw_state).float()
values, logprobs, rewards = [],[],[]
done = False
j=0
G=torch.Tensor([0]) # Return을 담음 -> 처음에는 0으로 초기화
while (j < N_steps and done == False): # Episode가 다 끝난 이후 뿐만 아니라, N step 진행할 때마다 매개변수 갱신
j+=1
policy, value = worker_model(state)
values.append(value)
logits = policy.view(-1)
action_dist = torch.distributions.Categorical(logits=logits)
action = action_dist.sample()
logprob_ = policy.view(-1)[action]
logprobs.append(logprob_)
state_, _, done, __, info = worker_env.step(action.detach().numpy())
state = torch.from_numpy(state_).float()
if done:
reward = -10
worker_env.reset()
else: # 에피소드가 끝나지 않았다면 마지막 state value를 return으로 둠
reward = 1.0
G = value.detach() # 현재의 Critic 모델(worker_model)을 업데이트 하지 않은 상태로 얻은 state value를 G 둠
# -> 예측을 기반으로 예측함 : Boot strapping
rewards.append(reward)
return values, logprobs, rewards, G
def update_params(worker_opt, values, logprobs, rewards, G, clc=0.1, gamma=0.95):
rewards = torch.Tensor(rewards).flip(dims=(0,)).view(-1)
logprobs = torch.stack(logprobs).flip(dims=(0,)).view(-1)
values = torch.stack(values).flip(dims=(0,)).view(-1)
Returns = []
ret_ = G # G값 직접 이용 cf. MC : ret_ = torch.Tensor([0]) -> 이 지점에서 boot strapping이 일어남
for r in range(rewards.shape[0]):
ret_ = rewards[r] + gamma * ret_
Returns.append(ret_)
Returns = torch.stack(Returns).view(-1)
Returns = F.normalize(Returns, dim=0)
actor_loss = -1 * logprobs * (Returns - values.detach())
critic_loss = torch.pow(values - Returns, 2)
loss = actor_loss.sum() + clc * critic_loss.sum()
loss.backward()
worker_opt.step()
return actor_loss, critic_loss, len(rewards)
def worker(t, worker_model, counter, params):
# worker_env = gym.make("CartPole-v1", render_mode="human")
worker_env = gym.make("CartPole-v1")
worker_env.reset()
worker_opt = optim.Adam(lr=1e-4, params=worker_model.parameters())
worker_opt.zero_grad()
for i in range(params['epochs']):
worker_opt.zero_grad()
values, logprobs, rewards, G = run_episode(worker_env, worker_model)
actor_loss, critic_loss, eplen = update_params(worker_opt, values, logprobs, rewards, G)
print("epoch ", i)
counter.value = counter.value + 1
MasterNode = ActorCritic()
MasterNode.share_memory()
processes = []
params = {
'epochs': 5000,
'n_workers': 12,
}
counter = mp.Value('i', 0)
for i in range(params['n_workers']):
p = mp.Process(target=worker, args=(i, MasterNode, counter, params))
p.start()
processes.append(p)
for p in processes:
p.join()
for p in processes:
p.terminate()
env = gym.make("CartPole-v1", render_mode="human")
env.reset()
step = 0
for i in range(10000):
state_ = np.array(env.env.state)
state = torch.from_numpy(state_).float()
logits,value = MasterNode(state)
action_dist = torch.distributions.Categorical(logits=logits)
action = action_dist.sample()
state2, reward, done, _,info = env.step(action.detach().numpy())
if done:
print("Lost in ", step)
step = 0
env.reset()
state_ = np.array(env.env.state)
state = torch.from_numpy(state_).float()
env.render()
step = step + 1