diff --git a/tensordict/_td.py b/tensordict/_td.py index df6804077..bd2d3142a 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1062,7 +1062,7 @@ def from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None): "Cannot pass both batch_size and batch_dims to `from_dict`." ) - batch_size_set = [] if batch_size is None else batch_size + batch_size_set = torch.Size(()) if batch_size is None else batch_size for key, value in list(input_dict.items()): if isinstance(value, (dict,)): # we don't know if another tensor of smaller size is coming diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 8f02faf13..53df615aa 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -14,7 +14,7 @@ import torch from cloudpickle import dumps as cloudpickle_dumps, loads as cloudpickle_loads from tensordict._td import is_tensor_collection, TensorDictBase -from tensordict._tensordict import unravel_key_list +from tensordict._tensordict import _unravel_key_to_tuple, unravel_key_list from tensordict.functional import make_tensordict from tensordict.nn.functional_modules import ( @@ -255,9 +255,7 @@ def wrapper(_self, *args: Any, **kwargs: Any) -> Any: if isinstance(dest, str): dest = getattr(_self, dest) for key in source: - expected_key = ( - self.separator.join(key) if isinstance(key, tuple) else key - ) + expected_key = self.separator.join(_unravel_key_to_tuple(key)) if len(args): tensordict_values[key] = args[0] args = args[1:]