diff --git a/.circleci/config.yml b/.circleci/config.yml index 7204e13df0..dab13dcb88 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -763,15 +763,15 @@ commands: - store_artifacts: path: /tmp/testlogs - test-dynamo-torch_compile-core: - description: "Test the Dynamo torch_compile path" + test-dynamo-compile-core: + description: "Test the Dynamo compile path" steps: - run: - name: Run Dynamo torch_compile core tests + name: Run Dynamo compile core tests command: | - cd py/torch_tensorrt/dynamo/torch_compile + cd py/torch_tensorrt/dynamo/backend pushd test/ - pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml + pytest --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml popd - store_test_results: @@ -779,17 +779,17 @@ commands: - store_artifacts: path: /tmp/testlogs - test-dynamo-torch_compile: - description: "Test the Dynamo torch_compile path" + test-dynamo-compile: + description: "Test the Dynamo compile path" steps: - run: - name: Run Dynamo torch_compile E2E tests + name: Run Dynamo compile E2E tests command: | cd py/torch_tensorrt/dynamo/ pushd test/ pip3 install timm pip3 install transformers - pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml --ir torch_compile + pytest --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir dynamo_compile popd - store_test_results: @@ -1051,8 +1051,8 @@ jobs: command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl # We install torch after torch-trt because pip automatically enforces the version constraint otherwise - dump-test-env - - test-dynamo-torch_compile - - test-dynamo-torch_compile-core + - test-dynamo-compile + - test-dynamo-compile-core - test-dynamo-fx_ts package-x86_64-linux: diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index b8e4fd0d9d..f92b29aa86 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -97,7 +97,7 @@ def _find_lib(name, paths): if version.parse(torch.__version__) >= version.parse("2.dev"): from torch_tensorrt import dynamo - from torch_tensorrt.dynamo import torch_compile + from torch_tensorrt.dynamo import backend def _register_with_torch(): diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index e300669fd5..de0aeb5308 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -16,7 +16,7 @@ class _IRType(Enum): ts = 0 fx = 1 fx_ts_compat = 2 - torch_compile = 3 + dynamo_compile = 3 class _ModuleType(Enum): @@ -47,7 +47,7 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]]) ir_targets_fx = ir == "fx" - ir_targets_torch_compile = ir == "torch_compile" + ir_targets_dynamo_compile = ir == "dynamo_compile" ir_targets_fx_ts_compat = ir == "fx_ts_compat" if module_is_tsable and ir_targets_torchscript: @@ -56,8 +56,8 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: return _IRType.fx elif module_is_fxable and ir_targets_fx_ts_compat: return _IRType.fx_ts_compat - elif module_is_fxable and ir_targets_torch_compile: - return _IRType.torch_compile + elif module_is_fxable and ir_targets_dynamo_compile: + return _IRType.dynamo_compile else: if ir == "default": # Options are listed in order of preference @@ -156,8 +156,8 @@ def compile( dynamic_batch=False, **kwargs, ) - elif target_ir == _IRType.torch_compile: - return torch_tensorrt.dynamo.torch_compile( + elif target_ir == _IRType.dynamo_compile: + return torch_tensorrt.dynamo.compile( module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs ) elif target_ir == _IRType.fx_ts_compat: diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 26e8b7aa3e..ea1778edfe 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -1,2 +1,2 @@ from torch_tensorrt.dynamo import fx_ts_compat -from .torch_compile import compile as torch_compile +from .backend import compile diff --git a/py/torch_tensorrt/dynamo/torch_compile/__init__.py b/py/torch_tensorrt/dynamo/backend/__init__.py similarity index 90% rename from py/torch_tensorrt/dynamo/torch_compile/__init__.py rename to py/torch_tensorrt/dynamo/backend/__init__.py index 32e5567c51..eba389ecec 100644 --- a/py/torch_tensorrt/dynamo/torch_compile/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/__init__.py @@ -8,10 +8,10 @@ from torch_tensorrt import EngineCapability, Device from torch_tensorrt.fx.utils import LowerPrecision -from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings -from torch_tensorrt.dynamo.torch_compile.utils import prepare_inputs, prepare_device -from torch_tensorrt.dynamo.torch_compile.backends import tensorrt_backend -from torch_tensorrt.dynamo.torch_compile._defaults import ( +from torch_tensorrt.dynamo.backend._settings import CompilationSettings +from torch_tensorrt.dynamo.backend.utils import prepare_inputs, prepare_device +from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend +from torch_tensorrt.dynamo.backend._defaults import ( PRECISION, DEBUG, MAX_WORKSPACE_SIZE, @@ -121,6 +121,6 @@ def create_backend( ) return partial( - tensorrt_backend, + torch_tensorrt_backend, settings=settings, ) diff --git a/py/torch_tensorrt/dynamo/torch_compile/_defaults.py b/py/torch_tensorrt/dynamo/backend/_defaults.py similarity index 83% rename from py/torch_tensorrt/dynamo/torch_compile/_defaults.py rename to py/torch_tensorrt/dynamo/backend/_defaults.py index 48c9a26f9e..814331e158 100644 --- a/py/torch_tensorrt/dynamo/torch_compile/_defaults.py +++ b/py/torch_tensorrt/dynamo/backend/_defaults.py @@ -4,4 +4,4 @@ PRECISION = LowerPrecision.FP32 DEBUG = False MAX_WORKSPACE_SIZE = 20 << 30 -MAX_NUM_TRT_ENGINES = 200 +MAX_NUM_TRT_ENGINES = 10 diff --git a/py/torch_tensorrt/dynamo/torch_compile/_settings.py b/py/torch_tensorrt/dynamo/backend/_settings.py similarity index 86% rename from py/torch_tensorrt/dynamo/torch_compile/_settings.py rename to py/torch_tensorrt/dynamo/backend/_settings.py index 276b8742ff..7677b1bd57 100644 --- a/py/torch_tensorrt/dynamo/torch_compile/_settings.py +++ b/py/torch_tensorrt/dynamo/backend/_settings.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from torch_tensorrt.fx.utils import LowerPrecision -from torch_tensorrt.dynamo.torch_compile._defaults import ( +from torch_tensorrt.dynamo.backend._defaults import ( PRECISION, DEBUG, MAX_WORKSPACE_SIZE, diff --git a/py/torch_tensorrt/dynamo/torch_compile/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py similarity index 70% rename from py/torch_tensorrt/dynamo/torch_compile/backends.py rename to py/torch_tensorrt/dynamo/backend/backends.py index 9ceab947f0..9df3f1c686 100644 --- a/py/torch_tensorrt/dynamo/torch_compile/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -4,30 +4,42 @@ from functools import partial import torch._dynamo as td -from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings -from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import ( +from torch_tensorrt.dynamo.backend._settings import CompilationSettings +from torch_tensorrt.dynamo.backend.lowering._decompositions import ( get_decompositions, ) -from torch_tensorrt.dynamo.torch_compile.lowering._partition import ( +from torch_tensorrt.dynamo.backend.lowering._partition import ( partition, get_submod_inputs, ) -from torch_tensorrt.dynamo.torch_compile.conversion import convert_module +from torch_tensorrt.dynamo.backend.conversion import convert_module from torch._dynamo.backends.common import fake_tensor_unsupported from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler -@td.register_backend(name="tensorrt") +@td.register_backend(name="torch_tensorrt") @fake_tensor_unsupported -def tensorrt_backend( - gm: torch.nn.Module, +def torch_tensorrt_backend( + gm: torch.fx.GraphModule, + sample_inputs: Sequence[torch.Tensor], + settings: CompilationSettings = CompilationSettings(), +): + DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend + + return DEFAULT_BACKEND(gm, sample_inputs, settings=settings) + + +@td.register_backend(name="aot_torch_tensorrt_aten") +@fake_tensor_unsupported +def aot_torch_tensorrt_aten_backend( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], settings: CompilationSettings = CompilationSettings(), ): custom_backend = partial( - fx_dynamo_backend, + _pretraced_backend, settings=settings, ) @@ -40,14 +52,12 @@ def tensorrt_backend( ) -@td.register_backend(name="fx_tensorrt") -@fake_tensor_unsupported -def fx_dynamo_backend( +def _pretraced_backend( gm: torch.fx.GraphModule, - example_inputs: Sequence[torch.Tensor], + sample_inputs: Sequence[torch.Tensor], settings: CompilationSettings = CompilationSettings(), ): - """Helper function to manage translation of FX module to TRT engines + """Helper function to manage translation of traced FX module to TRT engines Args: module: FX GraphModule to convert @@ -57,9 +67,9 @@ def fx_dynamo_backend( Compiled FX GraphModule """ try: - trt_compiled = compile_module( + trt_compiled = _compile_module( gm, - example_inputs, + sample_inputs, settings=settings, ) return trt_compiled @@ -72,12 +82,12 @@ def fx_dynamo_backend( return gm.forward -def compile_module( +def _compile_module( gm: torch.fx.GraphModule, - example_inputs: Sequence[torch.Tensor], + sample_inputs: Sequence[torch.Tensor], settings: CompilationSettings = CompilationSettings(), ) -> torch.fx.GraphModule: - """Compile an FX module + """Compile a traced FX module Includes: Partitioning + Conversion Phases @@ -100,7 +110,7 @@ def compile_module( # Get submodule inputs submodule_inputs = get_submod_inputs( - partitioned_module, submodule, example_inputs + partitioned_module, submodule, sample_inputs ) # Create TRT Module from submodule diff --git a/py/torch_tensorrt/dynamo/torch_compile/conversion.py b/py/torch_tensorrt/dynamo/backend/conversion.py similarity index 100% rename from py/torch_tensorrt/dynamo/torch_compile/conversion.py rename to py/torch_tensorrt/dynamo/backend/conversion.py diff --git a/py/torch_tensorrt/dynamo/backend/lowering/__init__.py b/py/torch_tensorrt/dynamo/backend/lowering/__init__.py new file mode 100644 index 0000000000..01b20cef6d --- /dev/null +++ b/py/torch_tensorrt/dynamo/backend/lowering/__init__.py @@ -0,0 +1,7 @@ +from torch_tensorrt.dynamo.backend.lowering._decompositions import ( + get_decompositions, +) +from torch_tensorrt.dynamo.backend.lowering._partition import ( + partition, + get_submod_inputs, +) diff --git a/py/torch_tensorrt/dynamo/torch_compile/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py similarity index 100% rename from py/torch_tensorrt/dynamo/torch_compile/lowering/_decompositions.py rename to py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py diff --git a/py/torch_tensorrt/dynamo/torch_compile/lowering/_partition.py b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py similarity index 98% rename from py/torch_tensorrt/dynamo/torch_compile/lowering/_partition.py rename to py/torch_tensorrt/dynamo/backend/lowering/_partition.py index 1dd38e0bd9..1885d18705 100644 --- a/py/torch_tensorrt/dynamo/torch_compile/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py @@ -2,7 +2,7 @@ import torch -from torch_tensorrt.dynamo.torch_compile._defaults import MAX_NUM_TRT_ENGINES +from torch_tensorrt.dynamo.backend._defaults import MAX_NUM_TRT_ENGINES from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupport diff --git a/py/torch_tensorrt/dynamo/torch_compile/test/test_compiler_utils.py b/py/torch_tensorrt/dynamo/backend/test/test_compiler_utils.py similarity index 95% rename from py/torch_tensorrt/dynamo/torch_compile/test/test_compiler_utils.py rename to py/torch_tensorrt/dynamo/backend/test/test_compiler_utils.py index da7157c3e5..947a277ddd 100644 --- a/py/torch_tensorrt/dynamo/torch_compile/test/test_compiler_utils.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_compiler_utils.py @@ -1,4 +1,4 @@ -from torch_tensorrt.dynamo.torch_compile.utils import prepare_device, prepare_inputs +from torch_tensorrt.dynamo.backend.utils import prepare_device, prepare_inputs from utils import same_output_format import torch_tensorrt import unittest diff --git a/py/torch_tensorrt/dynamo/torch_compile/test/test_lowering.py b/py/torch_tensorrt/dynamo/backend/test/test_lowering.py similarity index 100% rename from py/torch_tensorrt/dynamo/torch_compile/test/test_lowering.py rename to py/torch_tensorrt/dynamo/backend/test/test_lowering.py diff --git a/py/torch_tensorrt/dynamo/torch_compile/test/test_partitioning.py b/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py similarity index 97% rename from py/torch_tensorrt/dynamo/torch_compile/test/test_partitioning.py rename to py/torch_tensorrt/dynamo/backend/test/test_partitioning.py index b068f9c413..fccdd3c32e 100644 --- a/py/torch_tensorrt/dynamo/torch_compile/test/test_partitioning.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py @@ -1,4 +1,4 @@ -from torch_tensorrt.dynamo.torch_compile.lowering import partition +from torch_tensorrt.dynamo.backend.lowering import partition from torch.testing._internal.common_utils import run_tests, TestCase import torch from copy import deepcopy diff --git a/py/torch_tensorrt/dynamo/torch_compile/test/utils.py b/py/torch_tensorrt/dynamo/backend/test/utils.py similarity index 95% rename from py/torch_tensorrt/dynamo/torch_compile/test/utils.py rename to py/torch_tensorrt/dynamo/backend/test/utils.py index bdcbbfcc4a..466a600db8 100644 --- a/py/torch_tensorrt/dynamo/torch_compile/test/utils.py +++ b/py/torch_tensorrt/dynamo/backend/test/utils.py @@ -2,10 +2,10 @@ from functools import partial from typing import List, Sequence import torch -from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import ( +from torch_tensorrt.dynamo.backend.lowering._decompositions import ( get_decompositions, ) -from torch_tensorrt.dynamo.torch_compile.lowering._partition import ( +from torch_tensorrt.dynamo.backend.lowering._partition import ( partition, ) diff --git a/py/torch_tensorrt/dynamo/torch_compile/utils.py b/py/torch_tensorrt/dynamo/backend/utils.py similarity index 98% rename from py/torch_tensorrt/dynamo/torch_compile/utils.py rename to py/torch_tensorrt/dynamo/backend/utils.py index ba76536338..e6e22d5f96 100644 --- a/py/torch_tensorrt/dynamo/torch_compile/utils.py +++ b/py/torch_tensorrt/dynamo/backend/utils.py @@ -45,7 +45,7 @@ def prepare_inputs( else: raise ValueError( - f"Invalid input type {type(inputs)} encountered in the torch_compile input parsing. " + f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. " + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" ) diff --git a/py/torch_tensorrt/dynamo/test/conftest.py b/py/torch_tensorrt/dynamo/test/conftest.py index 98be643435..7218d5335b 100644 --- a/py/torch_tensorrt/dynamo/test/conftest.py +++ b/py/torch_tensorrt/dynamo/test/conftest.py @@ -9,7 +9,7 @@ def pytest_addoption(parser): type=str, required=True, help="IR to compile with", - choices=["torch_compile", "fx_ts_compat"], + choices=["dynamo_compile", "fx_ts_compat"], ) diff --git a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py index 4852f033bd..531d0cc317 100644 --- a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py +++ b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py @@ -24,6 +24,7 @@ def test_resnet18(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, + "max_num_trt_engines": 200, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -33,6 +34,12 @@ def test_resnet18(ir): f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache() + @pytest.mark.unit def test_mobilenet_v2(ir): @@ -48,6 +55,7 @@ def test_mobilenet_v2(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, + "max_num_trt_engines": 200, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -57,6 +65,12 @@ def test_mobilenet_v2(ir): f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache() + @pytest.mark.unit def test_efficientnet_b0(ir): @@ -72,6 +86,7 @@ def test_efficientnet_b0(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, + "max_num_trt_engines": 200, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -81,6 +96,12 @@ def test_efficientnet_b0(ir): f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache() + @pytest.mark.unit def test_bert_base_uncased(ir): @@ -104,8 +125,8 @@ def test_bert_base_uncased(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "truncate_long_and_double": True, - "debug": True, "ir": ir, + "max_num_trt_engines": 200, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -119,6 +140,12 @@ def test_bert_base_uncased(ir): f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache() + @pytest.mark.unit def test_resnet18_half(ir): @@ -142,3 +169,9 @@ def test_resnet18_half(ir): cos_sim > COSINE_THRESHOLD, f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache() diff --git a/py/torch_tensorrt/dynamo/torch_compile/lowering/__init__.py b/py/torch_tensorrt/dynamo/torch_compile/lowering/__init__.py deleted file mode 100644 index e0a41df755..0000000000 --- a/py/torch_tensorrt/dynamo/torch_compile/lowering/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import ( - get_decompositions, -) -from torch_tensorrt.dynamo.torch_compile.lowering._partition import ( - partition, - get_submod_inputs, -)