Skip to content

Commit 77033c8

Browse files
move is_thor() is_tegra_platform() from dynamo.utils to utils to avoid circular import (#3851)
1 parent 607a7b8 commit 77033c8

File tree

12 files changed

+17
-23
lines changed

12 files changed

+17
-23
lines changed

py/torch_tensorrt/_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ def is_tensorrt_version_supported(min_version: str) -> bool:
6565
return False
6666

6767

68+
def is_tegra_platform() -> bool:
69+
if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]:
70+
return True
71+
return False
72+
73+
6874
def is_thor() -> bool:
6975
if torch.cuda.get_device_capability() in [(11, 0)]:
7076
return True

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch._dynamo.backends.common import aot_autograd
1111
from torch._dynamo.utils import detect_fake_mode
1212
from torch._functorch.aot_autograd import aot_export_joint_simple
13+
from torch_tensorrt._utils import is_tegra_platform
1314
from torch_tensorrt.dynamo import CompilationSettings
1415
from torch_tensorrt.dynamo._compiler import compile_module
1516
from torch_tensorrt.dynamo.lowering import (
@@ -20,7 +21,6 @@
2021
repair_input_aliasing,
2122
)
2223
from torch_tensorrt.dynamo.utils import (
23-
is_tegra_platform,
2424
parse_dynamo_kwargs,
2525
prepare_inputs,
2626
set_log_level,

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from typing import Any, Callable, Optional, Sequence, Union
33

44
import torch
5+
from torch_tensorrt._utils import is_tegra_platform
56
from torch_tensorrt.dynamo._settings import CompilationSettings
6-
from torch_tensorrt.dynamo.utils import is_tegra_platform
77

88
from .complex_graph_rewrite import complex_graph_detection
99
from .constant_folding import constant_fold

py/torch_tensorrt/dynamo/utils.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -853,15 +853,3 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]
853853
f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple/list of torch.fx.node.Node"
854854
)
855855
return output_dtypes
856-
857-
858-
def is_tegra_platform() -> bool:
859-
if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]:
860-
return True
861-
return False
862-
863-
864-
def is_thor() -> bool:
865-
if torch.cuda.get_device_capability() in [(11, 0)]:
866-
return True
867-
return False

tests/py/dynamo/conversion/test_arange_aten.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch_tensorrt
66
from parameterized import parameterized
77
from torch.testing._internal.common_utils import run_tests
8-
from torch_tensorrt.dynamo.utils import is_tegra_platform, is_thor
8+
from torch_tensorrt._utils import is_tegra_platform, is_thor
99

1010
from .harness import DispatchTestCase
1111

tests/py/dynamo/conversion/test_cumsum_aten.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch_tensorrt
66
from parameterized import parameterized
77
from torch.testing._internal.common_utils import run_tests
8-
from torch_tensorrt.dynamo.utils import is_tegra_platform, is_thor
8+
from torch_tensorrt._utils import is_tegra_platform, is_thor
99

1010
from .harness import DispatchTestCase
1111

tests/py/dynamo/conversion/test_index_aten.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from parameterized import parameterized
77
from torch.testing._internal.common_utils import run_tests
88
from torch_tensorrt import Input
9-
from torch_tensorrt.dynamo.utils import is_tegra_platform, is_thor
9+
from torch_tensorrt._utils import is_tegra_platform, is_thor
1010

1111
from .harness import DispatchTestCase
1212

tests/py/dynamo/conversion/test_nonzero_aten.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from parameterized import parameterized
77
from torch.testing._internal.common_utils import run_tests
88
from torch_tensorrt import Input
9-
from torch_tensorrt.dynamo.utils import is_tegra_platform, is_thor
9+
from torch_tensorrt._utils import is_tegra_platform, is_thor
1010

1111
from .harness import DispatchTestCase
1212

tests/py/dynamo/conversion/test_sym_size.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch.nn as nn
55
from parameterized import parameterized
66
from torch.testing._internal.common_utils import run_tests
7-
from torch_tensorrt.dynamo.utils import is_thor
7+
from torch_tensorrt._utils import is_thor
88

99
from .harness import DispatchTestCase
1010

tests/py/dynamo/runtime/test_000_compilation_settings.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import torch
55
import torch_tensorrt
66
from torch.testing._internal.common_utils import TestCase, run_tests
7-
from torch_tensorrt._utils import is_tensorrt_version_supported
8-
from torch_tensorrt.dynamo.utils import is_tegra_platform
7+
from torch_tensorrt._utils import is_tegra_platform, is_tensorrt_version_supported
98

109
from packaging.version import Version
1110

0 commit comments

Comments
 (0)