1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | import gym import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.distributions import Categorical #Hyperparameters learning_rate = 0.0002 gamma = 0.98 n_rollout = 10 class ActorCritic(nn.Module): def __init__(self): super(ActorCritic, self).__init__() self.data = [] self.fc1 = nn.Linear(4,256) self.fc_pi = nn.Linear(256,2) self.fc_v = nn.Linear(256,1) self.optimizer = optim.Adam(self.parameters(), lr=learning_rate) def pi(self, x, softmax_dim = 0): x = F.relu(self.fc1(x)) x = self.fc_pi(x) prob = F.softmax(x, dim=softmax_dim) return prob def v(self, x): x = F.relu(self.fc1(x)) v = self.fc_v(x) return v def put_data(self, transition): self.data.append(transition) def make_batch(self): s_lst, a_lst, r_lst, s_prime_lst, done_lst = [], [], [], [], [] for transition in self.data: s,a,r,s_prime,done = transition s_lst.append(s) a_lst.append([a]) r_lst.append([r/100.0]) s_prime_lst.append(s_prime) done_mask = 0.0 if done else 1.0 done_lst.append([done_mask]) s_batch, a_batch, r_batch, s_prime_batch, done_batch = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), torch.tensor(r_lst, dtype=torch.float), torch.tensor(s_prime_lst, dtype=torch.float), torch.tensor(done_lst, dtype=torch.float) self.data = [] return s_batch, a_batch, r_batch, s_prime_batch, done_batch def train_net(self): s, a, r, s_prime, done = self.make_batch() td_target = r + gamma * self.v(s_prime) * done delta = td_target - self.v(s) pi = self.pi(s, softmax_dim=1) pi_a = pi.gather(1,a) loss = -torch.log(pi_a) * delta.detach() + F.smooth_l1_loss(self.v(s), td_target.detach()) self.optimizer.zero_grad() loss.mean().backward() self.optimizer.step() def main(): env = gym.make('CartPole-v1') model = ActorCritic() print_interval = 20 score = 0.0 for n_epi in range(10000): done = False s = env.reset() while not done: for t in range(n_rollout): prob = model.pi(torch.from_numpy(s).float()) m = Categorical(prob) a = m.sample().item() s_prime, r, done, info = env.step(a) model.put_data((s,a,r,s_prime,done)) s = s_prime score += r if done: break model.train_net() if n_epi%print_interval==0 and n_epi!=0: print("# of episode :{}, avg score : {:.1f}".format(n_epi, score/print_interval)) score = 0.0 env.close() if __name__ == '__main__': main() | cs |