-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathrun_gail.py
130 lines (107 loc) · 5.05 KB
/
run_gail.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/python3
import argparse
import gym
import numpy as np
import tensorflow as tf
from network_models.policy_net import Policy_net
from network_models.discriminator import Discriminator
from algo.ppo import PPOTrain
def argparser():
parser = argparse.ArgumentParser()
parser.add_argument('--logdir', help='log directory', default='log/train/gail')
parser.add_argument('--savedir', help='save directory', default='trained_models/gail')
parser.add_argument('--gamma', default=0.95)
parser.add_argument('--iteration', default=int(1e4))
return parser.parse_args()
def main(args):
env = gym.make('CartPole-v0')
env.seed(0)
ob_space = env.observation_space
Policy = Policy_net('policy', env)
Old_Policy = Policy_net('old_policy', env)
PPO = PPOTrain(Policy, Old_Policy, gamma=args.gamma)
D = Discriminator(env)
expert_observations = np.genfromtxt('trajectory/observations.csv')
expert_actions = np.genfromtxt('trajectory/actions.csv', dtype=np.int32)
saver = tf.train.Saver()
with tf.Session() as sess:
writer = tf.summary.FileWriter(args.logdir, sess.graph)
sess.run(tf.global_variables_initializer())
obs = env.reset()
success_num = 0
for iteration in range(args.iteration):
observations = []
actions = []
# do NOT use rewards to update policy
rewards = []
v_preds = []
run_policy_steps = 0
while True:
run_policy_steps += 1
obs = np.stack([obs]).astype(dtype=np.float32) # prepare to feed placeholder Policy.obs
act, v_pred = Policy.act(obs=obs, stochastic=True)
act = np.asscalar(act)
v_pred = np.asscalar(v_pred)
next_obs, reward, done, info = env.step(act)
observations.append(obs)
actions.append(act)
rewards.append(reward)
v_preds.append(v_pred)
if done:
next_obs = np.stack([next_obs]).astype(dtype=np.float32) # prepare to feed placeholder Policy.obs
_, v_pred = Policy.act(obs=next_obs, stochastic=True)
v_preds_next = v_preds[1:] + [np.asscalar(v_pred)]
obs = env.reset()
break
else:
obs = next_obs
writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='episode_length', simple_value=run_policy_steps)])
, iteration)
writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='episode_reward', simple_value=sum(rewards))])
, iteration)
if sum(rewards) >= 195:
success_num += 1
if success_num >= 100:
saver.save(sess, args.savedir + '/model.ckpt')
print('Clear!! Model saved.')
break
else:
success_num = 0
# convert list to numpy array for feeding tf.placeholder
observations = np.reshape(observations, newshape=[-1] + list(ob_space.shape))
actions = np.array(actions).astype(dtype=np.int32)
# train discriminator
for i in range(2):
D.train(expert_s=expert_observations,
expert_a=expert_actions,
agent_s=observations,
agent_a=actions)
# output of this discriminator is reward
d_rewards = D.get_rewards(agent_s=observations, agent_a=actions)
d_rewards = np.reshape(d_rewards, newshape=[-1]).astype(dtype=np.float32)
gaes = PPO.get_gaes(rewards=d_rewards, v_preds=v_preds, v_preds_next=v_preds_next)
gaes = np.array(gaes).astype(dtype=np.float32)
# gaes = (gaes - gaes.mean()) / gaes.std()
v_preds_next = np.array(v_preds_next).astype(dtype=np.float32)
# train policy
inp = [observations, actions, gaes, d_rewards, v_preds_next]
PPO.assign_policy_parameters()
for epoch in range(6):
sample_indices = np.random.randint(low=0, high=observations.shape[0],
size=32) # indices are in [low, high)
sampled_inp = [np.take(a=a, indices=sample_indices, axis=0) for a in inp] # sample training data
PPO.train(obs=sampled_inp[0],
actions=sampled_inp[1],
gaes=sampled_inp[2],
rewards=sampled_inp[3],
v_preds_next=sampled_inp[4])
summary = PPO.get_summary(obs=inp[0],
actions=inp[1],
gaes=inp[2],
rewards=inp[3],
v_preds_next=inp[4])
writer.add_summary(summary, iteration)
writer.close()
if __name__ == '__main__':
args = argparser()
main(args)