Skip to content

chore/fix: Update TRTInterpreter impl in Dynamo compile [1 / x] #2002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions py/torch_tensorrt/dynamo/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch_tensorrt.dynamo.backend._defaults import (
PRECISION,
DEBUG,
MAX_WORKSPACE_SIZE,
WORKSPACE_SIZE,
MIN_BLOCK_SIZE,
PASS_THROUGH_BUILD_FAILURES,
)
Expand All @@ -35,7 +35,7 @@ def compile(
debug=DEBUG,
capability=EngineCapability.default,
num_avg_timing_iters=1,
workspace_size=MAX_WORKSPACE_SIZE,
workspace_size=WORKSPACE_SIZE,
dla_sram_size=1048576,
dla_local_dram_size=1073741824,
dla_global_dram_size=536870912,
Expand Down Expand Up @@ -105,7 +105,7 @@ def compile(
def create_backend(
precision: LowerPrecision = PRECISION,
debug: bool = DEBUG,
workspace_size: int = MAX_WORKSPACE_SIZE,
workspace_size: int = WORKSPACE_SIZE,
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Sequence[str] = set(),
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
Expand All @@ -114,10 +114,12 @@ def create_backend(
"""Create torch.compile backend given specified arguments

Args:
precision:
debug: Whether to print out verbose debugging information
workspace_size: Maximum workspace TRT is allowed to use for the module
precision: Model Layer precision
debug: Whether to print out verbose debugging information
workspace_size: Workspace TRT is allowed to use for the module (0 is default)
min_block_size: Minimum number of operators per TRT-Engine Block
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
Returns:
Backend for torch.compile
"""
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/backend/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

PRECISION = LowerPrecision.FP32
DEBUG = False
MAX_WORKSPACE_SIZE = 20 << 30
WORKSPACE_SIZE = 0
MIN_BLOCK_SIZE = 5
PASS_THROUGH_BUILD_FAILURES = False
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/backend/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch_tensorrt.dynamo.backend._defaults import (
PRECISION,
DEBUG,
MAX_WORKSPACE_SIZE,
WORKSPACE_SIZE,
MIN_BLOCK_SIZE,
PASS_THROUGH_BUILD_FAILURES,
)
Expand All @@ -15,7 +15,7 @@
class CompilationSettings:
precision: LowerPrecision = PRECISION
debug: bool = DEBUG
workspace_size: int = MAX_WORKSPACE_SIZE
workspace_size: int = WORKSPACE_SIZE
min_block_size: int = MIN_BLOCK_SIZE
torch_executed_ops: Sequence[str] = field(default_factory=set)
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES
14 changes: 9 additions & 5 deletions py/torch_tensorrt/dynamo/backend/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch_tensorrt.fx.trt_module import TRTModule
from torch_tensorrt import TRTModuleNext
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
from torch_tensorrt.fx.fx2trt import (
from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import (
InputTensorSpec,
TRTInterpreter,
)
Expand All @@ -24,15 +24,15 @@ def convert_module(
Returns:
TRTModule or TRTModuleNext
"""
interp = TRTInterpreter(
interpreter = TRTInterpreter(
module,
InputTensorSpec.from_tensors(inputs),
explicit_batch_dimension=True,
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
)

r = interp.run(
max_workspace_size=settings.workspace_size,
interpreter_result = interpreter.run(
workspace_size=settings.workspace_size,
lower_precision=settings.precision,
profiling_verbosity=(
trt.ProfilingVerbosity.VERBOSE
Expand All @@ -41,4 +41,8 @@ def convert_module(
),
)

return TRTModule(*r)
return TRTModule(
engine=interpreter_result.engine,
input_names=interpreter_result.input_names,
output_names=interpreter_result.output_names,
)