diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 455f93ee4..a587360d7 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -414,10 +414,12 @@ def _stack( out: T | None = None, strict: bool = False, contiguous: bool = False, - maybe_dense_stack: bool = False, + maybe_dense_stack: bool | None = None, ) -> T: if not len(list_of_tensordicts): raise RuntimeError("list_of_tensordicts cannot be empty") + if maybe_dense_stack is None: + maybe_dense_stack = lazy_legacy() is_tc = any(is_tensorclass(td) for td in list_of_tensordicts) if all(is_non_tensor(td) for td in list_of_tensordicts): from tensordict.tensorclass import NonTensorData @@ -457,7 +459,11 @@ def _stack( if not _lazy_legacy and not contiguous: if maybe_dense_stack: with set_lazy_legacy(True): - return _stack(list_of_tensordicts, dim=dim) + return _stack( + list_of_tensordicts, + dim=dim, + maybe_dense_stack=maybe_dense_stack, + ) else: raise RuntimeError( "The sets of keys in the tensordicts to stack are exclusive. " @@ -490,7 +496,11 @@ def _stack( dim = dim - 1 return LazyStackedTensorDict( *[ - _stack(list(subtds), dim=dim) + _stack( + list(subtds), + dim=dim, + maybe_dense_stack=maybe_dense_stack, + ) for subtds in _zip_strict( *[td.tensordicts for td in list_of_tensordicts] ) @@ -540,7 +550,11 @@ def _stack( # Nested tensors will require a lazy stack if maybe_dense_stack: with set_lazy_legacy(True): - return _stack(list_of_tensordicts, dim=dim) + return _stack( + list_of_tensordicts, + dim=dim, + maybe_dense_stack=maybe_dense_stack, + ) else: raise RuntimeError( f"The shapes of the tensors to stack is incompatible: {new_tensor_shape} vs {tensor_shape} for key {key}."