-
Notifications
You must be signed in to change notification settings - Fork 2
/
wrappers.py
201 lines (155 loc) · 6.68 KB
/
wrappers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import jax
import jax.numpy as jnp
import chex
import numpy as np
from flax import struct
from functools import partial
from typing import Optional, Tuple, Union, Any
#Taken from https://github.com/MichaelTMatthews/Craftax_Baselines/blob/main/wrappers.py
class GymnaxWrapper(object):
"""Base class for Gymnax wrappers."""
def __init__(self, env):
self._env = env
# provide proxy access to regular attributes of wrapped object
def __getattr__(self, name):
return getattr(self._env, name)
class BatchEnvWrapper(GymnaxWrapper):
"""Batches reset and step functions"""
def __init__(self, env, num_envs: int):
super().__init__(env)
self.num_envs = num_envs
self.reset_fn = jax.vmap(self._env.reset, in_axes=(0, None))
self.step_fn = jax.vmap(self._env.step, in_axes=(0, 0, 0, None))
@partial(jax.jit, static_argnums=(0, 2))
def reset(self, rng, params=None):
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, self.num_envs)
obs, env_state = self.reset_fn(rngs, params)
return obs, env_state
@partial(jax.jit, static_argnums=(0, 4))
def step(self, rng, state, action, params=None):
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, self.num_envs)
obs, state, reward, done, info = self.step_fn(rngs, state, action, params)
return obs, state, reward, done, info
class AutoResetEnvWrapper(GymnaxWrapper):
"""Provides standard auto-reset functionality, providing the same behaviour as Gymnax-default."""
def __init__(self, env):
super().__init__(env)
@partial(jax.jit, static_argnums=(0, 2))
def reset(self, key, params=None):
return self._env.reset(key, params)
@partial(jax.jit, static_argnums=(0, 4))
def step(self, rng, state, action, params=None):
rng, _rng = jax.random.split(rng)
obs_st, state_st, reward, done, info = self._env.step(
_rng, state, action, params
)
rng, _rng = jax.random.split(rng)
obs_re, state_re = self._env.reset(_rng, params)
# Auto-reset environment based on termination
def auto_reset(done, state_re, state_st, obs_re, obs_st):
state = jax.tree_map(
lambda x, y: jax.lax.select(done, x, y), state_re, state_st
)
obs = jax.lax.select(done, obs_re, obs_st)
return obs, state
obs, state = auto_reset(done, state_re, state_st, obs_re, obs_st)
return obs, state, reward, done, info
class OptimisticResetVecEnvWrapper(GymnaxWrapper):
"""
Provides efficient 'optimistic' resets.
The wrapper also necessarily handles the batching of environment steps and resetting.
reset_ratio: the number of environment workers per environment reset. Higher means more efficient but a higher
chance of duplicate resets.
"""
def __init__(self, env, num_envs: int, reset_ratio: int):
super().__init__(env)
self.num_envs = num_envs
self.reset_ratio = reset_ratio
assert (
num_envs % reset_ratio == 0
), "Reset ratio must perfectly divide num envs."
self.num_resets = self.num_envs // reset_ratio
self.reset_fn = jax.vmap(self._env.reset, in_axes=(0, None))
self.step_fn = jax.vmap(self._env.step, in_axes=(0, 0, 0, None))
@partial(jax.jit, static_argnums=(0, 2))
def reset(self, rng, params=None):
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, self.num_envs)
obs, env_state = self.reset_fn(rngs, params)
return obs, env_state
@partial(jax.jit, static_argnums=(0, 4))
def step(self, rng, state, action, params=None):
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, self.num_envs)
obs_st, state_st, reward, done, info = self.step_fn(rngs, state, action, params)
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, self.num_resets)
obs_re, state_re = self.reset_fn(rngs, params)
rng, _rng = jax.random.split(rng)
reset_indexes = jnp.arange(self.num_resets).repeat(self.reset_ratio)
being_reset = jax.random.choice(
_rng,
jnp.arange(self.num_envs),
shape=(self.num_resets,),
p=done,
replace=False,
)
reset_indexes = reset_indexes.at[being_reset].set(jnp.arange(self.num_resets))
obs_re = obs_re[reset_indexes]
state_re = jax.tree_map(lambda x: x[reset_indexes], state_re)
# Auto-reset environment based on termination
def auto_reset(done, state_re, state_st, obs_re, obs_st):
state = jax.tree_map(
lambda x, y: jax.lax.select(done, x, y), state_re, state_st
)
obs = jax.lax.select(done, obs_re, obs_st)
return state, obs
state, obs = jax.vmap(auto_reset)(done, state_re, state_st, obs_re, obs_st)
return obs, state, reward, done, info
@struct.dataclass
class LogEnvState:
env_state: Any
episode_returns: float
episode_lengths: int
returned_episode_returns: float
returned_episode_lengths: int
timestep: int
class LogWrapper(GymnaxWrapper):
"""Log the episode returns and lengths."""
def __init__(self, env):
super().__init__(env)
@partial(jax.jit, static_argnums=(0, 2))
def reset(self, key: chex.PRNGKey, params=None):
obs, env_state = self._env.reset(key, params)
state = LogEnvState(env_state, 0.0, 0, 0.0, 0, 0)
return obs, state
@partial(jax.jit, static_argnums=(0, 4))
def step(
self,
key: chex.PRNGKey,
state,
action: Union[int, float],
params=None,
):
obs, env_state, reward, done, info = self._env.step(
key, state.env_state, action, params
)
new_episode_return = state.episode_returns + reward
new_episode_length = state.episode_lengths + 1
state = LogEnvState(
env_state=env_state,
episode_returns=new_episode_return * (1 - done),
episode_lengths=new_episode_length * (1 - done),
returned_episode_returns=state.returned_episode_returns * (1 - done)
+ new_episode_return * done,
returned_episode_lengths=state.returned_episode_lengths * (1 - done)
+ new_episode_length * done,
timestep=state.timestep + 1,
)
info["returned_episode_returns"] = state.returned_episode_returns
info["returned_episode_lengths"] = state.returned_episode_lengths
info["timestep"] = state.timestep
info["returned_episode"] = done
return obs, state, reward, done, info