Skip to content

Commit 7e45bcc

Browse files
author
Vincent Moens
committed
[BugFix] Compatibility with non-tensor inputs in CudaGraphModule
ghstack-source-id: f5a4845 Pull Request resolved: #1039
1 parent 49d226c commit 7e45bcc

File tree

2 files changed

+50
-9
lines changed

2 files changed

+50
-9
lines changed

tensordict/nn/cudagraphs.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
from torch.utils._pytree import SUPPORTED_NODES, tree_map
2929

3030
try:
31-
from torch.utils._pytree import tree_leaves
31+
from torch.utils._pytree import tree_flatten, tree_leaves, tree_unflatten
3232
except ImportError:
33-
from torch.utils._pytree import tree_flatten
33+
from torch.utils._pytree import tree_flatten, tree_unflatten
3434

3535
def tree_leaves(pytree):
3636
"""Torch 2.0 compatible version of tree_leaves."""
@@ -293,11 +293,13 @@ def check_tensor_id(name, t0, t1):
293293

294294
def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
295295
if self.counter >= self._warmup:
296-
tree_map(
297-
lambda x, y: x.copy_(y, non_blocking=True),
298-
(self._args, self._kwargs),
299-
(args, kwargs),
300-
)
296+
srcs, dests = [], []
297+
for arg_src, arg_dest in zip(
298+
tree_leaves((args, kwargs)), self._flat_tree
299+
):
300+
self._maybe_copy_onto_(arg_src, arg_dest, srcs, dests)
301+
if dests:
302+
torch._foreach_copy_(dests, srcs)
301303
torch.cuda.synchronize()
302304
self.graph.replay()
303305
if self._return_unchanged == "clone":
@@ -322,8 +324,13 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
322324
self.counter += self._has_cuda
323325
return out
324326
else:
325-
args, kwargs = self._args, self._kwargs = tree_map(
326-
self._check_device_and_clone, (args, kwargs)
327+
self._flat_tree, self._tree_spec = tree_flatten((args, kwargs))
328+
329+
self._flat_tree = tuple(
330+
self._check_device_and_clone(arg) for arg in self._flat_tree
331+
)
332+
args, kwargs = self._args, self._kwargs = tree_unflatten(
333+
self._flat_tree, self._tree_spec
327334
)
328335

329336
torch.cuda.synchronize()
@@ -360,6 +367,27 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
360367
_call_func = functools.wraps(self.module)(_call)
361368
self._call_func = _call_func
362369

370+
@staticmethod
371+
def _maybe_copy_onto_(src, dest, srcs, dests):
372+
if isinstance(src, torch.Tensor):
373+
srcs.append(src)
374+
dests.append(dest)
375+
return
376+
if is_tensor_collection(src):
377+
dest.copy_(src)
378+
return
379+
isdiff = False
380+
try:
381+
isdiff = src != dest
382+
except Exception as err:
383+
raise RuntimeError(
384+
"Couldn't assess input value. Make sure your function only takes tensor inputs or that "
385+
"the input value can be easily checked and is constant. For a better efficiency, avoid "
386+
"passing non-tensor inputs to your function."
387+
) from err
388+
if isdiff:
389+
raise ValueError("Varying inputs must be torch.Tensor subclasses.")
390+
363391
@classmethod
364392
def _check_device_and_clone(cls, x):
365393
if isinstance(x, torch.Tensor) or is_tensor_collection(x):

test/test_compile.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,6 +1056,19 @@ def test_td_input_non_tdmodule(self, compiled):
10561056
if i == 5:
10571057
assert not func._is_tensordict_module
10581058

1059+
def test_td_input_non_tdmodule_nontensor(self, compiled):
1060+
func = lambda x, y: x + y
1061+
func = self._make_cudagraph(func, compiled)
1062+
for i in range(10):
1063+
assert func(torch.zeros(()), 1.0) == 1.0
1064+
if i == 5:
1065+
assert not func._is_tensordict_module
1066+
if torch.cuda.is_available():
1067+
with pytest.raises(
1068+
ValueError, match="Varying inputs must be torch.Tensor subclasses."
1069+
):
1070+
func(torch.zeros(()), 2.0)
1071+
10591072

10601073
if __name__ == "__main__":
10611074
args, unknown = argparse.ArgumentParser().parse_known_args()

0 commit comments

Comments
 (0)