RL with python/Python example code

[RL-PyTorch] N Step Actor-Critic

achrxme 2023. 9. 2. 22:53

Mote Carlo : 각 에피소드의 끝에서 학습

Temporal Diffence (TD, Fully online learning) : 매 스텝마다 학습 -> Bootstrapping 적용 시 bias 커질 수 있음

N Step TD : N 스텝 마다 학습 -> Bootstrapping 적용해도 TD보다는 편향이 적음

*Bootstrapping : 예측에 기반해서 또 다른 예측을 수행 -> 예측을 하기 전에 최대한 많은 데이터를 모으는 것이 바람직함

 

# Multi step Temporal Difference

import torch
from torch import nn
from torch import optim
import numpy as np
from torch.nn import functional as F
import gym
import torch.multiprocessing as mp


class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.l1 = nn.Linear(4, 25)
        self.l2 = nn.Linear(25, 50)
        self.actor_lin1 = nn.Linear(50, 2)
        self.l3 = nn.Linear(50, 25)
        self.critic_lin1 = nn.Linear(25, 1)

    def forward(self, x):
        x = F.normalize(x, dim=0)
        y = F.relu(self.l1(x))
        y = F.relu(self.l2(y))
        actor = F.log_softmax(self.actor_lin1(y), dim=0)
        c = F.relu(self.l3(y.detach()))
        critic = torch.tanh(self.critic_lin1(c))
        return actor, critic


def run_episode(worker_env, worker_model, N_steps=10):
    raw_state = np.array(worker_env.env.state)
    state = torch.from_numpy(raw_state).float()
    values, logprobs, rewards = [],[],[]
    done = False
    j=0
    G=torch.Tensor([0]) # Return을 담음 -> 처음에는 0으로 초기화
    while (j < N_steps and done == False): # Episode가 다 끝난 이후 뿐만 아니라, N step 진행할 때마다 매개변수 갱신
        j+=1
        policy, value = worker_model(state)
        values.append(value)
        logits = policy.view(-1)
        action_dist = torch.distributions.Categorical(logits=logits)
        action = action_dist.sample()
        logprob_ = policy.view(-1)[action]
        logprobs.append(logprob_)
        state_, _, done, __, info = worker_env.step(action.detach().numpy())
        state = torch.from_numpy(state_).float()
        if done:
            reward = -10
            worker_env.reset()
        else: # 에피소드가 끝나지 않았다면 마지막 state value를 return으로 둠
            reward = 1.0
            G = value.detach() # 현재의 Critic 모델(worker_model)을 업데이트 하지 않은 상태로 얻은 state value를 G 둠
                               # -> 예측을 기반으로 예측함 : Boot strapping
        rewards.append(reward)
    return values, logprobs, rewards, G


def update_params(worker_opt, values, logprobs, rewards, G, clc=0.1, gamma=0.95):
    rewards = torch.Tensor(rewards).flip(dims=(0,)).view(-1)
    logprobs = torch.stack(logprobs).flip(dims=(0,)).view(-1)
    values = torch.stack(values).flip(dims=(0,)).view(-1)
    Returns = []
    ret_ = G # G값 직접 이용 cf. MC : ret_ = torch.Tensor([0]) -> 이 지점에서 boot strapping이 일어남
    for r in range(rewards.shape[0]):
        ret_ = rewards[r] + gamma * ret_
        Returns.append(ret_)
    Returns = torch.stack(Returns).view(-1)
    Returns = F.normalize(Returns, dim=0)
    actor_loss = -1 * logprobs * (Returns - values.detach())
    critic_loss = torch.pow(values - Returns, 2)
    loss = actor_loss.sum() + clc * critic_loss.sum()
    loss.backward()
    worker_opt.step()
    return actor_loss, critic_loss, len(rewards)


def worker(t, worker_model, counter, params):
    # worker_env = gym.make("CartPole-v1", render_mode="human")
    worker_env = gym.make("CartPole-v1")
    worker_env.reset()
    worker_opt = optim.Adam(lr=1e-4, params=worker_model.parameters())
    worker_opt.zero_grad()
    for i in range(params['epochs']):
        worker_opt.zero_grad()
        values, logprobs, rewards, G = run_episode(worker_env, worker_model)
        actor_loss, critic_loss, eplen = update_params(worker_opt, values, logprobs, rewards, G)
        print("epoch ", i)
        counter.value = counter.value + 1


MasterNode = ActorCritic()
MasterNode.share_memory()
processes = []
params = {
    'epochs': 5000,
    'n_workers': 12,
}

counter = mp.Value('i', 0)
for i in range(params['n_workers']):
    p = mp.Process(target=worker, args=(i, MasterNode, counter, params))
    p.start()
    processes.append(p)
for p in processes:
    p.join()
for p in processes:
    p.terminate()

env = gym.make("CartPole-v1", render_mode="human")
env.reset()

step = 0
for i in range(10000):
    state_ = np.array(env.env.state)
    state = torch.from_numpy(state_).float()
    logits,value = MasterNode(state)
    action_dist = torch.distributions.Categorical(logits=logits)
    action = action_dist.sample()
    state2, reward, done, _,info = env.step(action.detach().numpy())
    if done:
        print("Lost in ", step)
        step = 0
        env.reset()
    state_ = np.array(env.env.state)
    state = torch.from_numpy(state_).float()
    env.render()
    step = step + 1