-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathenv.py
executable file
·67 lines (53 loc) · 1.82 KB
/
env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
import numpy as np
import torch
import random
import gym
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
class DefaultDictWrapper(gym.Wrapper):
def __init__(self, env):
gym.Wrapper.__init__(self, env)
self.env = env
self.success = False
@property
def unwrapped(self):
return self.env.unwrapped
def reset(self, **kwargs):
self.success = False
return self.env.reset(**kwargs)
def step(self, action):
obs, reward, done, info = self.env.step(action)
info = defaultdict(float, info)
self.success = self.success or bool(info["success"])
info["success"] = float(self.success)
return obs, reward, done, info
def make_env(cfg):
"""
Make environment for experiments.
"""
domain, _ = cfg.task.split("-", 1)
if domain == "mw": # Meta-World
from tasks.metaworld import make_metaworld_env
env = make_metaworld_env(cfg)
elif domain == "adroit": # Adroit
from tasks.adroit import make_adroit_env
env = make_adroit_env(cfg)
else: # DMControl
from tasks.dmcontrol import make_dmcontrol_env
env = make_dmcontrol_env(cfg)
env = DefaultDictWrapper(env)
cfg.domain = domain
cfg.obs_shape = tuple(int(x) for x in env.observation_space.shape)
cfg.action_shape = tuple(int(x) for x in env.action_space.shape)
cfg.action_dim = env.action_space.shape[0]
return env