diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index 952563f5ca..7743aec3f1 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -94,7 +94,7 @@ def _find_lib(name, paths): from torch_tensorrt import fx -if version.parse(torch.__version__) >= version.parse("2.1.dev"): +if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): from torch_tensorrt import dynamo from torch_tensorrt.dynamo import backend diff --git a/py/torch_tensorrt/_util.py b/py/torch_tensorrt/_util.py index 28de07fe63..6f7e1a6c83 100644 --- a/py/torch_tensorrt/_util.py +++ b/py/torch_tensorrt/_util.py @@ -30,3 +30,11 @@ def get_build_info() -> str: def set_device(gpu_id): _C.set_device(gpu_id) + + +def sanitized_torch_version() -> str: + return ( + torch.__version__ + if ".nv" not in torch.__version__ + else torch.__version__.split(".nv")[0] + ) diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 31f0b61fff..ecd4384155 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -1,6 +1,6 @@ -import torch from packaging import version +from torch_tensorrt._util import sanitized_torch_version -if version.parse(torch.__version__) >= version.parse("2.1.dev"): +if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): from torch_tensorrt.dynamo import fx_ts_compat from .backend import compile diff --git a/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py b/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py index 431eaf4469..cd5548b30a 100644 --- a/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py +++ b/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py @@ -11,6 +11,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase from torch_tensorrt.fx.passes.lower_basic_pass import fix_reshape_batch_dim from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer +from torch_tensorrt._util import sanitized_torch_version _LOGGER = logging.getLogger(__name__) @@ -43,7 +44,9 @@ def forward(self, x, y): %reshape : [num_users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.reshape](args = (), kwargs = {input: %y, acc_out_ty: ((%getitem_1, -1, 3), None, None, None, None, None, None)}) return reshape """ - if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"): + if version.parse(sanitized_torch_version()) < version.parse( + "2.1.0.dev20230620" + ): expected_graph = expected_graph.replace("num_users", "#users") assert ( diff --git a/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py b/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py index 5dc7d8572c..2dc1c404ee 100644 --- a/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py +++ b/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py @@ -8,6 +8,8 @@ import torch.nn as nn import torch_tensorrt.fx.passes.remove_duplicate_output_args as dedup +from torch_tensorrt._util import sanitized_torch_version + from torch.testing._internal.common_utils import run_tests, TestCase _LOGGER = logging.getLogger(__name__) @@ -57,7 +59,9 @@ def is_leaf_module(self, m, qn): return add """.strip() - if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"): + if version.parse(sanitized_torch_version()) < version.parse( + "2.1.0.dev20230620" + ): ttop_graph_expected = ttop_graph_expected.replace("num_users", "#users") assert ( @@ -71,7 +75,9 @@ def is_leaf_module(self, m, qn): return (x,) """.strip() - if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"): + if version.parse(sanitized_torch_version()) < version.parse( + "2.1.0.dev20230620" + ): ttop_a_graph_expected = ttop_a_graph_expected.replace("num_users", "#users") assert ( diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py index 217aee973e..248ec3f920 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py @@ -3,10 +3,11 @@ from contextlib import contextmanager from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union from packaging import version +from torch_tensorrt._util import sanitized_torch_version import torch -if version.parse(torch.__version__) >= version.parse("2.dev"): +if version.parse(sanitized_torch_version()) >= version.parse("2.dev"): import torch._dynamo as torchdynamo from torch.fx.passes.infra.pass_base import PassResult diff --git a/py/torch_tensorrt/fx/utils.py b/py/torch_tensorrt/fx/utils.py index e70fc862d0..859529d861 100644 --- a/py/torch_tensorrt/fx/utils.py +++ b/py/torch_tensorrt/fx/utils.py @@ -12,7 +12,7 @@ replace_op_with_indices, run_const_fold, ) - +from torch_tensorrt._util import sanitized_torch_version from .types import Shape, TRTDataType @@ -160,7 +160,7 @@ def nested_decorator(f: Callable): def function_wrapper(*args, **kwargs): # Parse minimum and current Torch versions min_version = version.parse(min_torch_version) - current_version = version.parse(torch.__version__) + current_version = version.parse(sanitized_torch_version()) if current_version < min_version: raise AssertionError(