Skip to content

fix: Reorganize Dynamo directory + backends #1928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -763,33 +763,33 @@ commands:
- store_artifacts:
path: /tmp/testlogs

test-dynamo-torch_compile-core:
description: "Test the Dynamo torch_compile path"
test-dynamo-compile-core:
description: "Test the Dynamo compile path"
steps:
- run:
name: Run Dynamo torch_compile core tests
name: Run Dynamo compile core tests
command: |
cd py/torch_tensorrt/dynamo/torch_compile
cd py/torch_tensorrt/dynamo/backend
pushd test/
pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml
pytest --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml
popd

- store_test_results:
path: /tmp/artifacts
- store_artifacts:
path: /tmp/testlogs

test-dynamo-torch_compile:
description: "Test the Dynamo torch_compile path"
test-dynamo-compile:
description: "Test the Dynamo compile path"
steps:
- run:
name: Run Dynamo torch_compile E2E tests
name: Run Dynamo compile E2E tests
command: |
cd py/torch_tensorrt/dynamo/
pushd test/
pip3 install timm
pip3 install transformers
pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml --ir torch_compile
pytest --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir dynamo_compile
popd

- store_test_results:
Expand Down Expand Up @@ -1051,8 +1051,8 @@ jobs:
command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl
# We install torch after torch-trt because pip automatically enforces the version constraint otherwise
- dump-test-env
- test-dynamo-torch_compile
- test-dynamo-torch_compile-core
- test-dynamo-compile
- test-dynamo-compile-core
- test-dynamo-fx_ts

package-x86_64-linux:
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _find_lib(name, paths):

if version.parse(torch.__version__) >= version.parse("2.dev"):
from torch_tensorrt import dynamo
from torch_tensorrt.dynamo import torch_compile
from torch_tensorrt.dynamo import backend


def _register_with_torch():
Expand Down
12 changes: 6 additions & 6 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class _IRType(Enum):
ts = 0
fx = 1
fx_ts_compat = 2
torch_compile = 3
dynamo_compile = 3


class _ModuleType(Enum):
Expand Down Expand Up @@ -47,7 +47,7 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:

ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]])
ir_targets_fx = ir == "fx"
ir_targets_torch_compile = ir == "torch_compile"
ir_targets_dynamo_compile = ir == "dynamo_compile"
ir_targets_fx_ts_compat = ir == "fx_ts_compat"

