22import contextlib
33
44import torch
5- import torch .distributed .tensor as dt
65from torch ._subclasses .fake_tensor import FakeTensor , FakeTensorMode
7- from torch .distributed .tensor ._dtensor_spec import DTensorSpec
86from torch .utils ._dtype_abbrs import dtype_abbrs
97from torch .utils ._python_dispatch import _get_current_dispatch_mode , TorchDispatchMode
108from torch .utils ._pytree import tree_map
@@ -29,7 +27,7 @@ def _stringify_placement(placement) -> str:
2927
3028def _tensor_debug_string (tensor ) -> str :
3129 """Convert tensor to debug string representation."""
32- if isinstance (tensor , dt .DTensor ):
30+ if isinstance (tensor , torch . distributed . tensor .DTensor ):
3331 # omitted device mesh
3432 return f"dt: { dtype_abbrs [tensor .dtype ]} { _stringify_shape (tensor .shape )} { _stringify_placement (tensor .placements )} "
3533 elif isinstance (tensor , FakeTensor ):
@@ -41,6 +39,8 @@ def _tensor_debug_string(tensor) -> str:
4139
4240
4341def _arg_to_str (arg ) -> str :
42+ from torch .distributed .tensor ._dtensor_spec import DTensorSpec
43+
4444 def to_str (x ):
4545 if isinstance (x , torch .Tensor ):
4646 return _tensor_debug_string (x )
@@ -86,6 +86,7 @@ def __init__(
8686 record_realtensor = True ,
8787 ):
8888 super ().__init__ ()
89+ import torch .distributed .tensor # noqa: F401
8990
9091 self .record_torchfunction = record_torchfunction
9192 self .record_faketensor = record_faketensor
@@ -111,7 +112,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
111112 kwargs = {}
112113
113114 # Record the operation with its call depth
114- if dt .DTensor in types :
115+ if torch . distributed . tensor .DTensor in types :
115116 self .operators .append ((func , args , kwargs , self .call_depth ))
116117 return NotImplemented
117118 elif FakeTensor in types or isinstance (
0 commit comments