Skip to content

fix: Add new TRT 8.6 features to Dynamo compile #1971

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions py/torch_tensorrt/dynamo/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch_tensorrt
from functools import partial

from typing import Any, Sequence
from typing import Any, Optional, Sequence
from torch_tensorrt import EngineCapability, Device
from torch_tensorrt.fx.utils import LowerPrecision

Expand All @@ -14,9 +14,13 @@
from torch_tensorrt.dynamo.backend._defaults import (
PRECISION,
DEBUG,
MAX_WORKSPACE_SIZE,
WORKSPACE_SIZE,
MIN_BLOCK_SIZE,
PASS_THROUGH_BUILD_FAILURES,
MAX_AUX_STREAMS,
VERSION_COMPATIBLE,
OPTIMIZATION_LEVEL,
USE_EXPERIMENTAL_RT,
)


Expand All @@ -35,7 +39,7 @@ def compile(
debug=DEBUG,
capability=EngineCapability.default,
num_avg_timing_iters=1,
workspace_size=MAX_WORKSPACE_SIZE,
workspace_size=WORKSPACE_SIZE,
dla_sram_size=1048576,
dla_local_dram_size=1073741824,
dla_global_dram_size=536870912,
Expand All @@ -45,6 +49,10 @@ def compile(
min_block_size=MIN_BLOCK_SIZE,
torch_executed_ops=[],
torch_executed_modules=[],
max_aux_streams=MAX_AUX_STREAMS,
version_compatible=VERSION_COMPATIBLE,
optimization_level=OPTIMIZATION_LEVEL,
use_experimental_rt=USE_EXPERIMENTAL_RT,
**kwargs,
):
if debug:
Expand Down Expand Up @@ -86,6 +94,10 @@ def compile(
workspace_size=workspace_size,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
max_aux_streams=max_aux_streams,
version_compatible=version_compatible,
optimization_level=optimization_level,
use_experimental_rt=use_experimental_rt,
**kwargs,
)

Expand All @@ -105,19 +117,30 @@ def compile(
def create_backend(
precision: LowerPrecision = PRECISION,
debug: bool = DEBUG,
workspace_size: int = MAX_WORKSPACE_SIZE,
workspace_size: int = WORKSPACE_SIZE,
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Sequence[str] = set(),
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
version_compatible: bool = VERSION_COMPATIBLE,
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
use_experimental_rt: bool = USE_EXPERIMENTAL_RT,
**kwargs,
):
"""Create torch.compile backend given specified arguments

Args:
precision:
debug: Whether to print out verbose debugging information
workspace_size: Maximum workspace TRT is allowed to use for the module
precision: Model Layer precision
workspace_size: Workspace TRT is allowed to use for the module (0 is default)
min_block_size: Minimum number of operators per TRT-Engine Block
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
version_compatible: Provide version forward-compatibility for engine plan files
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
searching for more optimization options. TRT defaults to 3
use_experimental_rt: Whether to use the new experimental TRTModuleNext for TRT engines
Returns:
Backend for torch.compile
"""
Expand All @@ -131,6 +154,10 @@ def create_backend(
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
pass_through_build_failures=pass_through_build_failures,
max_aux_streams=max_aux_streams,
version_compatible=version_compatible,
optimization_level=optimization_level,
use_experimental_rt=use_experimental_rt,
)

return partial(
Expand Down
6 changes: 5 additions & 1 deletion py/torch_tensorrt/dynamo/backend/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

PRECISION = LowerPrecision.FP32
DEBUG = False
MAX_WORKSPACE_SIZE = 20 << 30
WORKSPACE_SIZE = 0
MIN_BLOCK_SIZE = 5
PASS_THROUGH_BUILD_FAILURES = False
MAX_AUX_STREAMS = None
VERSION_COMPATIBLE = False
OPTIMIZATION_LEVEL = None
USE_EXPERIMENTAL_RT = False
14 changes: 11 additions & 3 deletions py/torch_tensorrt/dynamo/backend/_settings.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
from dataclasses import dataclass, field
from typing import Sequence
from typing import Optional, Sequence

from torch_tensorrt.fx.utils import LowerPrecision
from torch_tensorrt.dynamo.backend._defaults import (
PRECISION,
DEBUG,
MAX_WORKSPACE_SIZE,
WORKSPACE_SIZE,
MIN_BLOCK_SIZE,
PASS_THROUGH_BUILD_FAILURES,
MAX_AUX_STREAMS,
VERSION_COMPATIBLE,
OPTIMIZATION_LEVEL,
USE_EXPERIMENTAL_RT,
)


@dataclass(frozen=True)
class CompilationSettings:
precision: LowerPrecision = PRECISION
debug: bool = DEBUG
workspace_size: int = MAX_WORKSPACE_SIZE
workspace_size: int = WORKSPACE_SIZE
min_block_size: int = MIN_BLOCK_SIZE
torch_executed_ops: Sequence[str] = field(default_factory=set)
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES
max_aux_streams: Optional[int] = MAX_AUX_STREAMS
version_compatible: bool = VERSION_COMPATIBLE
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
use_experimental_rt: bool = USE_EXPERIMENTAL_RT
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def _compile_module(
submodule,
submodule_inputs,
settings=settings,
name=name,
)

# Replace FX Module with TRT Module
Expand Down
31 changes: 26 additions & 5 deletions py/torch_tensorrt/dynamo/backend/conversion.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Sequence, Union
import torch
import io
from torch_tensorrt.fx.trt_module import TRTModule
from torch_tensorrt import TRTModuleNext
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
from torch_tensorrt.fx.fx2trt import (
from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import (
InputTensorSpec,
TRTInterpreter,
)
Expand All @@ -15,30 +16,50 @@ def convert_module(
module: torch.fx.GraphModule,
inputs: Sequence[torch.Tensor],
settings: CompilationSettings = CompilationSettings(),
name: str = "",
) -> Union[TRTModuleNext, TRTModule]:
"""Convert an FX module to a TRT module
Args:
module: FX GraphModule to convert
inputs: Sequence of Tensors representing inputs to the module
settings: Compilation settings
name: TRT engine name
Returns:
TRTModule or TRTModuleNext
"""
interp = TRTInterpreter(
interpreter = TRTInterpreter(
module,
InputTensorSpec.from_tensors(inputs),
explicit_batch_dimension=True,
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
)

r = interp.run(
max_workspace_size=settings.workspace_size,
interpreter_result = interpreter.run(
workspace_size=settings.workspace_size,
lower_precision=settings.precision,
profiling_verbosity=(
trt.ProfilingVerbosity.VERBOSE
if settings.debug
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
),
max_aux_streams=settings.max_aux_streams,
version_compatible=settings.version_compatible,
optimization_level=settings.optimization_level,
)

return TRTModule(*r)
if settings.use_experimental_rt:
with io.BytesIO() as engine_bytes:
engine_bytes.write(interpreter_result.engine.serialize())
engine_str = engine_bytes.getvalue()
return TRTModuleNext(
serialized_engine=engine_str,
name=name,
input_binding_names=interpreter_result.input_names,
output_binding_names=interpreter_result.output_names,
)
else:
return TRTModule(
engine=interpreter_result.engine,
input_names=interpreter_result.input_names,
output_names=interpreter_result.output_names,
)
5 changes: 3 additions & 2 deletions py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import warnings
from datetime import datetime
from packaging import version
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence

import numpy
Expand Down Expand Up @@ -224,14 +225,14 @@ def run(
cache = builder_config.create_timing_cache(b"")
builder_config.set_timing_cache(cache, False)

if trt.__version__ >= "8.2":
if version.parse(trt.__version__) >= version.parse("8.2"):
builder_config.profiling_verbosity = (
profiling_verbosity
if profiling_verbosity
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
)

if trt.__version__ >= "8.6":
if version.parse(trt.__version__) >= version.parse("8.6"):
if max_aux_streams is not None:
_LOGGER.info(f"Setting max aux streams to {max_aux_streams}")
builder_config.max_aux_streams = max_aux_streams
Expand Down