Skip to content

Commit

Permalink
[Feature] Warn when reset_parameters_recursive is a no-op (#693)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
  • Loading branch information
matteobettini and vmoens authored Feb 26, 2024
1 parent b5f6c17 commit 9601868
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 20 deletions.
17 changes: 11 additions & 6 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,17 @@

import torch
from cloudpickle import dumps as cloudpickle_dumps, loads as cloudpickle_loads

from tensordict._td import is_tensor_collection, TensorDictBase
from tensordict._tensordict import _unravel_key_to_tuple, unravel_key_list
from tensordict.functional import make_tensordict

from tensordict.nn.functional_modules import (
_swap_state,
extract_weights_and_buffers,
is_functional,
make_functional,
repopulate_module,
)

from tensordict.nn.utils import (
_auto_make_functional,
_dispatch_td_nn_modules,
Expand Down Expand Up @@ -248,7 +247,6 @@ def __call__(self, func: Callable) -> Callable:

@functools.wraps(func)
def wrapper(_self, *args: Any, **kwargs: Any) -> Any:

if not _dispatch_td_nn_modules():
return func(_self, *args, **kwargs)

Expand Down Expand Up @@ -830,7 +828,11 @@ def reset_parameters_recursive(
False
"""
if parameters is None:
self._reset_parameters(self)
any_reset = self._reset_parameters(self)
if not any_reset:
warnings.warn(
"reset_parameters_recursive was called without the parameters argument and did not find any parameters to reset"
)
return
elif parameters.ndim:
raise RuntimeError(
Expand Down Expand Up @@ -868,13 +870,16 @@ def reset_parameters_recursive(
self._reset_parameters(self)
return sanitized_parameters

def _reset_parameters(self, module: nn.Module) -> None:
def _reset_parameters(self, module: nn.Module) -> bool:
any_reset = False
for child in module.children():
if isinstance(child, nn.Module):
self._reset_parameters(child)
any_reset |= self._reset_parameters(child)

if hasattr(child, "reset_parameters"):
child.reset_parameters()
any_reset |= True
return any_reset


class TensorDictModule(TensorDictModuleBase):
Expand Down
56 changes: 42 additions & 14 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
set_skip_existing,
skip_existing,
)
from torch import distributions as d, nn
from torch import distributions, nn
from torch.distributions import Normal
from torch.utils._pytree import tree_map

Expand Down Expand Up @@ -165,6 +165,16 @@ def test_reset(self):
seq.reset_parameters_recursive()
assert torch.all(old_param != net[0][0].weight.data)

def test_reset_warning(self):
torch.manual_seed(0)
net = nn.ModuleList([nn.Tanh(), nn.ReLU()])
module = TensorDictModule(net, in_keys=["in"], out_keys=["out"])
with pytest.warns(
UserWarning,
match="reset_parameters_recursive was called without the parameters argument and did not find any parameters to reset",
):
module.reset_parameters_recursive()

@pytest.mark.parametrize(
"net",
[
Expand Down Expand Up @@ -335,7 +345,7 @@ def test_stateful_probabilistic_kwargs(
net = TensorDictModule(module=net, in_keys=in_keys, out_keys=out_keys)

kwargs = {
"distribution_class": torch.distributions.Uniform,
"distribution_class": distributions.Uniform,
"distribution_kwargs": {"high": max_dist},
}
if out_keys == ["low"]:
Expand Down Expand Up @@ -2666,7 +2676,6 @@ def test_module_buffer():
],
)
def test_nested_keys_probabilistic_delta(log_prob_key):

policy_module = TensorDictModule(
nn.Linear(1, 1), in_keys=[("data", "states")], out_keys=[("data", "param")]
)
Expand Down Expand Up @@ -2711,7 +2720,6 @@ def test_nested_keys_probabilistic_delta(log_prob_key):
],
)
def test_nested_keys_probabilistic_normal(log_prob_key):

loc_module = TensorDictModule(
nn.Linear(1, 1),
in_keys=[("data", "states")],
Expand Down Expand Up @@ -3083,7 +3091,10 @@ def test_const(self):
)
dist = CompositeDistribution(
params,
distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical},
distribution_map={
"cont": distributions.Normal,
("nested", "disc"): distributions.Categorical,
},
)
assert dist.batch_shape == params.shape
assert len(dist.dists) == 2
Expand All @@ -3100,7 +3111,10 @@ def test_sample(self):
)
dist = CompositeDistribution(
params,
distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical},
distribution_map={
"cont": distributions.Normal,
("nested", "disc"): distributions.Categorical,
},
)
sample = dist.sample()
assert sample.shape == params.shape
Expand All @@ -3121,8 +3135,8 @@ def test_rsample(self):
dist = CompositeDistribution(
params,
distribution_map={
"cont": d.Normal,
("nested", "disc"): d.RelaxedOneHotCategorical,
"cont": distributions.Normal,
("nested", "disc"): distributions.RelaxedOneHotCategorical,
},
extra_kwargs={("nested", "disc"): {"temperature": torch.tensor(1.0)}},
)
Expand All @@ -3147,8 +3161,8 @@ def test_log_prob(self):
dist = CompositeDistribution(
params,
distribution_map={
"cont": d.Normal,
("nested", "disc"): d.RelaxedOneHotCategorical,
"cont": distributions.Normal,
("nested", "disc"): distributions.RelaxedOneHotCategorical,
},
extra_kwargs={("nested", "disc"): {"temperature": torch.tensor(1.0)}},
)
Expand All @@ -3172,7 +3186,11 @@ def test_cdf(self):
[3],
)
dist = CompositeDistribution(
params, distribution_map={"cont": d.Normal, ("nested", "cont"): d.Normal}
params,
distribution_map={
"cont": distributions.Normal,
("nested", "cont"): distributions.Normal,
},
)
sample = dist.rsample((4,))
sample = dist.cdf(sample)
Expand All @@ -3194,7 +3212,11 @@ def test_icdf(self):
[3],
)
dist = CompositeDistribution(
params, distribution_map={"cont": d.Normal, ("nested", "cont"): d.Normal}
params,
distribution_map={
"cont": distributions.Normal,
("nested", "cont"): distributions.Normal,
},
)
sample = dist.rsample((4,))
sample = dist.cdf(sample)
Expand Down Expand Up @@ -3225,7 +3247,10 @@ def test_prob_module(self, interaction, return_log_prob):
)
in_keys = ["params"]
out_keys = ["cont", ("nested", "cont")]
distribution_map = {"cont": d.Normal, ("nested", "cont"): d.Normal}
distribution_map = {
"cont": distributions.Normal,
("nested", "cont"): distributions.Normal,
}
module = ProbabilisticTensorDictModule(
in_keys=in_keys,
out_keys=out_keys,
Expand Down Expand Up @@ -3275,7 +3300,10 @@ def test_prob_module_seq(self, interaction, return_log_prob):
)
in_keys = ["params"]
out_keys = ["cont", ("nested", "cont")]
distribution_map = {"cont": d.Normal, ("nested", "cont"): d.Normal}
distribution_map = {
"cont": distributions.Normal,
("nested", "cont"): distributions.Normal,
}
backbone = TensorDictModule(lambda: None, in_keys=[], out_keys=[])
module = ProbabilisticTensorDictSequential(
backbone,
Expand Down

0 comments on commit 9601868

Please sign in to comment.