Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quality] Fewer recompiles with tensordict #1015

Merged
merged 7 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,6 @@ def has_transfer(self):

_device_recorder = _RecordDeviceTransfer()

_TENSOR_COLLECTION_MEMO = {}


class TensorDictBase(MutableMapping):
"""TensorDictBase is an abstract parent class for TensorDicts, a torch.Tensor data container."""
Expand Down Expand Up @@ -10484,16 +10482,7 @@ def _register_tensor_class(cls):


def _is_tensor_collection(datatype):
out = _TENSOR_COLLECTION_MEMO.get(datatype)
if out is None:
if issubclass(datatype, TensorDictBase):
out = True
elif _is_tensorclass(datatype):
out = True
else:
out = False
_TENSOR_COLLECTION_MEMO[datatype] = out
return out
return issubclass(datatype, TensorDictBase) or _is_tensorclass(datatype)


def is_tensor_collection(datatype: type | Any) -> bool:
Expand Down
14 changes: 13 additions & 1 deletion tensordict/nn/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def _call(
return result

if not self._has_cuda or self.counter < self._warmup - 1:
# We must clone the data because providing non-contiguous data will fail later when we clone
tensordict.apply(self._clone, out=tensordict)
if self._has_cuda:
torch.cuda.synchronize()
with self._warmup_stream_cm:
Expand All @@ -243,6 +245,7 @@ def _call(
)

tree_map(self._check_non_tensor, (args, kwargs))
tensordict.apply(self._clone, out=tensordict)
self._tensordict = tensordict.copy()
if tensordict_out is not None:
td_out_save = tensordict_out.copy()
Expand Down Expand Up @@ -307,6 +310,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
return result

if not self._has_cuda or self.counter < self._warmup - 1:
args, kwargs = tree_map(self._clone, (args, kwargs))
if self._has_cuda:
torch.cuda.synchronize()
with self._warmup_stream_cm:
Expand All @@ -316,7 +320,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
self.counter += self._has_cuda
return out
else:
self._args, self._kwargs = tree_map(
args, kwargs = self._args, self._kwargs = tree_map(
self._check_device_and_clone, (args, kwargs)
)

Expand Down Expand Up @@ -371,6 +375,14 @@ def _check_device_and_clone(cls, x):
return x.clone()
return x

@classmethod
def _clone(cls, x):
if isinstance(x, torch.Tensor) or is_tensor_collection(x):
if x.requires_grad:
raise RuntimeError(cls._REQUIRES_GRAD_ERROR)
return x.clone()
return x

@classmethod
def _check_device_and_grad(cls, x):
if isinstance(x, torch.Tensor):
Expand Down
Loading