Á¤¼ºÈÆ
    9Àå) ActorCritic (ch9_ActorCritic.py)
RL_SERHO_chap9_ActorCritic_1.png [34 KB]   RL_SERHO_chap9_ActorCritic_2.png [41 KB]  
  https://github.com/seungeunrho/RLfrombasics/blob/master/ch9_ActorCritic.py





1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
 
#Hyperparameters
learning_rate = 0.0002
gamma         = 0.98
n_rollout     = 10
 
class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.data = []
        
        self.fc1 = nn.Linear(4,256)
        self.fc_pi = nn.Linear(256,2)
        self.fc_v = nn.Linear(256,1)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        
    def pi(self, x, softmax_dim = 0):
        x = F.relu(self.fc1(x))
        x = self.fc_pi(x)
        prob = F.softmax(x, dim=softmax_dim)
        return prob
    
    def v(self, x):
        x = F.relu(self.fc1(x))
        v = self.fc_v(x)
        return v
    
    def put_data(self, transition):
        self.data.append(transition)
        
    def make_batch(self):
        s_lst, a_lst, r_lst, s_prime_lst, done_lst = [], [], [], [], []
        for transition in self.data:
            s,a,r,s_prime,done = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r/100.0])
            s_prime_lst.append(s_prime)
            done_mask = 0.0 if done else 1.0
            done_lst.append([done_mask])
        
        s_batch, a_batch, r_batch, s_prime_batch, done_batch = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), 
                                                               torch.tensor(r_lst, dtype=torch.float), torch.tensor(s_prime_lst, dtype=torch.float), 
                                                               torch.tensor(done_lst, dtype=torch.float)
        self.data = []
        return s_batch, a_batch, r_batch, s_prime_batch, done_batch
  
    def train_net(self):
        s, a, r, s_prime, done = self.make_batch()
        td_target = r + gamma * self.v(s_prime) * done
        delta = td_target - self.v(s)
        
        pi = self.pi(s, softmax_dim=1)
        pi_a = pi.gather(1,a)
        loss = -torch.log(pi_a) * delta.detach() + F.smooth_l1_loss(self.v(s), td_target.detach())
 
        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()         
      
def main():  
    env = gym.make('CartPole-v1')
    model = ActorCritic()    
    print_interval = 20
    score = 0.0
 
    for n_epi in range(10000):
        done = False
        s = env.reset()
        while not done:
            for t in range(n_rollout):
                prob = model.pi(torch.from_numpy(s).float())
                m = Categorical(prob)
                a = m.sample().item()
                s_prime, r, done, info = env.step(a)
                model.put_data((s,a,r,s_prime,done))
                
                s = s_prime
                score += r
                
                if done:
                    break                     
            
            model.train_net()
            
        if n_epi%print_interval==0 and n_epi!=0:
            print("# of episode :{}, avg score : {:.1f}".format(n_epi, score/print_interval))
            score = 0.0
    env.close()
 
if __name__ == '__main__':
    main()
cs

  µî·ÏÀÏ : 2021-08-16 [20:36] Á¶È¸ : 436 ´Ù¿î : 258   
 
¡â ÀÌÀü±ÛÆÄÀ̽ã°ú Äɶ󽺷Π¹è¿ì´Â °­È­ÇнÀÀÌ 5Àå) ÅÙ¼­Ç÷Π2.0°ú Äɶó½º
¡ä ´ÙÀ½±Û9Àå) Advantage ActorCritic ½Ç½À (ÆÄÀ̽ã°ú Äɶ󽺷Π¹è¿ì´Â °­È­ÇнÀ 6Àå A2C)
°­È­ÇнÀ ÀÌ·Ð ¹× ½Ç½À(MD) ½Ç½À
¹øÈ£ ¨Ï Á¦ ¸ñ À̸§
l¹Ù´ÚºÎÅÍ ¹è¿ì´Â °­È­ ÇнÀ ÄÚµå (github)
l°­È­ÇнÀ/½ÉÃþ°­È­ÇнÀ Ư°­ (github)
lÆÄÀ̽ã°ú Äɶ󽺷Π¹è¿ì´Â °­È­ÇнÀ (github)
25 lÆÄÀ̽ã°ú Äɶ󽺷Π¹è¿ì´Â °­È­ÇнÀ (github) Á¤¼ºÈÆ
24 ¦¦❶ 7Àå) ¾ÆŸ¸® ºê·¹ÀÌÅ© ¾Æ¿ô (A3C) Á¤¼ºÈÆ
23 ¦¦❶ 7Àå) ¾ÆŸ¸® ºê·¹ÀÌÅ© ¾Æ¿ô (DQN) Á¤¼ºÈÆ
22 l°­È­ÇнÀ/½ÉÃþ°­È­ÇнÀ Ư°­ (github) Á¤¼ºÈÆ
21 ¦¦❶ 13Àå) ½º³×ÀÌÅ© °ÔÀÓ ¸¶½ºÅÍ µÇ±â Á¤¼ºÈÆ
20 ¦¦❶ 10Àå) ÀÚÀ²ÁÖÇàÂ÷¸¦ À§ÇÑ AI Á¤¼ºÈÆ
19 l¹Ù´ÚºÎÅÍ ¹è¿ì´Â °­È­ ÇнÀ ÄÚµå (github) Á¤¼ºÈÆ
18 ¦¦❶ ÆÄÀ̽ã°ú Äɶ󽺷Π¹è¿ì´Â °­È­ÇнÀÀÌ 5Àå) ÅÙ¼­Ç÷Π2.0°ú ÄÉ¶ó½º Á¤¼ºÈÆ
17    ¦¦❷ ÆÄÀ̽ã°ú Äɶ󽺷Π¹è¿ì´Â °­È­ÇнÀÀÌ 5Àå) ÅÙ¼­Ç÷Π2.0°ú ÄÉ¶ó½º Á¤¼ºÈÆ
16 ¦¦❶ l9Àå) ActorCritic (ch9_ActorCritic.py) Á¤¼ºÈÆ
15    ¦¦❷ 9Àå) Advantage ActorCritic ½Ç½À (ÆÄÀ̽ã°ú Äɶ󽺷Π¹è¿ì´Â °­È­ÇнÀ 6Àå A2C) Á¤¼ºÈÆ
14       ¦¦❸ 9Àå) ¿¬¼ÓÀû ¾×ÅÍ-Å©¸®Æ½ ½Ç½À (ÆÄÀ̽ã°ú Äɶ󽺷Π¹è¿ì´Â °­È­ÇнÀ 6Àå) Á¤¼ºÈÆ
13          ¦¦❹ ¿¬¼ÓÀû ¾×ÅÍ-Å©¸®Æ½ ½ÇÇà ȯ°æ ¹× °á°ú Á¤¼ºÈÆ
12 ¦¦❶ l9Àå) REINFORCE (ch9_REINFORCE.py) Á¤¼ºÈÆ
11 ¦¦❶ l8Àå) DQN (ch8_DQN.py) Á¤¼ºÈÆ

[1][2]