2828from torch .utils ._pytree import SUPPORTED_NODES , tree_map
2929
3030try :
31- from torch .utils ._pytree import tree_leaves
31+ from torch .utils ._pytree import tree_flatten , tree_leaves , tree_unflatten
3232except ImportError :
33- from torch .utils ._pytree import tree_flatten
33+ from torch .utils ._pytree import tree_flatten , tree_unflatten
3434
3535 def tree_leaves (pytree ):
3636 """Torch 2.0 compatible version of tree_leaves."""
@@ -293,11 +293,13 @@ def check_tensor_id(name, t0, t1):
293293
294294 def _call (* args : torch .Tensor , ** kwargs : torch .Tensor ):
295295 if self .counter >= self ._warmup :
296- tree_map (
297- lambda x , y : x .copy_ (y , non_blocking = True ),
298- (self ._args , self ._kwargs ),
299- (args , kwargs ),
300- )
296+ srcs , dests = [], []
297+ for arg_src , arg_dest in zip (
298+ tree_leaves ((args , kwargs )), self ._flat_tree
299+ ):
300+ self ._maybe_copy_onto_ (arg_src , arg_dest , srcs , dests )
301+ if dests :
302+ torch ._foreach_copy_ (dests , srcs )
301303 torch .cuda .synchronize ()
302304 self .graph .replay ()
303305 if self ._return_unchanged == "clone" :
@@ -322,8 +324,13 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
322324 self .counter += self ._has_cuda
323325 return out
324326 else :
325- args , kwargs = self ._args , self ._kwargs = tree_map (
326- self ._check_device_and_clone , (args , kwargs )
327+ self ._flat_tree , self ._tree_spec = tree_flatten ((args , kwargs ))
328+
329+ self ._flat_tree = tuple (
330+ self ._check_device_and_clone (arg ) for arg in self ._flat_tree
331+ )
332+ args , kwargs = self ._args , self ._kwargs = tree_unflatten (
333+ self ._flat_tree , self ._tree_spec
327334 )
328335
329336 torch .cuda .synchronize ()
@@ -360,6 +367,27 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
360367 _call_func = functools .wraps (self .module )(_call )
361368 self ._call_func = _call_func
362369
370+ @staticmethod
371+ def _maybe_copy_onto_ (src , dest , srcs , dests ):
372+ if isinstance (src , torch .Tensor ):
373+ srcs .append (src )
374+ dests .append (dest )
375+ return
376+ if is_tensor_collection (src ):
377+ dest .copy_ (src )
378+ return
379+ isdiff = False
380+ try :
381+ isdiff = src != dest
382+ except Exception as err :
383+ raise RuntimeError (
384+ "Couldn't assess input value. Make sure your function only takes tensor inputs or that "
385+ "the input value can be easily checked and is constant. For a better efficiency, avoid "
386+ "passing non-tensor inputs to your function."
387+ ) from err
388+ if isdiff :
389+ raise ValueError ("Varying inputs must be torch.Tensor subclasses." )
390+
363391 @classmethod
364392 def _check_device_and_clone (cls , x ):
365393 if isinstance (x , torch .Tensor ) or is_tensor_collection (x ):
0 commit comments