-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
76 lines (60 loc) · 2.33 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import warnings
warnings.filterwarnings("ignore")
from jax_sandbox.common import *
from jax_sandbox.imitation import *
from jax_sandbox.actor_critic import *
from jax_sandbox.policy_gradient import *
from jax_sandbox.value_based_methods import *
import hydra
import envs.dmc as dmc
from utils import *
def make_env(env_name, seed):
if env_name in ['walker_walk', 'cheetah_run', 'humanoid_walk', 'finger_turn_hard', 'cartpole_swingup', 'hopper_hop', 'quadruped_walk', 'reacher_hard']:
env = dmc.make(env_name, seed=seed)
else:
env = make_gym_env(env_name, seed)
return env
def get_observation_action_spec(env):
if hasattr(env, 'observation_space'):
obs_shape = env.observation_space.shape
if isinstance(env.action_space, gym.spaces.Box):
action_shape = env.action_space.shape[0]
else:
action_shape = env.action_space.n
else:
obs_shape = env.observation_spec().shape
action_shape = env.action_spec().shape
return obs_shape, action_shape
OFFLINE_ALGOS = ['cql', 'td3_bc', 'milo']
class Workspace:
def __init__(self, cfg):
self.cfg = cfg
self.setup()
print('done setting up')
if cfg.alg == 'bc':
self.learner = bc.BC(cfg)
elif cfg.alg == 'gail':
self.learner = gail.GAIL(cfg)
elif cfg.alg == 'reinforce':
self.learner = reinforce.REINFORCE(cfg)
elif cfg.alg == 'ddpg':
self.learner = ddpg.DDPG(cfg)
elif cfg.alg == 'sac':
self.learner = sac.SAC(cfg)
elif cfg.alg == 'dqn':
self.learner = dqn.DQN(cfg)
else:
raise ValueError('RL algorithm not implemented yet.')
def setup(self):
# setup env stuff and fill in unknown cfg values
self.train_env = make_env(self.cfg.task, self.cfg.seed)
self.eval_env = make_env(self.cfg.task, self.cfg.seed)
self.cfg.obs_shape, self.cfg.action_shape = get_observation_action_spec(self.train_env)
self.cfg.continuous = is_discrete(self.cfg.task)
self.cfg.img_input = len(self.cfg.obs_shape) == 3
# dataset/dataloader
@hydra.main(config_path="cfgs", config_name="config")
def main(cfg):
ws = Workspace(cfg)
if __name__ == '__main__':
main()