Skip to content

Commit ea3b91b

Browse files
SherlockNoMadcleonard530
authored andcommitted
Lazy import to avoid circular import issue for DebugMode (pytorch#163381)
as title. Pull Request resolved: pytorch#163381 Approved by: https://github.com/dolpm
1 parent 4de527c commit ea3b91b

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

torch/utils/debug_mode.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
import contextlib
33

44
import torch
5-
import torch.distributed.tensor as dt
65
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
7-
from torch.distributed.tensor._dtensor_spec import DTensorSpec
86
from torch.utils._dtype_abbrs import dtype_abbrs
97
from torch.utils._python_dispatch import _get_current_dispatch_mode, TorchDispatchMode
108
from torch.utils._pytree import tree_map
@@ -29,7 +27,7 @@ def _stringify_placement(placement) -> str:
2927

3028
def _tensor_debug_string(tensor) -> str:
3129
"""Convert tensor to debug string representation."""
32-
if isinstance(tensor, dt.DTensor):
30+
if isinstance(tensor, torch.distributed.tensor.DTensor):
3331
# omitted device mesh
3432
return f"dt: {dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_placement(tensor.placements)}"
3533
elif isinstance(tensor, FakeTensor):
@@ -41,6 +39,8 @@ def _tensor_debug_string(tensor) -> str:
4139

4240

4341
def _arg_to_str(arg) -> str:
42+
from torch.distributed.tensor._dtensor_spec import DTensorSpec
43+
4444
def to_str(x):
4545
if isinstance(x, torch.Tensor):
4646
return _tensor_debug_string(x)
@@ -86,6 +86,7 @@ def __init__(
8686
record_realtensor=True,
8787
):
8888
super().__init__()
89+
import torch.distributed.tensor # noqa: F401
8990

9091
self.record_torchfunction = record_torchfunction
9192
self.record_faketensor = record_faketensor
@@ -111,7 +112,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
111112
kwargs = {}
112113

113114
# Record the operation with its call depth
114-
if dt.DTensor in types:
115+
if torch.distributed.tensor.DTensor in types:
115116
self.operators.append((func, args, kwargs, self.call_depth))
116117
return NotImplemented
117118
elif FakeTensor in types or isinstance(

0 commit comments

Comments
 (0)