Á¤¼ºÈÆ
    8Àå) DQN (ch8_DQN.py)
RL_SERHO_chap8_DQN_1.png [50 KB]   RL_SERHO_chap8_DQN_2.png [45 KB]  
  https://github.com/seungeunrho/RLfrombasics/blob/master/ch8_DQN.py





1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gym
import collections
import random
 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
#Hyperparameters
learning_rate = 0.0005
gamma         = 0.98
buffer_limit  = 50000
batch_size    = 32
 
class ReplayBuffer():
    def __init__(self):
        self.buffer = collections.deque(maxlen=buffer_limit)
    
    def put(self, transition):
        self.buffer.append(transition)
    
    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []
        
        for transition in mini_batch:
            s, a, r, s_prime, done_mask = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            done_mask_lst.append([done_mask])
 
        return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), 
               torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), 
               torch.tensor(done_mask_lst)
    
    def size(self):
        return len(self.buffer)
 
class Qnet(nn.Module):
    def __init__(self):
        super(Qnet, self).__init__()
        self.fc1 = nn.Linear(4128)
        self.fc2 = nn.Linear(128128)
        self.fc3 = nn.Linear(1282)
 
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
      
    def sample_action(self, obs, epsilon):
        out = self.forward(obs)
        coin = random.random()
        if coin < epsilon:
            return random.randint(0,1)
        else : 
            return out.argmax().item()
            
def train(q, q_target, memory, optimizer):
    for i in range(10):
        s,a,r,s_prime,done_mask = memory.sample(batch_size)
 
        q_out = q(s)
        q_a = q_out.gather(1,a)
        max_q_prime = q_target(s_prime).max(1)[0].unsqueeze(1)
        target = r + gamma * max_q_prime * done_mask
        loss = F.smooth_l1_loss(q_a, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
def main():
    env = gym.make('CartPole-v1')
    q = Qnet()
    q_target = Qnet()
    q_target.load_state_dict(q.state_dict())
    memory = ReplayBuffer()
 
    print_interval = 20
    score = 0.0  
    optimizer = optim.Adam(q.parameters(), lr=learning_rate)
 
    for n_epi in range(10000):
        epsilon = max(0.010.08 - 0.01*(n_epi/200)) #Linear annealing from 8% to 1%
        s = env.reset()
        done = False
 
        while not done:
            a = q.sample_action(torch.from_numpy(s).float(), epsilon)      
            s_prime, r, done, info = env.step(a)
            done_mask = 0.0 if done else 1.0
            memory.put((s,a,r/100.0,s_prime, done_mask))
            s = s_prime
 
            score += r
            if done:
                break
            
        if memory.size()>2000:
            train(q, q_target, memory, optimizer)
 
        if n_epi%print_interval==0 and n_epi!=0:
            q_target.load_state_dict(q.state_dict())
            print("n_episode :{}, score : {:.1f}, n_buffer : {}, eps : {:.1f}%".format(
                                                            n_epi, score/print_interval, memory.size(), epsilon*100))
            score = 0.0
    env.close()
 
if __name__ == '__main__':
    main()
cs

  µî·ÏÀÏ : 2021-08-16 [20:32] Á¶È¸ : 851 ´Ù¿î : 459   
 
¡â ÀÌÀü±Û9Àå) REINFORCE (ch9_REINFORCE.py)
¡ä ´ÙÀ½±Û8Àå) DQN ½Ç½À (ÆÄÀ̽ã°ú Äɶ󽺷Π¹è¿ì´Â °­È­ÇнÀ 6Àå DQN)
°­È­ÇнÀ ÀÌ·Ð ¹× ½Ç½À(MD) ½Ç½À
¹øÈ£ ¨Ï Á¦ ¸ñ À̸§
l¹Ù´ÚºÎÅÍ ¹è¿ì´Â °­È­ ÇнÀ ÄÚµå (github)
l°­È­ÇнÀ/½ÉÃþ°­È­ÇнÀ Ư°­ (github)
lÆÄÀ̽ã°ú Äɶ󽺷Π¹è¿ì´Â °­È­ÇнÀ (github)
25 lÆÄÀ̽ã°ú Äɶ󽺷Π¹è¿ì´Â °­È­ÇнÀ (github) Á¤¼ºÈÆ
24 ¦¦❶ 7Àå) ¾ÆŸ¸® ºê·¹ÀÌÅ© ¾Æ¿ô (A3C) Á¤¼ºÈÆ
23 ¦¦❶ 7Àå) ¾ÆŸ¸® ºê·¹ÀÌÅ© ¾Æ¿ô (DQN) Á¤¼ºÈÆ
22 l°­È­ÇнÀ/½ÉÃþ°­È­ÇнÀ Ư°­ (github) Á¤¼ºÈÆ
21 ¦¦❶ 13Àå) ½º³×ÀÌÅ© °ÔÀÓ ¸¶½ºÅÍ µÇ±â Á¤¼ºÈÆ
20 ¦¦❶ 10Àå) ÀÚÀ²ÁÖÇàÂ÷¸¦ À§ÇÑ AI Á¤¼ºÈÆ
19 l¹Ù´ÚºÎÅÍ ¹è¿ì´Â °­È­ ÇнÀ ÄÚµå (github) Á¤¼ºÈÆ
18 ¦¦❶ ÆÄÀ̽ã°ú Äɶ󽺷Π¹è¿ì´Â °­È­ÇнÀÀÌ 5Àå) ÅÙ¼­Ç÷Π2.0°ú ÄÉ¶ó½º Á¤¼ºÈÆ
17    ¦¦❷ ÆÄÀ̽ã°ú Äɶ󽺷Π¹è¿ì´Â °­È­ÇнÀÀÌ 5Àå) ÅÙ¼­Ç÷Π2.0°ú ÄÉ¶ó½º Á¤¼ºÈÆ
16 ¦¦❶ l9Àå) ActorCritic (ch9_ActorCritic.py) Á¤¼ºÈÆ
15    ¦¦❷ 9Àå) Advantage ActorCritic ½Ç½À (ÆÄÀ̽ã°ú Äɶ󽺷Π¹è¿ì´Â °­È­ÇнÀ 6Àå A2C) Á¤¼ºÈÆ
14       ¦¦❸ 9Àå) ¿¬¼ÓÀû ¾×ÅÍ-Å©¸®Æ½ ½Ç½À (ÆÄÀ̽ã°ú Äɶ󽺷Π¹è¿ì´Â °­È­ÇнÀ 6Àå) Á¤¼ºÈÆ
13          ¦¦❹ ¿¬¼ÓÀû ¾×ÅÍ-Å©¸®Æ½ ½ÇÇà ȯ°æ ¹× °á°ú Á¤¼ºÈÆ
12 ¦¦❶ l9Àå) REINFORCE (ch9_REINFORCE.py) Á¤¼ºÈÆ
11 ¦¦❶ l8Àå) DQN (ch8_DQN.py) Á¤¼ºÈÆ

[1][2]