Skip to content

Commit

Permalink
[BugFix, Doc] Fix tutos (#1107)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 28, 2023
1 parent 25370f7 commit 50f0db0
Show file tree
Hide file tree
Showing 14 changed files with 86 additions and 25 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ jobs:
id: build_doc
run: |
cd ./docs
#timeout 7m bash -ic "MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build
timeout 7m bash -ic "MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
cd ..
- name: Install rsync 📚
run: |
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Intermediate

tutorials/torch_envs
tutorials/pretrained_models
tutorials/dqn_with_rnn.py

Advanced
--------
Expand Down
8 changes: 6 additions & 2 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,9 @@ def __init__(
):
self.closed = True

exploration_type = _convert_exploration_type(exploration_mode, exploration_type)
exploration_type = _convert_exploration_type(
exploration_mode=exploration_mode, exploration_type=exploration_type
)
if create_env_kwargs is None:
create_env_kwargs = {}
if not isinstance(create_env_fn, EnvBase):
Expand Down Expand Up @@ -1049,7 +1051,9 @@ def __init__(
devices=None,
storing_devices=None,
):
exploration_type = _convert_exploration_type(exploration_mode, exploration_type)
exploration_type = _convert_exploration_type(
exploration_mode=exploration_mode, exploration_type=exploration_type
)
self.closed = True
self.create_env_fn = create_env_fn
self.num_workers = len(create_env_fn)
Expand Down
4 changes: 3 additions & 1 deletion torchrl/collectors/distributed/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,9 @@ def __init__(
launcher="submitit",
tcp_port=None,
):
exploration_type = _convert_exploration_type(exploration_mode, exploration_type)
exploration_type = _convert_exploration_type(
exploration_mode=exploration_mode, exploration_type=exploration_type
)

if collector_class == "async":
collector_class = MultiaSyncDataCollector
Expand Down
4 changes: 3 additions & 1 deletion torchrl/collectors/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ def __init__(
visible_devices=None,
tensorpipe_options=None,
):
exploration_type = _convert_exploration_type(exploration_mode, exploration_type)
exploration_type = _convert_exploration_type(
exploration_mode=exploration_mode, exploration_type=exploration_type
)
if collector_class == "async":
collector_class = MultiaSyncDataCollector
elif collector_class == "sync":
Expand Down
4 changes: 3 additions & 1 deletion torchrl/collectors/distributed/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ def __init__(
launcher="submitit",
tcp_port=None,
):
exploration_type = _convert_exploration_type(exploration_mode, exploration_type)
exploration_type = _convert_exploration_type(
exploration_mode=exploration_mode, exploration_type=exploration_type
)

if collector_class == "async":
collector_class = MultiaSyncDataCollector
Expand Down
27 changes: 26 additions & 1 deletion torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,37 @@ class GymWrapper(GymLikeEnv):
git_url = "https://github.com/openai/gym"
libname = "gym"

@staticmethod
def get_library_name(env):
# try gym
try:
import gym

if isinstance(env.action_space, gym.spaces.space.Space):
return gym
except ImportError:
pass
try:
import gymnasium

if isinstance(env.action_space, gymnasium.spaces.space.Space):
return gymnasium
except ImportError:
pass
raise RuntimeError(
f"Could not find the library of env {env}. Please file an issue on torchrl github repo."
)

def __init__(self, env=None, categorical_action_encoding=False, **kwargs):
if env is not None:
kwargs["env"] = env
self._seed_calls_reset = None
self._categorical_action_encoding = categorical_action_encoding
super().__init__(**kwargs)
if "env" in kwargs:
with set_gym_backend(self.get_library_name(kwargs["env"])):
super().__init__(**kwargs)
else:
super().__init__(**kwargs)

def _check_kwargs(self, kwargs: Dict):
if "env" not in kwargs:
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
AVAILABLE_LIBRARIES = {pkg.key for pkg in pkg_resources.working_set}


def _convert_exploration_type(exploration_mode, exploration_type):
def _convert_exploration_type(*, exploration_mode, exploration_type):
if exploration_mode is not None:
return ExplorationType.from_str(exploration_mode)
return exploration_type
Expand Down
12 changes: 7 additions & 5 deletions tutorials/sphinx-tutorials/coding_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,8 @@ def _loss_actor(
tensordict,
) -> torch.Tensor:
td_copy = tensordict.select(*self.actor_in_keys)
# Get an action from the actor network
td_copy = self.actor_network(
td_copy,
)
# Get an action from the actor network: since we made it functional, we need to pass the params
td_copy = self.actor_network(td_copy, params=self.actor_network_params)
# get the value associated with that action
td_copy = self.value_network(
td_copy,
Expand Down Expand Up @@ -482,6 +480,7 @@ def make_env(from_pixels=False):
CatTensors,
DoubleToFloat,
EnvCreator,
InitTracker,
ObservationNorm,
ParallelEnv,
RewardScaling,
Expand Down Expand Up @@ -536,6 +535,9 @@ def make_transformed_env(

env.append_transform(StepCounter(max_frames_per_traj))

# We need a marker for the start of trajectories for our OU exploration:
env.append_transform(InitTracker())

return env


Expand Down Expand Up @@ -889,7 +891,7 @@ def make_recorder(actor_model_explore, transform_state_dict, record_interval):
record_frames=1000,
policy_exploration=actor_model_explore,
environment=environment,
exploration_type="mode",
exploration_type=ExplorationType.MEAN,
record_interval=record_interval,
)
return recorder_obj
Expand Down
5 changes: 3 additions & 2 deletions tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,8 @@ def get_collector(


def get_loss_module(actor, gamma):
loss_module = DQNLoss(actor, gamma=gamma, delay_value=True)
loss_module = DQNLoss(actor, delay_value=True)
loss_module.make_value_estimator(gamma=gamma)
target_updater = SoftUpdate(loss_module)
return loss_module, target_updater

Expand Down Expand Up @@ -617,7 +618,7 @@ def get_loss_module(actor, gamma):
frame_skip=1,
policy_exploration=actor_explore,
environment=test_env,
exploration_type="mode",
exploration_type=ExplorationType.MODE,
log_keys=[("next", "reward")],
out_keys={("next", "reward"): "rewards"},
log_pbar=True,
Expand Down
4 changes: 1 addition & 3 deletions tutorials/sphinx-tutorials/dqn_with_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,7 @@
# For the sake of efficiency, we're only running a few thousands iterations
# here. In a real setting, the total number of frames should be set to 1M.
#
collector = SyncDataCollector(
env, stoch_policy, frames_per_batch=50, total_frames=200
)
collector = SyncDataCollector(env, stoch_policy, frames_per_batch=50, total_frames=200)
rb = TensorDictReplayBuffer(
storage=LazyMemmapStorage(20_000), batch_size=4, prefetch=10
)
Expand Down
6 changes: 3 additions & 3 deletions tutorials/sphinx-tutorials/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,17 +845,17 @@ def simple_rollout(steps=100):
for _ in pbar:
init_td = env.reset(env.gen_params(batch_size=[batch_size]))
rollout = env.rollout(100, policy, tensordict=init_td, auto_reset=False)
traj_return = rollout["reward"].mean()
traj_return = rollout["next", "reward"].mean()
(-traj_return).backward()
gn = torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
optim.step()
optim.zero_grad()
pbar.set_description(
f"reward: {traj_return: 4.4f}, "
f"last reward: {rollout[..., -1]['reward'].mean(): 4.4f}, gradient norm: {gn: 4.4}"
f"last reward: {rollout[..., -1]['next', 'reward'].mean(): 4.4f}, gradient norm: {gn: 4.4}"
)
logs["return"].append(traj_return.item())
logs["last_reward"].append(rollout[..., -1]["reward"].mean().item())
logs["last_reward"].append(rollout[..., -1]["next", "reward"].mean().item())
scheduler.step()


Expand Down
23 changes: 23 additions & 0 deletions tutorials/sphinx-tutorials/run_local.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/bin/bash

set -e
set -v

# Allows you to run all the tutorials without building the docset.

DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"

# loop through all the .py files in the directory
for file in $(ls -r "$DIR"/*.py)
do
# execute each Python script using the 'exec' function
echo $file
python -c """
with open('$file') as f:
source = f.read()
code = compile(source, '$file', 'exec')
exec(code)
"""
done
8 changes: 5 additions & 3 deletions tutorials/sphinx-tutorials/torchrl_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,10 +732,12 @@ def forward(self, obs, action):
tensordict = TensorDict(
{
"observation": torch.randn(10, 3),
"next": {"observation": torch.randn(10, 3)},
"reward": torch.randn(10, 1),
"next": {
"observation": torch.randn(10, 3),
"reward": torch.randn(10, 1),
"done": torch.zeros(10, 1, dtype=torch.bool),
},
"action": torch.randn(10, 1),
"done": torch.zeros(10, 1, dtype=torch.bool),
},
batch_size=[10],
device="cpu",
Expand Down

0 comments on commit 50f0db0

Please sign in to comment.