Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dictionary Observations #243

Merged
merged 96 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
c8d0914
First commit
J-Travnik Nov 24, 2020
b7f3386
Fixing missing refs from a quick merge from master
J-Travnik Nov 25, 2020
21fecd3
Reformat
araffin Nov 26, 2020
b2a1c14
Adding DictBuffers
J-Travnik Nov 26, 2020
1cda60b
Adding DictBuffers
J-Travnik Nov 26, 2020
8a04e61
Reformat
araffin Nov 27, 2020
86b3c14
Minor reformat
araffin Nov 27, 2020
f60d439
added slow dict test. Added SACMultiInputPolicy for future. Added pri…
J-Travnik Nov 27, 2020
3d2e041
Merge branch 'feat/dict_observations' of https://github.com/J-Travnik…
J-Travnik Nov 27, 2020
da1de6e
Ran black on buffers
J-Travnik Nov 27, 2020
51249da
Ran isort
J-Travnik Nov 27, 2020
761a67f
Adding StackedObservations classes used within VecStackEnvs wrappers.…
J-Travnik Nov 30, 2020
3cb69f5
Running isort :facepalm
J-Travnik Nov 30, 2020
82fe425
Fixed typing issues
araffin Dec 1, 2020
201799d
Adding docstrings and typing. Using util for moving data to device.
J-Travnik Dec 1, 2020
683bbf2
Fixed trailing commas
J-Travnik Dec 1, 2020
887e007
Merging pull of previous format
J-Travnik Dec 1, 2020
15ceb35
Fix types
araffin Dec 2, 2020
f9cab8a
Minor edits
araffin Dec 2, 2020
5b178f4
Avoid duplicating code
araffin Dec 2, 2020
d692027
Fix calls to parents
araffin Dec 2, 2020
b5249ec
Merge branch 'master' into feat/dict_observations
araffin Dec 2, 2020
8b22f96
Adding assert to buffers. Updating changelong
J-Travnik Dec 2, 2020
70dfa83
Running format on buffers
J-Travnik Dec 2, 2020
b2b7d6f
Merge branch 'master' into feat/dict_observations
araffin Dec 6, 2020
a006b5a
Adding multi-input policies to dqn,td3,a2c. Fixing warnings. Fixed bu…
J-Travnik Dec 8, 2020
a94e6df
Merge branch 'master' into feat/dict_observations
araffin Dec 8, 2020
12361ae
Fixing warnings, splitting is_vectorized_observation into multiple fu…
J-Travnik Dec 8, 2020
9c6390b
Merge branch 'feat/dict_observations' of https://github.com/J-Travnik…
J-Travnik Dec 8, 2020
51cb4e4
Merge branch 'master' into feat/dict_observations
araffin Dec 10, 2020
ce0f1a4
Created envs folder in common. Updated imports. Moved stacked_obs to …
J-Travnik Dec 14, 2020
9eee82a
Moved envs to envs directory. Moved stacked obs to vec_envs. Started …
J-Travnik Dec 14, 2020
c6a8705
Merge branch 'master' into feat/dict_observations
araffin Dec 16, 2020
c3d2138
Fixes
araffin Dec 16, 2020
c893faa
Merged with master. Added miniscule delay to prevent zero divide on d…
J-Travnik Jan 4, 2021
935cef9
Running code style
J-Travnik Jan 4, 2021
a07497b
Update docstrings on torch_layers
Miffyli Jan 6, 2021
96d1e64
Decapitalize non-constant variables
Miffyli Jan 6, 2021
245f4ab
Using NatureCNN architecture in combined extractor. Increasing img si…
J-Travnik Jan 6, 2021
715fec8
merged with latest
J-Travnik Jan 6, 2021
4dc1625
Update doc
araffin Jan 8, 2021
57c1926
Update doc
araffin Jan 8, 2021
0fa3650
Fix format
araffin Jan 8, 2021
20b217a
Merge branch 'master' into feat/dict_observations
araffin Jan 8, 2021
f6ab0bc
Merge branch 'master' into feat/dict_observations
araffin Jan 11, 2021
6206b36
Removing NineRoom env. Using nested preprocess. Removing mutable defa…
J-Travnik Jan 12, 2021
f064972
running code style
J-Travnik Jan 12, 2021
90d2577
Passing channel check through to stacked dict observations.
J-Travnik Jan 14, 2021
8f37cb2
Running black
J-Travnik Jan 14, 2021
2984756
Adding channel control to SimpleMultiObsEnv. Passing check_channels t…
J-Travnik Jan 15, 2021
324ef43
Remove optimize memory for dict buffers
araffin Jan 18, 2021
2fdcfc6
Update doc
araffin Jan 18, 2021
2bab0a3
Move identity env
araffin Jan 18, 2021
b1ec40d
Minor edits + bump version
araffin Jan 18, 2021
12d42e9
Update doc
araffin Jan 18, 2021
5f45044
Fix doc build
araffin Jan 18, 2021
510821b
Bug fixes + add support for more type of dict env
araffin Jan 18, 2021
0b09976
Merge branch 'master' into feat/dict_observations
araffin Jan 21, 2021
8d9183f
Fixes + add multi env test
araffin Jan 25, 2021
04170df
Merge branch 'master' into feat/dict_observations
araffin Feb 1, 2021
b9c4f05
Merge branch 'master' into feat/dict_observations
araffin Feb 6, 2021
3bb747a
Add support for vectranspose
Miffyli Feb 18, 2021
cda8c21
Fix stacked obs for dict and add tests
Miffyli Feb 19, 2021
f770217
Add check for nested spaces. Fix dict-subprocvecenv test
Miffyli Feb 19, 2021
5cbde19
Fix (single) pytype error
Miffyli Feb 19, 2021
4464744
Simplify CombinedExtractor
Miffyli Feb 19, 2021
1f5553a
Fix tests
araffin Feb 19, 2021
dda6990
Merge branch 'master' into feat/dict_observations
araffin Feb 19, 2021
6716567
Fix check
araffin Feb 19, 2021
e756793
Merge branch 'master' into feat/dict_observations
araffin Mar 2, 2021
32b899f
Fix for net_arch with dict and vector obs
araffin Mar 2, 2021
6652df3
Merge branch 'master' into feat/dict_observations
araffin Mar 2, 2021
ec3356e
Fixes
araffin Mar 2, 2021
a369bb1
Merge branch 'master' into feat/dict_observations
araffin Mar 6, 2021
4f787fa
Add consistency test
araffin Mar 9, 2021
e945ec1
Update env checker
araffin Mar 9, 2021
12c8be0
Merge branch 'master' into feat/dict_observations
araffin Mar 17, 2021
a4851b1
Merge branch 'master' into feat/dict_observations
araffin Mar 25, 2021
4138f96
Add some docs on dict obs
Miffyli Apr 5, 2021
4f12135
Merge branch 'master' into feat/dict_observations
araffin Apr 6, 2021
613a141
Merge branch 'master' into feat/dict_observations
araffin Apr 21, 2021
0bcfa11
Update default CNN feature vector size
Miffyli Apr 22, 2021
10f4f6b
Merge branch 'master' into feat/dict_observations
araffin Apr 27, 2021
652a6d0
Refactor HER (#351)
araffin May 3, 2021
bcd97cd
Merge remote-tracking branch 'github/tmp/dict-obs' into feat/dict_obs…
araffin May 3, 2021
89607cf
Update doc and minor fixes
araffin May 3, 2021
495cf5d
Update doc
araffin May 3, 2021
0ea5c61
Added note about MultiInputPolicy in error of NatureCNN
J-Travnik May 8, 2021
f8351ab
Merge branch 'master' into feat/dict_observations
araffin May 10, 2021
94cb760
Address comments
Miffyli May 10, 2021
5d56c34
Naming clarifications
Miffyli May 10, 2021
ab75dcd
merge master
Miffyli May 10, 2021
c30916e
Actually saving the file would be nice
Miffyli May 10, 2021
0acea97
Fix edge case when doing online sampling with HER
araffin May 11, 2021
d6a59f9
Cleanup
araffin May 11, 2021
ce848fb
Add sanity check
araffin May 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Pre-Release 0.11.0a1 (WIP)

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Breaks HER as HER needs to be updated to use the new dictionary observations
J-Travnik marked this conversation as resolved.
Show resolved Hide resolved
- ``evaluate_policy`` now returns rewards/episode lengths from a ``Monitor`` wrapper if one is present,
this allows to return the unnormalized reward in the case of Atari games for instance.
- Renamed ``common.vec_env.is_wrapped`` to ``common.vec_env.is_vecenv_wrapped`` to avoid confusion
Expand All @@ -19,13 +20,16 @@ New Features:
automatic check for image spaces.
- ``VecFrameStack`` now has a ``channels_order`` argument to tell if observations should be stacked
on the first or last observation dimension (originally always stacked on last).
- Add support for dictionary observations in both RolloutBuffer (need to be tested in ReplayBuffer)
- Added simple 4x4 and 9room test environments
- Added ``common.env_util.is_wrapped`` and ``common.env_util.unwrap_wrapper`` functions for checking/unwrapping
an environment for specific wrapper.
- Added ``env_is_wrapped()`` method for ``VecEnv`` to check if its environments are wrapped
with given Gym wrappers.
- Added ``monitor_kwargs`` parameter to ``make_vec_env`` and ``make_atari_env``
- Wrap the environments automatically with a ``Monitor`` wrapper when possible.


Bug Fixes:
^^^^^^^^^^
- Fixed bug where code added VecTranspose on channel-first image environments (thanks @qxcv)
Expand Down
87 changes: 87 additions & 0 deletions multi_input_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import argparse
J-Travnik marked this conversation as resolved.
Show resolved Hide resolved
import gym
import numpy as np

from stable_baselines3 import PPO, SAC
from stable_baselines3.common.policies import MultiInputActorCriticPolicy
from stable_baselines3.common.vec_env import (
DummyVecEnv,
VecFrameStack,
VecTransposeImage,
)

from stable_baselines3.common.multi_input_envs import (
SimpleMultiObsEnv,
NineRoomMultiObsEnv,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Runs the multi_input_tests script")
parser.add_argument(
"--timesteps",
type=int,
default=30000,
help="Number of timesteps to train for (default: 20000)",
)
parser.add_argument(
"--num_envs",
type=int,
default=10,
help="Number of environments to use (default: 10)",
)
parser.add_argument(
"--frame_stacks",
type=int,
default=1,
help="Number of stacked frames to use (default: 4)",
)
parser.add_argument(
"--room9",
action="store_true",
help="If true, uses more complex 9 room environment",
)
args = parser.parse_args()

ENV_CLS = NineRoomMultiObsEnv if args.room9 else SimpleMultiObsEnv

make_env = lambda: ENV_CLS(random_start=True)

env = DummyVecEnv([make_env for i in range(args.num_envs)])
if args.frame_stacks > 1:
env = VecFrameStack(env, n_stack=args.frame_stacks)

model = PPO(MultiInputActorCriticPolicy, env)

model.learn(args.timesteps)
env.close()
print("Done training, starting testing")

make_env = lambda: ENV_CLS(random_start=False)
test_env = DummyVecEnv([make_env])
if args.frame_stacks > 1:
test_env = VecFrameStack(test_env, n_stack=args.frame_stacks)

obs = test_env.reset()
num_episodes = 1
trajectories = [[]]
i_step, i_episode = 0, 0
while i_episode < num_episodes:
action, _states = model.predict(obs, deterministic=False)
obs, reward, done, info = test_env.step(action)
test_env.render()
trajectories[-1].append((test_env.get_attr("state")[0], action[0]))

i_step += 1

if done[0]:
if info[0]["got_to_end"]:
print(f"Episode {i_episode} : Got to end in {i_step} steps")
else:
print(f"Episode {i_episode} : Did not get to end")
obs = test_env.reset()
i_step = 0
trajectories.append([])
i_episode += 1

test_env.close()
4 changes: 3 additions & 1 deletion stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> Ve
env = VecTransposeImage(env)

# check if wrapper for dict support is needed when using HER
if isinstance(env.observation_space, gym.spaces.dict.Dict):
if isinstance(env.observation_space, gym.spaces.dict.Dict) and set(env.observation_space.spaces.keys()) == set(
araffin marked this conversation as resolved.
Show resolved Hide resolved
["observation", "desired_goal", "achieved_goal"]
):
env = ObsDictWrapper(env)

return env
Expand Down
156 changes: 141 additions & 15 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
psutil = None

from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
from stable_baselines3.common.type_aliases import ReplayBufferSamples, RolloutBufferSamples
from stable_baselines3.common.type_aliases import (
DictReplayBufferSamples,
DictRolloutBufferSamples,
ReplayBufferSamples,
RolloutBufferSamples,
)
from stable_baselines3.common.vec_env import VecNormalize


Expand Down Expand Up @@ -42,6 +47,7 @@ def __init__(
self.observation_space = observation_space
self.action_space = action_space
self.obs_shape = get_obs_shape(observation_space)
self.is_dict_data = isinstance(self.observation_space, spaces.Dict)
self.action_dim = get_action_dim(action_space)
self.pos = 0
self.full = False
Expand Down Expand Up @@ -130,7 +136,8 @@ def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:

@staticmethod
def _normalize_obs(
obs: Union[np.ndarray, Dict[str, np.ndarray]], env: Optional[VecNormalize] = None
obs: Union[np.ndarray, Dict[str, np.ndarray]],
env: Optional[VecNormalize] = None,
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
if env is not None:
return env.normalize_obs(obs)
Expand Down Expand Up @@ -177,20 +184,50 @@ def __init__(
mem_available = psutil.virtual_memory().available

self.optimize_memory_usage = optimize_memory_usage
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)

if self.is_dict_data:
self.observations = {
key: np.zeros((self.buffer_size, self.n_envs) + _obs_shape) for key, _obs_shape in self.obs_shape.items()
}
else:
self.observations = np.zeros(
(self.buffer_size, self.n_envs) + self.obs_shape,
dtype=observation_space.dtype,
)
if optimize_memory_usage:
# `observations` contains also the next observation
self.next_observations = None
else:
self.next_observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
if self.is_dict_data:
self.next_observations = {
key: np.zeros((self.buffer_size, self.n_envs) + _obs_shape) for key, _obs_shape in self.obs_shape.items()
}
else:
self.next_observations = np.zeros(
(self.buffer_size, self.n_envs) + self.obs_shape,
dtype=observation_space.dtype,
)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)

if psutil is not None:
total_memory_usage = self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
obs_nbytes = 0
if self.is_dict_data:
for key, obs in self.observations.items():
obs_nbytes += obs.nbytes
else:
obs_nbytes = self.observations.nbytes

total_memory_usage = obs_nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
if self.next_observations is not None:
total_memory_usage += self.next_observations.nbytes
next_obs_nbytes = 0
if self.is_dict_data:
for key, obs in self.observations.items():
next_obs_nbytes += obs.nbytes
else:
next_obs_nbytes = self.next_observations.nbytes
total_memory_usage += next_obs_nbytes

if total_memory_usage > mem_available:
# Convert to GB
Expand All @@ -201,13 +238,34 @@ def __init__(
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
)

def add(self, obs: np.ndarray, next_obs: np.ndarray, action: np.ndarray, reward: np.ndarray, done: np.ndarray) -> None:
def add(
self,
obs: Union[np.ndarray, dict],
next_obs: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
) -> None:
# Copy to avoid modification by reference
self.observations[self.pos] = np.array(obs).copy()

if self.is_dict_data:
for key in self.observations.keys():
self.observations[key][self.pos] = np.array(obs[key]).copy()
else:
self.observations[self.pos] = np.array(obs).copy()

if self.optimize_memory_usage:
self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs).copy()
if self.is_dict_data:
for key in self.observations.keys():
self.observations[key][(self.pos + 1) % self.buffer_size] = np.array(next_obs[key]).copy()
else:
self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs).copy()
else:
self.next_observations[self.pos] = np.array(next_obs).copy()
if self.is_dict_data:
for key in self.next_observations.keys():
self.next_observations[key][self.pos] = np.array(next_obs[key]).copy()
else:
self.next_observations[self.pos] = np.array(next_obs).copy()

self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
Expand Down Expand Up @@ -241,6 +299,35 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB
return self._get_samples(batch_inds, env=env)

def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:

if self.is_dict_data:
if self.optimize_memory_usage:
next_obs = {
key: self.to_torch(
self._normalize_obs(
obs[(batch_inds + 1) % self.buffer_size, 0, :],
env,
)
)
for key, obs in self.observations.items()
}
else:
next_obs = {
key: self.to_torch(self._normalize_obs(obs[batch_inds, 0, :], env))
for key, obs in self.next_observations.items()
}

normalized_obs = {
key: self.to_torch(self._normalize_obs(obs[batch_inds, 0, :], env)) for key, obs in self.observations.items()
}
return DictReplayBufferSamples(
observations=normalized_obs,
actions=self.to_torch(self.actions[batch_inds]),
next_observations=next_obs,
dones=self.to_torch(self.dones[batch_inds]),
returns=self.to_torch(self._normalize_reward(self.rewards[batch_inds], env)),
)

if self.optimize_memory_usage:
next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, 0, :], env)
else:
Expand Down Expand Up @@ -293,13 +380,24 @@ def __init__(
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
self.gae_lambda = gae_lambda
self.gamma = gamma
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
self.observations, self.actions, self.rewards, self.advantages = (
None,
None,
None,
None,
)
self.returns, self.dones, self.values, self.log_probs = None, None, None, None
self.generator_ready = False
self.reset()

def reset(self) -> None:
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32)

