-
Notifications
You must be signed in to change notification settings - Fork 69
/
train_pixels.py
146 lines (120 loc) · 5.04 KB
/
train_pixels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import random
import numpy as np
import tqdm
from absl import app, flags
from ml_collections import config_flags
from tensorboardX import SummaryWriter
from jaxrl.agents import DrQLearner
from jaxrl.datasets import ReplayBuffer
from jaxrl.evaluation import evaluate
from jaxrl.utils import make_env
FLAGS = flags.FLAGS
flags.DEFINE_string('env_name', 'cheetah-run', 'Environment name.')
flags.DEFINE_string('save_dir', './tmp/', 'Tensorboard logging dir.')
flags.DEFINE_integer('seed', 42, 'Random seed.')
flags.DEFINE_integer('eval_episodes', 10,
'Number of episodes used for evaluation.')
flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
flags.DEFINE_integer('eval_interval', 5000, 'Eval interval.')
flags.DEFINE_integer('batch_size', 512, 'Mini batch size.')
flags.DEFINE_integer('max_steps', int(5e5), 'Number of environment steps.')
flags.DEFINE_integer('start_training', int(1e3),
'Number of environment steps to start training.')
flags.DEFINE_integer(
'action_repeat', None,
'Action repeat, if None, uses 2 or PlaNet default values.')
flags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.')
flags.DEFINE_boolean('save_video', False, 'Save videos during evaluation.')
config_flags.DEFINE_config_file(
'config',
'configs/drq_default.py',
'File path to the training hyperparameter configuration.',
lock_config=False)
PLANET_ACTION_REPEAT = {
'cartpole-swingup': 8,
'reacher-easy': 4,
'cheetah-run': 4,
'finger-spin': 2,
'ball_in_cup-catch': 4,
'walker-walk': 2
}
def main(_):
summary_writer = SummaryWriter(
os.path.join(FLAGS.save_dir, 'tb', str(FLAGS.seed)))
if FLAGS.save_video:
video_train_folder = os.path.join(FLAGS.save_dir, 'video', 'train')
video_eval_folder = os.path.join(FLAGS.save_dir, 'video', 'eval')
else:
video_train_folder = None
video_eval_folder = None
if FLAGS.action_repeat is not None:
action_repeat = FLAGS.action_repeat
else:
action_repeat = PLANET_ACTION_REPEAT.get(FLAGS.env_name, 2)
kwargs = dict(FLAGS.config)
gray_scale = kwargs.pop('gray_scale')
image_size = kwargs.pop('image_size')
def make_pixel_env(seed, video_folder):
return make_env(FLAGS.env_name,
seed,
video_folder,
action_repeat=action_repeat,
image_size=image_size,
frame_stack=3,
from_pixels=True,
gray_scale=gray_scale)
env = make_pixel_env(FLAGS.seed, video_train_folder)
eval_env = make_pixel_env(FLAGS.seed + 42, video_eval_folder)
np.random.seed(FLAGS.seed)
random.seed(FLAGS.seed)
kwargs.pop('algo')
replay_buffer_size = kwargs.pop('replay_buffer_size')
agent = DrQLearner(FLAGS.seed,
env.observation_space.sample()[np.newaxis],
env.action_space.sample()[np.newaxis], **kwargs)
replay_buffer = ReplayBuffer(
env.observation_space, env.action_space, replay_buffer_size
or FLAGS.max_steps // action_repeat)
eval_returns = []
observation, done = env.reset(), False
for i in tqdm.tqdm(range(1, FLAGS.max_steps // action_repeat + 1),
smoothing=0.1,
disable=not FLAGS.tqdm):
if i < FLAGS.start_training:
action = env.action_space.sample()
else:
action = agent.sample_actions(observation)
next_observation, reward, done, info = env.step(action)
if not done or 'TimeLimit.truncated' in info:
mask = 1.0
else:
mask = 0.0
replay_buffer.insert(observation, action, reward, mask, float(done),
next_observation)
observation = next_observation
if done:
observation, done = env.reset(), False
for k, v in info['episode'].items():
summary_writer.add_scalar(f'training/{k}', v,
info['total']['timesteps'])
if i >= FLAGS.start_training:
batch = replay_buffer.sample(FLAGS.batch_size)
update_info = agent.update(batch)
if i % FLAGS.log_interval == 0:
for k, v in update_info.items():
summary_writer.add_scalar(f'training/{k}', v, i)
summary_writer.flush()
if i % FLAGS.eval_interval == 0:
eval_stats = evaluate(agent, eval_env, FLAGS.eval_episodes)
for k, v in eval_stats.items():
summary_writer.add_scalar(f'evaluation/average_{k}s', v,
info['total']['timesteps'])
summary_writer.flush()
eval_returns.append(
(info['total']['timesteps'], eval_stats['return']))
np.savetxt(os.path.join(FLAGS.save_dir, f'{FLAGS.seed}.txt'),
eval_returns,
fmt=['%d', '%.1f'])
if __name__ == '__main__':
app.run(main)