Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] from_modules expand_identical kwarg #911

Merged
merged 3 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,7 @@ def from_modules(
lock: bool = True,
use_state_dict: bool = False,
lazy_stack: bool = False,
expand_identical: bool = False,
):
"""Retrieves the parameters of several modules for ensebmle learning/feature of expects applications through vmap.

Expand Down Expand Up @@ -1134,6 +1135,9 @@ def from_modules(
or :meth:`~torch.optim.Optimizer.zero_grad` will take longer
to be executed. In general, ``lazy_stack`` should be reserved
to very few use cases.
expand_identical (bool, optional): if ``True`` and the same parameter (same
identity) is being stacked to itself, an expanded version of this parameter
will be returned instead. This argument is ignored when ``lazy_stack=True``.

Examples:
>>> from torch import nn
Expand Down Expand Up @@ -1200,6 +1204,36 @@ def from_modules(
"lasy_stack=True is not compatible with lazy modules."
)
params = LazyStackedTensorDict.lazy_stack(param_list)
elif expand_identical:
from tensordict._torch_func import _stack_uninit_params

# Check the keys
# If not expand_identical, `stack` takes care of that check but
# here we use apply which will ignore keys that are in one TD but not another
sets = [set(param.keys(True, True)) for param in param_list]
for set_ in sets[1:]:
if set_ != sets[0]:
raise ValueError(
f"All key sets must match. "
f"Got {set_.symmetric_difference(sets[0])} in one but not another."
)

def maybe_stack(*params):
param = params[0]
if isinstance(param, UninitializedTensorMixin):
return _stack_uninit_params(params, 0)
if len(set(params)) == 1:
return param.expand((len(params), *param.shape))
result = torch.stack(params)
if isinstance(param, nn.Parameter):
return nn.Parameter(result.detach(), param.requires_grad)
return Buffer(result)

params = param_list[0]._fast_apply(
maybe_stack,
*param_list[1:],
batch_size=torch.Size([len(param_list), *param_list[0].batch_size]),
)
else:
with set_lazy_legacy(False), torch.no_grad():
params = torch.stack(param_list)
Expand Down
2 changes: 2 additions & 0 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def _maybe_make_param_or_buffer(tensor):
and tensor.dtype in (torch.float, torch.double, torch.half)
):
# convert all non-parameters to buffers
# dataptr = tensor.data.data_ptr()
tensor = Buffer(tensor)
# assert tensor.data.data_ptr() == dataptr
return tensor


Expand Down
10 changes: 4 additions & 6 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1853,13 +1853,11 @@ class Buffer(Tensor, metaclass=_ParameterMeta):
def __new__(cls, data=None, requires_grad=False):
if data is None:
data = torch.empty(0)
if type(data) is Tensor or type(data) is Buffer:
# For ease of BC maintenance, keep this path for standard Tensor.
# Eventually (tm), we should change the behavior for standard Tensor to match.
return Tensor._make_subclass(cls, data, requires_grad)

# Path for custom tensors: set a flag on the instance to indicate parameter-ness.
t = data.detach().requires_grad_(requires_grad)
if requires_grad:
t = data.detach().requires_grad_(requires_grad)
else:
t = data
t._is_buffer = True
return t

Expand Down
33 changes: 33 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,39 @@ def get_leaf(leaf):
assert p.grad is None
assert all(param.grad is not None for param in params.values(True, True))

@pytest.mark.parametrize("as_module", [False, True])
def test_from_modules_expand(self, as_module):
empty_module = nn.Sequential(
nn.Linear(3, 3, device="meta"), nn.Linear(3, 4, device="meta")
)
module0 = nn.Linear(3, 3)
modules = [nn.Sequential(module0, nn.Linear(3, 4)) for _ in range(3)]
params = TensorDict.from_modules(
*modules, as_module=as_module, expand_identical=True
)
assert not isinstance(params["0", "weight"], nn.Parameter)
assert params["0", "weight"].data.data_ptr() == module0.weight.data.data_ptr()
assert isinstance(params["1", "weight"], nn.Parameter)
assert (
params["1", "weight"].data.data_ptr()
!= modules[0][1].weight.data.data_ptr()
)

def exec_module(params, x):
with params.to_module(empty_module):
return empty_module(x)

x = torch.zeros(3)
y = torch.vmap(exec_module, (0, None))(params, x)
y.sum().backward()
for k, p in modules[0].named_parameters():
assert p.grad is None if k.startswith("1") else p.grad is not None
assert all(
param.grad is not None
for param in params.values(True, True)
if isinstance(param, nn.Parameter)
)

@pytest.mark.parametrize("as_module", [False, True])
@pytest.mark.parametrize("lazy_stack", [False, True])
@pytest.mark.parametrize("device", get_available_devices())
Expand Down
Loading