Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 14, 2024
2 parents 878152d + 9ce2f3f commit 93abaab
Show file tree
Hide file tree
Showing 14 changed files with 915 additions and 183 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ jobs:
cd ./docs
# timeout 7m bash -ic "MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
# bash -ic "PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build
PYOPENGL_PLATFORM=egl MUJOCO_GL=egl TORCHRL_CONSOLE_STREAM=stdout sphinx-build ./source _local_build
cd ..
cp -r docs/_local_build/* "${RUNNER_ARTIFACT_DIR}"
Expand Down
3 changes: 3 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ memory_profiler
pyrender
pytest
vmas
onnxscript
onnxruntime
onnx
3 changes: 3 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
version = f"main ({torchrl.__version__})"
release = "main"

os.environ["TORCHRL_CONSOLE_STREAM"] = "stdout"

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
Expand Down Expand Up @@ -95,6 +97,7 @@
"abort_on_example_error": False,
"only_warn_on_example_error": True,
"show_memory": True,
"capture_repr": ("_repr_html_", "__repr__"), # capture representations
}

napoleon_use_ivar = True
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ Intermediate
tutorials/pretrained_models
tutorials/dqn_with_rnn
tutorials/rb_tutorial
tutorials/export

Advanced
--------
Expand Down
65 changes: 65 additions & 0 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import pytest
import torch

import torchrl.modules
from tensordict import LazyStackedTensorDict, pad, TensorDict, unravel_key_list
from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential
from torch import nn
Expand Down Expand Up @@ -743,6 +745,41 @@ def test_set_temporal_mode(self):
lstm_module.parameters()
)

def test_python_cudnn(self):
lstm_module = LSTMModule(
input_size=3,
hidden_size=12,
batch_first=True,
dropout=0,
num_layers=2,
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
).set_recurrent_mode(True)
obs = torch.rand(10, 20, 3)

hidden0 = torch.rand(10, 20, 2, 12)
hidden1 = torch.rand(10, 20, 2, 12)

is_init = torch.zeros(10, 20, dtype=torch.bool)
assert isinstance(lstm_module.lstm, nn.LSTM)
outs_ref = lstm_module(
observation=obs, hidden0=hidden0, hidden1=hidden1, is_init=is_init
)

lstm_module.make_python_based()
assert isinstance(lstm_module.lstm, torchrl.modules.LSTM)
outs_rl = lstm_module(
observation=obs, hidden0=hidden0, hidden1=hidden1, is_init=is_init
)
torch.testing.assert_close(outs_ref, outs_rl)

lstm_module.make_cudnn_based()
assert isinstance(lstm_module.lstm, nn.LSTM)
outs_cudnn = lstm_module(
observation=obs, hidden0=hidden0, hidden1=hidden1, is_init=is_init
)
torch.testing.assert_close(outs_ref, outs_cudnn)

def test_noncontiguous(self):
lstm_module = LSTMModule(
input_size=3,
Expand Down Expand Up @@ -1088,6 +1125,34 @@ def test_set_temporal_mode(self):
gru_module.parameters()
)

def test_python_cudnn(self):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
dropout=0,
num_layers=2,
in_keys=["observation", "hidden0"],
out_keys=["intermediate", ("next", "hidden0")],
).set_recurrent_mode(True)
obs = torch.rand(10, 20, 3)

hidden0 = torch.rand(10, 20, 2, 12)

is_init = torch.zeros(10, 20, dtype=torch.bool)
assert isinstance(gru_module.gru, nn.GRU)
outs_ref = gru_module(observation=obs, hidden0=hidden0, is_init=is_init)

gru_module.make_python_based()
assert isinstance(gru_module.gru, torchrl.modules.GRU)
outs_rl = gru_module(observation=obs, hidden0=hidden0, is_init=is_init)
torch.testing.assert_close(outs_ref, outs_rl)

gru_module.make_cudnn_based()
assert isinstance(gru_module.gru, nn.GRU)
outs_cudnn = gru_module(observation=obs, hidden0=hidden0, is_init=is_init)
torch.testing.assert_close(outs_ref, outs_cudnn)

def test_noncontiguous(self):
gru_module = GRUModule(
input_size=3,
Expand Down
2 changes: 1 addition & 1 deletion test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def make_storage():
rb_trainer2.register(trainer2)
if re_init:
trainer2._process_batch_hook(td.to_tensordict().zero_())
trainer2.load_from_file(file)
trainer2.load_from_file(file, weights_only=False)
assert state_dict_has_been_called[0]
assert load_state_dict_has_been_called[0]
assert state_dict_has_been_called_td[0]
Expand Down
24 changes: 21 additions & 3 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,17 @@
# Remove all attached handlers
while logger.hasHandlers():
logger.removeHandler(logger.handlers[0])
console_handler = logging.StreamHandler()
stream_handlers = {
"stdout": sys.stdout,
"stderr": sys.stderr,
}
TORCHRL_CONSOLE_STREAM = os.getenv("TORCHRL_CONSOLE_STREAM")
if TORCHRL_CONSOLE_STREAM:
stream_handler = stream_handlers[TORCHRL_CONSOLE_STREAM]
else:
stream_handler = None
console_handler = logging.StreamHandler(stream=stream_handler)

console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s [%(name)s][%(levelname)s] %(message)s")
console_handler.setFormatter(formatter)
Expand Down Expand Up @@ -86,17 +96,25 @@ def __exit__(self, exc_type, exc_val, exc_tb):
val[2] = N

@staticmethod
def print(prefix=None): # noqa: T202
def print(prefix=None) -> str: # noqa: T202
"""Prints the state of the timer.
Returns:
the string printed using the logger.
"""
keys = list(timeit._REG)
keys.sort()
string = []
for name in keys:
strings = []
if prefix:
strings.append(prefix)
strings.append(
f"{name} took {timeit._REG[name][0] * 1000:4.4} msec (total = {timeit._REG[name][1]} sec)"
)
logger.info(" -- ".join(strings))
string.append(" -- ".join(strings))
logger.info(string[-1])
return "\n".join(string)

@classmethod
def todict(cls, percall=True, prefix=None):
Expand Down
26 changes: 25 additions & 1 deletion torchrl/modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ def __init__(
if not isinstance(out_features, Number):
_out_features_num = prod(out_features)
self.out_features = out_features
self._reshape_out = not isinstance(
self.out_features, (int, torch.SymInt, Number)
)
self._out_features_num = _out_features_num
self.activation_class = activation_class
self.norm_class = norm_class
Expand Down Expand Up @@ -302,7 +305,7 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor:
inputs = (torch.cat([*inputs], -1),)

out = super().forward(*inputs)
if not isinstance(self.out_features, Number):
if self._reshape_out:
out = out.view(*out.shape[:-1], *self.out_features)
return out

Expand Down Expand Up @@ -549,6 +552,27 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
out = out.unflatten(0, batch)
return out

@classmethod
def default_atari_dqn(cls, num_actions: int):
"""Returns the default DQN as presented in the seminal DQN paper.
Args:
num_actions (int): the action space of the atari game.
"""
cnn = ConvNet(
activation_class=torch.nn.ReLU,
num_cells=[32, 64, 64],
kernel_sizes=[8, 4, 3],
strides=[4, 2, 1],
)
mlp = MLP(
activation_class=torch.nn.ReLU,
out_features=num_actions,
num_cells=[512],
)
return nn.Sequential(cnn, mlp)


Conv2dNet = ConvNet

Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
out = action_tensordict.get(action_key)
eps = self.eps.item()
cond = torch.rand(action_tensordict.shape, device=out.device) < eps
# cond = torch.zeros(action_tensordict.shape, device=out.device, dtype=torch.bool).bernoulli_(eps)
cond = expand_as_right(cond, out)
spec = self.spec
if spec is not None:
Expand Down
108 changes: 107 additions & 1 deletion torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from tensordict.base import NO_DEFAULT

from tensordict.nn import TensorDictModuleBase as ModuleBase
from tensordict.nn import dispatch, TensorDictModuleBase as ModuleBase
from tensordict.utils import expand_as_right, prod, set_lazy_legacy

from torch import nn, Tensor
Expand Down Expand Up @@ -467,6 +467,8 @@ def __init__(
raise ValueError("The input lstm must have batch_first=True.")
if bidirectional:
raise ValueError("The input lstm cannot be bidirectional.")
if not hidden_size:
raise ValueError("hidden_size must be passed.")
if python_based:
lstm = LSTM(
input_size=input_size,
Expand Down Expand Up @@ -524,6 +526,58 @@ def __init__(
self.out_keys = out_keys
self._recurrent_mode = False

def make_python_based(self) -> LSTMModule:
"""Transforms the LSTM layer in its python-based version.
Returns:
self
"""
if isinstance(self.lstm, LSTM):
return self
lstm = LSTM(
input_size=self.lstm.input_size,
hidden_size=self.lstm.hidden_size,
num_layers=self.lstm.num_layers,
bias=self.lstm.bias,
dropout=self.lstm.dropout,
proj_size=self.lstm.proj_size,
device="meta",
batch_first=self.lstm.batch_first,
bidirectional=self.lstm.bidirectional,
)
from tensordict import from_module

from_module(self.lstm).to_module(lstm)
self.lstm = lstm
return self

def make_cudnn_based(self) -> LSTMModule:
"""Transforms the LSTM layer in its CuDNN-based version.
Returns:
self
"""
if isinstance(self.lstm, nn.LSTM):
return self
lstm = nn.LSTM(
input_size=self.lstm.input_size,
hidden_size=self.lstm.hidden_size,
num_layers=self.lstm.num_layers,
bias=self.lstm.bias,
dropout=self.lstm.dropout,
proj_size=self.lstm.proj_size,
device="meta",
batch_first=self.lstm.batch_first,
bidirectional=self.lstm.bidirectional,
)
from tensordict import from_module

from_module(self.lstm).to_module(lstm)
self.lstm = lstm
return self

def make_tensordict_primer(self):
"""Makes a tensordict primer for the environment.
Expand Down Expand Up @@ -644,6 +698,7 @@ def set_recurrent_mode(self, mode: bool = True):
out._recurrent_mode = mode
return out

@dispatch
def forward(self, tensordict: TensorDictBase):
# we want to get an error if the value input is missing, but not the hidden states
defaults = [NO_DEFAULT, None, None]
Expand Down Expand Up @@ -1273,6 +1328,56 @@ def __init__(
self.out_keys = out_keys
self._recurrent_mode = False

def make_python_based(self) -> GRUModule:
"""Transforms the GRU layer in its python-based version.
Returns:
self
"""
if isinstance(self.gru, GRU):
return self
gru = GRU(
input_size=self.gru.input_size,
hidden_size=self.gru.hidden_size,
num_layers=self.gru.num_layers,
bias=self.gru.bias,
dropout=self.gru.dropout,
device="meta",
batch_first=self.gru.batch_first,
bidirectional=self.gru.bidirectional,
)
from tensordict import from_module

from_module(self.gru).to_module(gru)
self.gru = gru
return self

def make_cudnn_based(self) -> GRUModule:
"""Transforms the GRU layer in its CuDNN-based version.
Returns:
self
"""
if isinstance(self.gru, nn.GRU):
return self
gru = nn.GRU(
input_size=self.gru.input_size,
hidden_size=self.gru.hidden_size,
num_layers=self.gru.num_layers,
bias=self.gru.bias,
dropout=self.gru.dropout,
device="meta",
batch_first=self.gru.batch_first,
bidirectional=self.gru.bidirectional,
)
from tensordict import from_module

from_module(self.gru).to_module(gru)
self.gru = gru
return self

def make_tensordict_primer(self):
"""Makes a tensordict primer for the environment.
Expand Down Expand Up @@ -1389,6 +1494,7 @@ def set_recurrent_mode(self, mode: bool = True):
out._recurrent_mode = mode
return out

@dispatch
@set_lazy_legacy(False)
def forward(self, tensordict: TensorDictBase):
# we want to get an error if the value input is missing, but not the hidden states
Expand Down
Loading

0 comments on commit 93abaab

Please sign in to comment.