본문 바로가기
카테고리 없음

[RL-PyTorch] Distributed Advantage Actor-Critic (DA2C)

by achrxme 2023. 9. 1.

[Actor-Critic]

loss = $ -log(\pi(a \mid S))\cdot (R-V_{\pi}(S))$

- 일반적인 -log 형태의 loss function과 비슷함

- 차이 : $R-V_{\pi}(S)$

  -> 해당 step에서 받은 reward($R$) 에서, 해당 state에서 받을 수 있다고 기대되는 state value 값을(state-action value가 아닌 state value) baseline으로 뺀 다음 loss 를 계산 -> 해당 state 에서 특정 action이 얼마나 효과적이었는지 상대적으로 평가 가능

 

Actor : policy network -> action을 선택

Critic : State value network -> actor가 선택한 action을 평가함

 

import torch
from torch import nn
from torch import optim
import numpy as np
from torch.nn import functional as F
import gym
import torch.multiprocessing as mp  # Python 내장 multiprocessing library 와 동일


class ActorCritic(nn.Module):  # Actor와 Critic을 하나의 신경망으로 결합한 모형값 -> 이중 출력 신경
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.l1 = nn.Linear(4, 25)
        self.l2 = nn.Linear(25, 50)
        self.actor_lin1 = nn.Linear(50, 2)  # Actor -> 좌/우 Action의 확률 분포
        self.l3 = nn.Linear(50, 25)
        self.critic_lin1 = nn.Linear(25, 1)  # Critic -> state value 예측

    def forward(self, x):
        x = F.normalize(x, dim=0)
        y = F.relu(self.l1(x))
        y = F.relu(self.l2(y))
        actor = F.log_softmax(self.actor_lin1(y), dim=0)  # Actor -> log(softmax(...)) 적용
        c = F.relu(self.l3(y.detach()))  # Detach -> y 노드를 그래프에서 떼어냄
        # -> Critic의 loss는 역전파 되지 않음 -> Actor와 Critic이 서로 반대 방향으로 갱신하는 경우 방지
        critic = torch.tanh(self.critic_lin1(c))  # Critic -> tanh : -1~1 사이 출력 (= cartpole reward와 동일함)
        return actor, critic


def run_episode(worker_env, worker_model):
    state = torch.from_numpy(worker_env.env.state).float()  # numpy 배열 형태의 환경 상태를 tensor로 변환
    values, logprobs, rewards = [], [], []
    done = False
    j = 0
    while (done == False):
        j += 1
        policy, value = worker_model(state)  # state value와 (critic) / action들에 대한 log 확률(actor)을 계산
        values.append(value)
        logits = policy.view(-1)
        action_dist = torch.distributions.Categorical(logits=logits)    # actor의 동작 log 확률을 이용해 범주형 확률분포를 만듬
        action = action_dist.sample()   # action 추출
        logprob_ = policy.view(-1)[action]
        logprobs.append(logprob_)
        state_, _, done, __, info = worker_env.step(action.detach().numpy())
        state = torch.from_numpy(state_).float()
        if done:
            reward = -10
            worker_env.reset()
        else:
            reward = 1.0
        rewards.append(reward)
    return values, logprobs, rewards


def update_params(worker_opt, values, logprobs, rewards, clc=0.1, gamma=0.95):
    rewards = torch.Tensor(rewards).flip(dims=(0,)).view(-1)
    # 주어진 배열들(rewards, logprobs, values)를 역순으로 정렬하고 (-> 가장 최근 동작을 가장 중요시하기 위함)
    # .view(-1)를 호출하여 1차원 배열 형태로 평평하게 만듬
    logprobs = torch.stack(logprobs).flip(dims=(0,)).view(-1)
    values = torch.stack(values).flip(dims=(0,)).view(-1)
    Returns = []
    ret_ = torch.Tensor([0])
    # reward들을 원래 수집한 것과 역순으로 훑으면서 각 reward로 return을 계산하여 returns 배열에 추가
    for r in range(rewards.shape[0]):
        ret_ = rewards[r] + gamma * ret_
        Returns.append(ret_)
    Returns = torch.stack(Returns).view(-1)
    Returns = F.normalize(Returns, dim=0)
    actor_loss = -1 * logprobs * (Returns - values.detach())  # Actor loss 계산 (baseline 사용)
    # -> critic loss가 역전파 되지 않도록 values.detach()
    critic_loss = torch.pow(values - Returns, 2)  # Critic loss 계산 (just 제곱오차)
    loss = actor_loss.sum() + clc * critic_loss.sum()  # 총 loss 계산 -> clc = 0.1 -> actor가 critic보다 빠르게 학습하게 하기 위
    loss.backward()
    worker_opt.step()
    return actor_loss, critic_loss, len(rewards)  # training 진척 과정을 관리하는 용도로 사용됨


def worker(t, worker_model, counter, params):
    worker_env = gym.make("CartPole-v1", render_mode="human")
    worker_env.reset()
    worker_opt = optim.Adam(lr=1e-4, params=worker_model.parameters())  # 각 프로세스는 격리되어 에피소드를 실행하지만, 하나의 모형을 모든 프로세스가 공유함
    worker_opt.zero_grad()
    for i in range(params['epochs']):
        worker_opt.zero_grad()
        values, logprobs, rewards = run_episode(worker_env, worker_model)  # 에피소드를 실행해서 데이터를 수집
        actor_loss, critic_loss, eplen = update_params(worker_opt, values, logprobs,
                                                      rewards)  # run_episode에서 수집한 데이터로 매개변수들을 한 단계 갱신
        counter.value = counter.value + 1   # counter : 모든 프로세스가 공유하는 전역 카운터


MasterNode = ActorCritic()  # Processor들이 공유할 global actor-critic model를
MasterNode.share_memory()  # Process들이 모형의 매개변수들을 각자 복사하는 것이 아니라, 그대로 공유하게 만듬
processes = []
params = {
    'epochs': 1000,
    'n_workers': 7,
}
counter = mp.Value('i', 0)  # multiprocessing 내장 공유 객체를 global 공유 counter로 사용, 'i' : 공유 객체의 데이터 형식이 정수라는 뜻
for i in range(params['n_workers']):
    # 각 worker는 그냥 자신의 일정에 따라 독립적으로 에피소드를 실행하여 모형 매개변수들을 갱신 -> asynchronous 갱신
    p = mp.Process(target=worker, args=(i, MasterNode, counter, params))  # 새 프로세스를 띄워 worker 함수를 실행
    p.start()
    processes.append(p)
for p in processes:  # 모든 프로세스가 작업을 마치길 기다림
    p.join()
for p in processes:  # 전역 카운터의 값과 첫 프로세스의 종료 코드(오류가 없다면 0) 출력
    p.terminate()