if module_is_tsable and ir_targets_torchscript:
Expand All @@ -56,8 +56,8 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
return _IRType.fx
elif module_is_fxable and ir_targets_fx_ts_compat:
return _IRType.fx_ts_compat
elif module_is_fxable and ir_targets_torch_compile:
return _IRType.torch_compile
elif module_is_fxable and ir_targets_dynamo_compile:
return _IRType.dynamo_compile
else:
if ir == "default":
# Options are listed in order of preference
Expand Down Expand Up @@ -156,8 +156,8 @@ def compile(
dynamic_batch=False,
**kwargs,
)
elif target_ir == _IRType.torch_compile:
return torch_tensorrt.dynamo.torch_compile(
elif target_ir == _IRType.dynamo_compile:
return torch_tensorrt.dynamo.compile(
module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
)
elif target_ir == _IRType.fx_ts_compat:
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from torch_tensorrt.dynamo import fx_ts_compat
from .torch_compile import compile as torch_compile
from .backend import compile
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from torch_tensorrt import EngineCapability, Device
from torch_tensorrt.fx.utils import LowerPrecision

from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings
from torch_tensorrt.dynamo.torch_compile.utils import prepare_inputs, prepare_device
from torch_tensorrt.dynamo.torch_compile.backends import tensorrt_backend
from torch_tensorrt.dynamo.torch_compile._defaults import (
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
from torch_tensorrt.dynamo.backend.utils import prepare_inputs, prepare_device
from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend
from torch_tensorrt.dynamo.backend._defaults import (
PRECISION,
DEBUG,
MAX_WORKSPACE_SIZE,
Expand Down Expand Up @@ -121,6 +121,6 @@ def create_backend(
)

return partial(
tensorrt_backend,
torch_tensorrt_backend,
settings=settings,
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
PRECISION = LowerPrecision.FP32
DEBUG = False
MAX_WORKSPACE_SIZE = 20 << 30
MAX_NUM_TRT_ENGINES = 200
MAX_NUM_TRT_ENGINES = 10
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass

from torch_tensorrt.fx.utils import LowerPrecision
from torch_tensorrt.dynamo.torch_compile._defaults import (
from torch_tensorrt.dynamo.backend._defaults import (
PRECISION,
DEBUG,
MAX_WORKSPACE_SIZE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,42 @@
from functools import partial
import torch._dynamo as td

from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings
from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import (
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
get_decompositions,
)
from torch_tensorrt.dynamo.torch_compile.lowering._partition import (
from torch_tensorrt.dynamo.backend.lowering._partition import (
partition,
get_submod_inputs,
)
from torch_tensorrt.dynamo.torch_compile.conversion import convert_module
from torch_tensorrt.dynamo.backend.conversion import convert_module

from torch._dynamo.backends.common import fake_tensor_unsupported

from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler


@td.register_backend(name="tensorrt")
@td.register_backend(name="torch_tensorrt")
@fake_tensor_unsupported
def tensorrt_backend(
gm: torch.nn.Module,
def torch_tensorrt_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
settings: CompilationSettings = CompilationSettings(),
):
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend

return DEFAULT_BACKEND(gm, sample_inputs, settings=settings)


@td.register_backend(name="aot_torch_tensorrt_aten")
@fake_tensor_unsupported
def aot_torch_tensorrt_aten_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
settings: CompilationSettings = CompilationSettings(),
):
custom_backend = partial(
fx_dynamo_backend,
_pretraced_backend,
settings=settings,
)

Expand All @@ -40,14 +52,12 @@ def tensorrt_backend(
)


@td.register_backend(name="fx_tensorrt")
@fake_tensor_unsupported
def fx_dynamo_backend(
def _pretraced_backend(
gm: torch.fx.GraphModule,
example_inputs: Sequence[torch.Tensor],
sample_inputs: Sequence[torch.Tensor],
settings: CompilationSettings = CompilationSettings(),
):
"""Helper function to manage translation of FX module to TRT engines
"""Helper function to manage translation of traced FX module to TRT engines

Args:
module: FX GraphModule to convert
Expand All @@ -57,9 +67,9 @@ def fx_dynamo_backend(
Compiled FX GraphModule
"""
try:
trt_compiled = compile_module(
trt_compiled = _compile_module(
gm,
example_inputs,
sample_inputs,
settings=settings,
)
return trt_compiled
Expand All @@ -72,12 +82,12 @@ def fx_dynamo_backend(
return gm.forward


def compile_module(
def _compile_module(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added "_" to name, as the function is intended only for internal/helper use (requires a pre-traced FX GraphModule)

gm: torch.fx.GraphModule,
example_inputs: Sequence[torch.Tensor],
sample_inputs: Sequence[torch.Tensor],
settings: CompilationSettings = CompilationSettings(),
) -> torch.fx.GraphModule:
"""Compile an FX module
"""Compile a traced FX module

Includes: Partitioning + Conversion Phases

Expand All @@ -100,7 +110,7 @@ def compile_module(

# Get submodule inputs
submodule_inputs = get_submod_inputs(
partitioned_module, submodule, example_inputs
partitioned_module, submodule, sample_inputs
)

# Create TRT Module from submodule
Expand Down
7 changes: 7 additions & 0 deletions py/torch_tensorrt/dynamo/backend/lowering/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
get_decompositions,
)
from torch_tensorrt.dynamo.backend.lowering._partition import (
partition,
get_submod_inputs,
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from torch_tensorrt.dynamo.torch_compile._defaults import MAX_NUM_TRT_ENGINES
from torch_tensorrt.dynamo.backend._defaults import MAX_NUM_TRT_ENGINES
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupport

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch_tensorrt.dynamo.torch_compile.utils import prepare_device, prepare_inputs
from torch_tensorrt.dynamo.backend.utils import prepare_device, prepare_inputs
from utils import same_output_format
import torch_tensorrt
import unittest
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch_tensorrt.dynamo.torch_compile.lowering import partition
from torch_tensorrt.dynamo.backend.lowering import partition
from torch.testing._internal.common_utils import run_tests, TestCase
import torch
from copy import deepcopy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from functools import partial
from typing import List, Sequence
import torch
from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import (
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
get_decompositions,
)
from torch_tensorrt.dynamo.torch_compile.lowering._partition import (
from torch_tensorrt.dynamo.backend.lowering._partition import (
partition,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def prepare_inputs(

else:
raise ValueError(
f"Invalid input type {type(inputs)} encountered in the torch_compile input parsing. "
f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. "
+ "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}"
)

Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def pytest_addoption(parser):
type=str,
required=True,
help="IR to compile with",
choices=["torch_compile", "fx_ts_compat"],
choices=["dynamo_compile", "fx_ts_compat"],
)


Expand Down
Loading