-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_PPO_trXL_pmap.py
97 lines (67 loc) · 2.01 KB
/
train_PPO_trXL_pmap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import time
from trainer_PPO_trXL_pmap import make_train,ActorCriticTransformer
import jax
import os
import jax.numpy as jnp
from flax.jax_utils import replicate, unreplicate
config = {
"LR": 2e-4,
"NUM_ENVS": 1024,
"NUM_STEPS": 128,
"TOTAL_TIMESTEPS": 1e9,
"UPDATE_EPOCHS": 4,
"NUM_MINIBATCHES": 8,
"GAMMA": 0.999,
"GAE_LAMBDA": 0.8,
"CLIP_EPS": 0.2,
"ENT_COEF": 0.002,
"VF_COEF": 0.5,
"MAX_GRAD_NORM": 1.,
"ACTIVATION": "relu",
"ENV_NAME": "craftax",
"ANNEAL_LR": True,
"qkv_features":256,
"EMBED_SIZE":256,
"num_heads":8,
"num_layers":2,
"hidden_layers":256,
"WINDOW_MEM":128,
"WINDOW_GRAD":64,
"gating":True,
"gating_bias":2.,
"seed":0
}
#seed=int(os.environ["SLURM_ARRAY_TASK_ID"])
seed=config["seed"]
prefix= "results_craftax/"+config["ENV_NAME"]
try:
if not os.path.exists(prefix):
os.makedirs(prefix)
except:
print("directory creation " + prefix +" failed")
print("Start compiling")
time_a=time.time()
rng = jax.random.PRNGKey(seed)
rng,_rng=jax.random.split(rng)
train_fn,train_state = (make_train(config,_rng))
print(jax.local_devices())
train_states = replicate(train_state, jax.local_devices())
rng=jax.random.split(rng,len(jax.local_devices()))
train_jit_fn= train_fn.lower(rng,train_states).compile()
print("compilation took " + str(time.time()-time_a))
print("Start training")
time_a=time.time()
out =train_jit_fn(rng,train_states)
returns=out["metrics"]["returned_episode_returns"].block_until_ready()
print("training took " + str(time.time()-time_a))
out=unreplicate(out)
#out=jax.tree_util.tree_map(lambda x: x[0],out)
import matplotlib.pyplot as plt
plt.plot(out["metrics"]["returned_episode_returns"])
plt.xlabel("Updates")
plt.ylabel("Return")
plt.savefig(prefix+"/return_"+str(seed))
plt.clf()
jnp.save(prefix+"/"+str(seed)+"_params", out["runner_state"][0].params)
jnp.save(prefix+"/"+str(seed)+"_config", config)
jnp.save(prefix+"/"+str(seed)+"_metrics",out["metrics"])