forked from PaddlePaddle/PARL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
117 lines (97 loc) · 3.42 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import gym
import numpy as np
import parl
from parl.utils import logger, ReplayMemory
from cartpole_model import CartpoleModel
from cartpole_agent import CartpoleAgent
from parl.algorithms import DQN
LEARN_FREQ = 5 # training frequency
MEMORY_SIZE = 200000
MEMORY_WARMUP_SIZE = 200
BATCH_SIZE = 64
LEARNING_RATE = 0.0005
GAMMA = 0.99
# train an episode
def run_train_episode(agent, env, rpm):
total_reward = 0
obs = env.reset()
step = 0
while True:
step += 1
action = agent.sample(obs)
next_obs, reward, done, _ = env.step(action)
rpm.append(obs, action, reward, next_obs, done)
# train model
if (len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0):
# s,a,r,s',done
(batch_obs, batch_action, batch_reward, batch_next_obs,
batch_done) = rpm.sample_batch(BATCH_SIZE)
train_loss = agent.learn(batch_obs, batch_action, batch_reward,
batch_next_obs, batch_done)
total_reward += reward
obs = next_obs
if done:
break
return total_reward
# evaluate 5 episodes
def run_evaluate_episodes(agent, env, eval_episodes=5, render=False):
eval_reward = []
for i in range(eval_episodes):
obs = env.reset()
episode_reward = 0
while True:
action = agent.predict(obs)
obs, reward, done, _ = env.step(action)
episode_reward += reward
if render:
env.render()
if done:
break
eval_reward.append(episode_reward)
return np.mean(eval_reward)
def main():
env = gym.make('CartPole-v0')
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n
logger.info('obs_dim {}, act_dim {}'.format(obs_dim, act_dim))
# set action_shape = 0 while in discrete control environment
rpm = ReplayMemory(MEMORY_SIZE, obs_dim, 0)
# build an agent
model = CartpoleModel(obs_dim=obs_dim, act_dim=act_dim)
alg = DQN(model, gamma=GAMMA, lr=LEARNING_RATE)
agent = CartpoleAgent(
alg, act_dim=act_dim, e_greed=0.1, e_greed_decrement=1e-6)
# warmup memory
while len(rpm) < MEMORY_WARMUP_SIZE:
run_train_episode(agent, env, rpm)
max_episode = 800
# start training
episode = 0
while episode < max_episode:
# train part
for i in range(50):
total_reward = run_train_episode(agent, env, rpm)
episode += 1
# test part
eval_reward = run_evaluate_episodes(agent, env, render=False)
logger.info('episode:{} e_greed:{} Test reward:{}'.format(
episode, agent.e_greed, eval_reward))
# save the parameters to ./model.ckpt
save_path = './model.ckpt'
agent.save(save_path)
if __name__ == '__main__':
main()