Skip to content

Commit

Permalink
[BugFix] Propagate maybe_dense_stack in _stack
Browse files Browse the repository at this point in the history
ghstack-source-id: a1cb1dee6c3665bc164e8c87585114e80650798c
Pull Request resolved: #1036
  • Loading branch information
vmoens committed Oct 8, 2024
1 parent d147be4 commit 1e32195
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,12 @@ def _stack(
out: T | None = None,
strict: bool = False,
contiguous: bool = False,
maybe_dense_stack: bool = False,
maybe_dense_stack: bool | None = None,
) -> T:
if not len(list_of_tensordicts):
raise RuntimeError("list_of_tensordicts cannot be empty")
if maybe_dense_stack is None:
maybe_dense_stack = lazy_legacy()
is_tc = any(is_tensorclass(td) for td in list_of_tensordicts)
if all(is_non_tensor(td) for td in list_of_tensordicts):
from tensordict.tensorclass import NonTensorData
Expand Down Expand Up @@ -457,7 +459,11 @@ def _stack(
if not _lazy_legacy and not contiguous:
if maybe_dense_stack:
with set_lazy_legacy(True):
return _stack(list_of_tensordicts, dim=dim)
return _stack(
list_of_tensordicts,
dim=dim,
maybe_dense_stack=maybe_dense_stack,
)
else:
raise RuntimeError(
"The sets of keys in the tensordicts to stack are exclusive. "
Expand Down Expand Up @@ -490,7 +496,11 @@ def _stack(
dim = dim - 1
return LazyStackedTensorDict(
*[
_stack(list(subtds), dim=dim)
_stack(
list(subtds),
dim=dim,
maybe_dense_stack=maybe_dense_stack,
)
for subtds in _zip_strict(
*[td.tensordicts for td in list_of_tensordicts]
)
Expand Down Expand Up @@ -540,7 +550,11 @@ def _stack(
# Nested tensors will require a lazy stack
if maybe_dense_stack:
with set_lazy_legacy(True):
return _stack(list_of_tensordicts, dim=dim)
return _stack(
list_of_tensordicts,
dim=dim,
maybe_dense_stack=maybe_dense_stack,
)
else:
raise RuntimeError(
f"The shapes of the tensors to stack is incompatible: {new_tensor_shape} vs {tensor_shape} for key {key}."
Expand Down

0 comments on commit 1e32195

Please sign in to comment.