if self.is_dict_data:
self.observations = {}
for (key, obs_input_shape) in self.obs_shape.items():
self.observations[key] = np.zeros((self.buffer_size, self.n_envs) + obs_input_shape, dtype=np.float32)
else:
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
Expand Down Expand Up @@ -342,7 +440,13 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarra
self.returns = self.advantages + self.values

def add(
self, obs: np.ndarray, action: np.ndarray, reward: np.ndarray, done: np.ndarray, value: th.Tensor, log_prob: th.Tensor
self,
obs: Union[np.ndarray, dict],
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
value: th.Tensor,
log_prob: th.Tensor,
) -> None:
"""
:param obs: Observation
Expand All @@ -358,7 +462,11 @@ def add(
# Reshape 0-d tensor to avoid error
log_prob = log_prob.reshape(-1, 1)

self.observations[self.pos] = np.array(obs).copy()
if self.is_dict_data:
for key in self.observations.keys():
self.observations[key][self.pos] = np.array(obs[key]).copy()
else:
self.observations[self.pos] = np.array(obs).copy()
self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
self.dones[self.pos] = np.array(done).copy()
Expand All @@ -373,7 +481,15 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSample
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:
for tensor in ["observations", "actions", "values", "log_probs", "advantages", "returns"]:

_tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]
if self.is_dict_data:
for key, obs in self.observations.items():
self.observations[key] = self.swap_and_flatten(obs)
else:
_tensor_names.append("observations")

for tensor in _tensor_names:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True

Expand All @@ -387,6 +503,16 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSample
start_idx += batch_size

def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples:
if self.is_dict_data:
return DictRolloutBufferSamples(
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
actions=self.to_torch(self.actions[batch_inds]),
old_values=self.to_torch(self.values[batch_inds].flatten()),
old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
advantages=self.to_torch(self.advantages[batch_inds].flatten()),
returns=self.to_torch(self.returns[batch_inds].flatten()),
)

data = (
self.observations[batch_inds],
self.actions[batch_inds],
Expand Down
Loading