-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdebug.py
42 lines (36 loc) · 1.19 KB
/
debug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from muzero.config import MuZeroConfig
from muzero.main import Muzero
from muzero.common import KnownBounds
import ray
def make_atari_config() -> MuZeroConfig:
def visit_softmax_temperature(num_moves, training_steps):
if training_steps < 5e3:
return 1.0
elif training_steps < 1e4:
return 0.5
else:
return 0.25
return MuZeroConfig(
gym_env_name='CartPole-v1',
action_space_size=2,
value_support_size=20,
reward_support_size=20, # Keep this fairly low since we don't need granularity here
selfplay_iterations=1000, # Todo: implement None for continuous play
max_moves=500,
discount=0.997,
use_TD_values=True,
dirichlet_alpha=0.25,
num_simulations=50, # >20 usually works best
batch_size=128,#1024,
td_steps=25,#10
num_actors=1,#350
lr_init=0.05,#0.05
lr_decay_steps=350e3,
checkpoint_interval=10,
visit_softmax_temperature_fn=visit_softmax_temperature,
# known_bounds=KnownBounds(min=0, max=500),
num_train_gpus=0)
ray.init(local_mode=False)
config = make_atari_config()
mz = Muzero(config)
mz.run()