본문 바로가기
RL with python/Python example code

[RL-PyTorch] Basic Deep Q-learning

by achrxme 2023. 8. 22.

 

import numpy as np
import torch
from Gridworld import Gridworld
from IPython.display import clear_output
import random
from matplotlib import pylab as plt

l1 = 64
l2 = 150
l3 = 100
l4 = 4

model = torch.nn.Sequential(
    torch.nn.Linear(l1, l2),
    torch.nn.ReLU(),
    torch.nn.Linear(l2, l3),
    torch.nn.ReLU(),
    torch.nn.Linear(l3, l4)
)
loss_fn = torch.nn.MSELoss()
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

gamma = 0.9
epsilon = 1.0

action_set = {
    0: 'u',
    1: 'd',
    2: 'l',
    3: 'r',
}

epochs = 1000
losses = []  # 이후 추세 그래프를 그리기 위한 list
for i in range(epochs):  # main training loop
    game = Gridworld(size=4, mode='static')  # 각 training 마다 새 게임을 시작
    # 게임 intance를 생성하고, 학습에 사용하기 쉽도록 eshape(1, 64) -> 약간의 noise 추가 (ReLU의 dead neuron 방지)
    state1_ = game.board.render_np().reshape(1, 64) + np.random.rand(1, 64) / 10.0
    state1 = torch.from_numpy(state1_).float()
    status = 1  # 아직 게임이 진행 중 인지를 나타냄

    # Play game in here
    while (status == 1):
        #  Q NN을 실행 (순전파)
        qval = model(state1)  # 하나의 state에 대해 가능한 action들에 대한 reward(Q value)를 모두 예측
        qval_ = qval.data.numpy()

        # Epsilon greedy
        if (random.random() < epsilon):
            action_ = np.random.randint(0, 4)
        else:
            action_ = np.argmax(qval_)

        # 선택한 action 수행
        action = action_set[action_]
        game.makeMove(action)  # 선택한 action을 수행

        # action을 수행 하고, 게임의 새 state와 reward를 얻
        state2_ = game.board.render_np().reshape(1, 64) + np.random.rand(1, 64) / 10.0  # 선택한 action에 따른 새 state를 얻음
        state2 = torch.from_numpy(state2_).float()
        reward = game.reward()

        with torch.no_grad():  # 원래는 자동으로 계산 그래프를 만드는데, 굳이 필요 없으면 안 만들게 함
            newQ = model(state2.reshape(1, 64))  # 새 state에 대한 q value를 예측 (dim : 4)
        # 현재까지 최고의 Q value
        maxQ = torch.max(newQ)  # 새로 예측한 Q value

        # Q-learning target 계산
        if reward == -1:
            Y = reward + (gamma * maxQ)
        else:
            Y = reward

        # NN model training
        # *detach : 계산 그래프에서 노드를 떼어냄
        # -> If not -> Y에 훈련 가능한 매개 변수들과 개별적인 계산 그래프들이 연관 됨 -> model training 시 X와 Y둘 다에 대한 역전파가 일어남
        # -> 우리가 원하는 것은 X만 갱신하는 것
        Y = torch.Tensor([Y]).detach()  # Q-learning target
        X = qval.squeeze()[action_]  # qval : action 수행 전에 예측한 q value임
        loss = loss_fn(X, Y)  # Q-learning target과 qval의 loss를 계산
        # -> state에 따라 계산되는 qval을 NN으로 학습하여 더 정확한 예측이 가능하게 하려는 것
        print(i, loss.item())
        clear_output(wait=True)
        optimizer.zero_grad()
        loss.backward()
        losses.append(loss.item())
        optimizer.step()
        state1 = state2
        if reward != -1:  # reward == -1 이면 게임이 끝난 것
            status = 0
    if epsilon > 0.1:  # 매 episode 마다 epsilon 감소
        epsilon -= (1 / epochs)

plt.figure(figsize=(10,7))
plt.plot(losses)
plt.xlabel("Epochs",fontsize=22)
plt.ylabel("Loss",fontsize=22)

plt.show()

def test_model(model, mode='static', display=True):
    i = 0
    test_game = Gridworld(mode=mode)
    state_ = test_game.board.render_np().reshape(1, 64) + np.random.rand(1, 64) / 10.0
    state = torch.from_numpy(state_).float()
    if display:
        print("Initial State:")
        print(test_game.display())
    status = 1
    while (status == 1):  # A
        qval = model(state)
        qval_ = qval.data.numpy()
        action_ = np.argmax(qval_)  # B
        action = action_set[action_]
        if display:
            print('Move #: %s; Taking action: %s' % (i, action))
        test_game.makeMove(action)
        state_ = test_game.board.render_np().reshape(1, 64) + np.random.rand(1, 64) / 10.0
        state = torch.from_numpy(state_).float()
        if display:
            print(test_game.display())
        reward = test_game.reward()
        if reward != -1:
            if reward > 0:
                status = 2
                if display:
                    print("Game won! Reward: %s" % (reward,))
            else:
                status = 0
                if display:
                    print("Game LOST. Reward: %s" % (reward,))
        i += 1
        if (i > 15):
            if display:
                print("Game lost; too many moves.")
            break

    win = True if status == 2 else False
    return win


test_model(model, 'static')