From 38a2c4cf1694d39d5cf0e17df4db07300d0955fd Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 17 Jul 2023 10:54:56 -0700 Subject: [PATCH 1/2] fix: Repair version checking system for Torch - Move _C utilities into `ts` directory - Address version parsing issue for NV versions of Torch - Add specialized check for NV Torch versions such as `2.0.0.nv23.05` --- py/torch_tensorrt/__init__.py | 5 +++-- py/torch_tensorrt/_utils.py | 9 +++++++++ py/torch_tensorrt/dynamo/__init__.py | 4 ++-- .../fx/test/passes/test_fix_reshape_batch_dim.py | 5 ++++- .../test/passes/test_remove_duplicate_output_args.py | 10 ++++++++-- .../fx/tracer/dispatch_tracer/aten_tracer.py | 3 ++- py/torch_tensorrt/fx/utils.py | 4 ++-- py/torch_tensorrt/{ => ts}/_util.py | 0 8 files changed, 30 insertions(+), 10 deletions(-) create mode 100644 py/torch_tensorrt/_utils.py rename py/torch_tensorrt/{ => ts}/_util.py (100%) diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index 952563f5ca..5470ee8458 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -84,7 +84,7 @@ def _find_lib(name, paths): import torch from torch_tensorrt._compile import * -from torch_tensorrt._util import * +from torch_tensorrt.ts._util import * from torch_tensorrt import ts from torch_tensorrt import ptq from torch_tensorrt._enums import * @@ -93,8 +93,9 @@ def _find_lib(name, paths): from torch_tensorrt._Device import Device from torch_tensorrt import fx +from torch_tensorrt._utils import sanitized_torch_version -if version.parse(torch.__version__) >= version.parse("2.1.dev"): +if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): from torch_tensorrt import dynamo from torch_tensorrt.dynamo import backend diff --git a/py/torch_tensorrt/_utils.py b/py/torch_tensorrt/_utils.py new file mode 100644 index 0000000000..f0ce6507d9 --- /dev/null +++ b/py/torch_tensorrt/_utils.py @@ -0,0 +1,9 @@ +import torch + + +def sanitized_torch_version() -> str: + return ( + torch.__version__ + if ".nv" not in torch.__version__ + else torch.__version__.split(".nv")[0] + ) diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 31f0b61fff..0a40b118bb 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -1,6 +1,6 @@ -import torch from packaging import version +from torch_tensorrt._utils import sanitized_torch_version -if version.parse(torch.__version__) >= version.parse("2.1.dev"): +if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): from torch_tensorrt.dynamo import fx_ts_compat from .backend import compile diff --git a/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py b/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py index 431eaf4469..8cb79f4958 100644 --- a/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py +++ b/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py @@ -11,6 +11,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase from torch_tensorrt.fx.passes.lower_basic_pass import fix_reshape_batch_dim from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer +from torch_tensorrt._utils import sanitized_torch_version _LOGGER = logging.getLogger(__name__) @@ -43,7 +44,9 @@ def forward(self, x, y): %reshape : [num_users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.reshape](args = (), kwargs = {input: %y, acc_out_ty: ((%getitem_1, -1, 3), None, None, None, None, None, None)}) return reshape """ - if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"): + if version.parse(sanitized_torch_version()) < version.parse( + "2.1.0.dev20230620" + ): expected_graph = expected_graph.replace("num_users", "#users") assert ( diff --git a/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py b/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py index 5dc7d8572c..2cc97c46be 100644 --- a/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py +++ b/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py @@ -8,6 +8,8 @@ import torch.nn as nn import torch_tensorrt.fx.passes.remove_duplicate_output_args as dedup +from torch_tensorrt._utils import sanitized_torch_version + from torch.testing._internal.common_utils import run_tests, TestCase _LOGGER = logging.getLogger(__name__) @@ -57,7 +59,9 @@ def is_leaf_module(self, m, qn): return add """.strip() - if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"): + if version.parse(sanitized_torch_version()) < version.parse( + "2.1.0.dev20230620" + ): ttop_graph_expected = ttop_graph_expected.replace("num_users", "#users") assert ( @@ -71,7 +75,9 @@ def is_leaf_module(self, m, qn): return (x,) """.strip() - if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"): + if version.parse(sanitized_torch_version()) < version.parse( + "2.1.0.dev20230620" + ): ttop_a_graph_expected = ttop_a_graph_expected.replace("num_users", "#users") assert ( diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py index 217aee973e..a23e824a66 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py @@ -3,10 +3,11 @@ from contextlib import contextmanager from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union from packaging import version +from torch_tensorrt._utils import sanitized_torch_version import torch -if version.parse(torch.__version__) >= version.parse("2.dev"): +if version.parse(sanitized_torch_version()) >= version.parse("2.dev"): import torch._dynamo as torchdynamo from torch.fx.passes.infra.pass_base import PassResult diff --git a/py/torch_tensorrt/fx/utils.py b/py/torch_tensorrt/fx/utils.py index e70fc862d0..ea622ac476 100644 --- a/py/torch_tensorrt/fx/utils.py +++ b/py/torch_tensorrt/fx/utils.py @@ -12,7 +12,7 @@ replace_op_with_indices, run_const_fold, ) - +from torch_tensorrt._utils import sanitized_torch_version from .types import Shape, TRTDataType @@ -160,7 +160,7 @@ def nested_decorator(f: Callable): def function_wrapper(*args, **kwargs): # Parse minimum and current Torch versions min_version = version.parse(min_torch_version) - current_version = version.parse(torch.__version__) + current_version = version.parse(sanitized_torch_version()) if current_version < min_version: raise AssertionError( diff --git a/py/torch_tensorrt/_util.py b/py/torch_tensorrt/ts/_util.py similarity index 100% rename from py/torch_tensorrt/_util.py rename to py/torch_tensorrt/ts/_util.py From 47115fd813adc1f33b5eaa0fe43efa315fe79155 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 17 Jul 2023 17:08:07 -0700 Subject: [PATCH 2/2] fix: Rollback change to only add version sanitizer --- py/torch_tensorrt/__init__.py | 3 +-- py/torch_tensorrt/{ts => }/_util.py | 8 ++++++++ py/torch_tensorrt/_utils.py | 9 --------- py/torch_tensorrt/dynamo/__init__.py | 2 +- .../fx/test/passes/test_fix_reshape_batch_dim.py | 2 +- .../fx/test/passes/test_remove_duplicate_output_args.py | 2 +- .../fx/tracer/dispatch_tracer/aten_tracer.py | 2 +- py/torch_tensorrt/fx/utils.py | 2 +- 8 files changed, 14 insertions(+), 16 deletions(-) rename py/torch_tensorrt/{ts => }/_util.py (81%) delete mode 100644 py/torch_tensorrt/_utils.py diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index 5470ee8458..7743aec3f1 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -84,7 +84,7 @@ def _find_lib(name, paths): import torch from torch_tensorrt._compile import * -from torch_tensorrt.ts._util import * +from torch_tensorrt._util import * from torch_tensorrt import ts from torch_tensorrt import ptq from torch_tensorrt._enums import * @@ -93,7 +93,6 @@ def _find_lib(name, paths): from torch_tensorrt._Device import Device from torch_tensorrt import fx -from torch_tensorrt._utils import sanitized_torch_version if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): from torch_tensorrt import dynamo diff --git a/py/torch_tensorrt/ts/_util.py b/py/torch_tensorrt/_util.py similarity index 81% rename from py/torch_tensorrt/ts/_util.py rename to py/torch_tensorrt/_util.py index 28de07fe63..6f7e1a6c83 100644 --- a/py/torch_tensorrt/ts/_util.py +++ b/py/torch_tensorrt/_util.py @@ -30,3 +30,11 @@ def get_build_info() -> str: def set_device(gpu_id): _C.set_device(gpu_id) + + +def sanitized_torch_version() -> str: + return ( + torch.__version__ + if ".nv" not in torch.__version__ + else torch.__version__.split(".nv")[0] + ) diff --git a/py/torch_tensorrt/_utils.py b/py/torch_tensorrt/_utils.py deleted file mode 100644 index f0ce6507d9..0000000000 --- a/py/torch_tensorrt/_utils.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch - - -def sanitized_torch_version() -> str: - return ( - torch.__version__ - if ".nv" not in torch.__version__ - else torch.__version__.split(".nv")[0] - ) diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 0a40b118bb..ecd4384155 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -1,5 +1,5 @@ from packaging import version -from torch_tensorrt._utils import sanitized_torch_version +from torch_tensorrt._util import sanitized_torch_version if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): from torch_tensorrt.dynamo import fx_ts_compat diff --git a/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py b/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py index 8cb79f4958..cd5548b30a 100644 --- a/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py +++ b/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py @@ -11,7 +11,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase from torch_tensorrt.fx.passes.lower_basic_pass import fix_reshape_batch_dim from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer -from torch_tensorrt._utils import sanitized_torch_version +from torch_tensorrt._util import sanitized_torch_version _LOGGER = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py b/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py index 2cc97c46be..2dc1c404ee 100644 --- a/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py +++ b/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py @@ -8,7 +8,7 @@ import torch.nn as nn import torch_tensorrt.fx.passes.remove_duplicate_output_args as dedup -from torch_tensorrt._utils import sanitized_torch_version +from torch_tensorrt._util import sanitized_torch_version from torch.testing._internal.common_utils import run_tests, TestCase diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py index a23e824a66..248ec3f920 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union from packaging import version -from torch_tensorrt._utils import sanitized_torch_version +from torch_tensorrt._util import sanitized_torch_version import torch diff --git a/py/torch_tensorrt/fx/utils.py b/py/torch_tensorrt/fx/utils.py index ea622ac476..859529d861 100644 --- a/py/torch_tensorrt/fx/utils.py +++ b/py/torch_tensorrt/fx/utils.py @@ -12,7 +12,7 @@ replace_op_with_indices, run_const_fold, ) -from torch_tensorrt._utils import sanitized_torch_version +from torch_tensorrt._util import sanitized_torch_version from .types import Shape, TRTDataType