-
Notifications
You must be signed in to change notification settings - Fork 763
/
config.py
68 lines (53 loc) · 1.23 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class AgentConfig(object):
scale = 10000
display = False
max_step = 5000 * scale
memory_size = 100 * scale
batch_size = 32
random_start = 30
cnn_format = 'NCHW'
discount = 0.99
target_q_update_step = 1 * scale
learning_rate = 0.00025
learning_rate_minimum = 0.00025
learning_rate_decay = 0.96
learning_rate_decay_step = 5 * scale
ep_end = 0.1
ep_start = 1.
ep_end_t = memory_size
history_length = 4
train_frequency = 4
learn_start = 5. * scale
min_delta = -1
max_delta = 1
double_q = False
dueling = False
_test_step = 5 * scale
_save_step = _test_step * 10
class EnvironmentConfig(object):
env_name = 'Breakout-v0'
screen_width = 84
screen_height = 84
max_reward = 1.
min_reward = -1.
class DQNConfig(AgentConfig, EnvironmentConfig):
model = ''
pass
class M1(DQNConfig):
backend = 'tf'
env_type = 'detail'
action_repeat = 1
def get_config(FLAGS):
if FLAGS.model == 'm1':
config = M1
elif FLAGS.model == 'm2':
config = M2
for k, v in FLAGS.__dict__['__flags'].items():
if k == 'gpu':
if v == False:
config.cnn_format = 'NHWC'
else:
config.cnn_format = 'NCHW'
if hasattr(config, k):
setattr(config, k, v)
return config