Skip to content

Commit c7af89a

Browse files
committed
refactor: Move runtime, unify compile scripts, remove input_tensor_spec
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
1 parent 1db0582 commit c7af89a

File tree

8 files changed

+340
-191
lines changed

8 files changed

+340
-191
lines changed

py/torch_tensorrt/_compile.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,13 @@ def compile(
154154
dynamic_batch=False,
155155
**kwargs,
156156
)
157-
elif target_ir == _IRType.dynamo:
157+
elif target_ir == _IRType.dynamo or target_ir == _IRType.torch_compile:
158158
return torch_tensorrt.dynamo.compile(
159-
module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
160-
)
161-
elif target_ir == _IRType.torch_compile:
162-
return torch_tensorrt.dynamo.backend.compile(
163-
module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
159+
module,
160+
inputs=inputs,
161+
enabled_precisions=enabled_precisions,
162+
ir=target_ir.name,
163+
**kwargs,
164164
)
165165
else:
166166
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
+2-167
Original file line numberDiff line numberDiff line change
@@ -1,167 +1,2 @@
1-
import torch
2-
import logging
3-
import collections.abc
4-
import torch_tensorrt
5-
from functools import partial
6-
7-
from typing import Any, Optional, Sequence
8-
from torch_tensorrt import EngineCapability, Device
9-
from torch_tensorrt.fx.utils import LowerPrecision
10-
11-
from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device
12-
from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend
13-
from torch_tensorrt.dynamo._defaults import (
14-
PRECISION,
15-
DEBUG,
16-
WORKSPACE_SIZE,
17-
MIN_BLOCK_SIZE,
18-
PASS_THROUGH_BUILD_FAILURES,
19-
MAX_AUX_STREAMS,
20-
VERSION_COMPATIBLE,
21-
OPTIMIZATION_LEVEL,
22-
USE_PYTHON_RUNTIME,
23-
)
24-
25-
26-
logger = logging.getLogger(__name__)
27-
28-
29-
def compile(
30-
gm: torch.nn.Module,
31-
inputs: Any,
32-
*,
33-
device=Device._current_device(),
34-
disable_tf32=False,
35-
sparse_weights=False,
36-
enabled_precisions=set(),
37-
refit=False,
38-
debug=DEBUG,
39-
capability=EngineCapability.default,
40-
num_avg_timing_iters=1,
41-
workspace_size=WORKSPACE_SIZE,
42-
dla_sram_size=1048576,
43-
dla_local_dram_size=1073741824,
44-
dla_global_dram_size=536870912,
45-
calibrator=None,
46-
truncate_long_and_double=False,
47-
require_full_compilation=False,
48-
min_block_size=MIN_BLOCK_SIZE,
49-
torch_executed_ops=[],
50-
torch_executed_modules=[],
51-
pass_through_build_failures=PASS_THROUGH_BUILD_FAILURES,
52-
max_aux_streams=MAX_AUX_STREAMS,
53-
version_compatible=VERSION_COMPATIBLE,
54-
optimization_level=OPTIMIZATION_LEVEL,
55-
use_python_runtime=USE_PYTHON_RUNTIME,
56-
**kwargs,
57-
):
58-
if debug:
59-
logger.setLevel(logging.DEBUG)
60-
61-
logger.warn(
62-
"The Dynamo backend is an experimental feature, for which only the "
63-
+ "following arguments are supported: "
64-
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
65-
+ "torch_executed_ops, pass_through_build_failures}"
66-
)
67-
68-
if not isinstance(inputs, collections.abc.Sequence):
69-
inputs = [inputs]
70-
71-
inputs = prepare_inputs(inputs, prepare_device(device))
72-
73-
if not isinstance(enabled_precisions, collections.abc.Collection):
74-
enabled_precisions = [enabled_precisions]
75-
76-
# Parse user-specified enabled precisions
77-
if (
78-
torch.float16 in enabled_precisions
79-
or torch_tensorrt.dtype.half in enabled_precisions
80-
):
81-
lower_precision = LowerPrecision.FP16
82-
elif (
83-
torch.float32 in enabled_precisions
84-
or torch_tensorrt.dtype.float in enabled_precisions
85-
):
86-
lower_precision = LowerPrecision.FP32
87-
elif len(enabled_precisions) == 0:
88-
logger.info(f"No precision specified, defaulting to {PRECISION}")
89-
lower_precision = PRECISION
90-
else:
91-
raise ValueError(
92-
f"Precision {enabled_precisions} not supported in the Dynamo Path"
93-
)
94-
95-
custom_backend = create_backend(
96-
precision=lower_precision,
97-
debug=debug,
98-
workspace_size=workspace_size,
99-
min_block_size=min_block_size,
100-
torch_executed_ops=torch_executed_ops,
101-
pass_through_build_failures=pass_through_build_failures,
102-
max_aux_streams=max_aux_streams,
103-
version_compatible=version_compatible,
104-
optimization_level=optimization_level,
105-
use_python_runtime=use_python_runtime,
106-
**kwargs,
107-
)
108-
109-
model = torch.compile(gm, backend=custom_backend)
110-
111-
# Ensure compilation occurs by calling the function with provided inputs
112-
model(*inputs)
113-
114-
return model
115-
116-
117-
from torch_tensorrt.fx.utils import LowerPrecision
118-
119-
logger = logging.getLogger(__name__)
120-
121-
122-
def create_backend(
123-
precision: LowerPrecision = PRECISION,
124-
debug: bool = DEBUG,
125-
workspace_size: int = WORKSPACE_SIZE,
126-
min_block_size: int = MIN_BLOCK_SIZE,
127-
torch_executed_ops: Sequence[str] = set(),
128-
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
129-
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
130-
version_compatible: bool = VERSION_COMPATIBLE,
131-
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
132-
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME,
133-
**kwargs,
134-
):
135-
"""Create torch.compile backend given specified arguments
136-
137-
Args:
138-
precision: Model Layer precision
139-
debug: Whether to print out verbose debugging information
140-
workspace_size: Workspace TRT is allowed to use for the module (0 is default)
141-
min_block_size: Minimum number of operators per TRT-Engine Block
142-
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
143-
pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
144-
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
145-
version_compatible: Provide version forward-compatibility for engine plan files
146-
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
147-
searching for more optimization options. TRT defaults to 3
148-
use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
149-
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
150-
argument as None
151-
Returns:
152-
Backend for torch.compile
153-
"""
154-
return partial(
155-
torch_tensorrt_backend,
156-
debug=debug,
157-
precision=precision,
158-
workspace_size=workspace_size,
159-
min_block_size=min_block_size,
160-
torch_executed_ops=torch_executed_ops,
161-
pass_through_build_failures=pass_through_build_failures,
162-
max_aux_streams=max_aux_streams,
163-
version_compatible=version_compatible,
164-
optimization_level=optimization_level,
165-
use_python_runtime=use_python_runtime,
166-
**kwargs,
167-
)
1+
from .backends import torch_tensorrt_backend
2+
from .compile import compile

py/torch_tensorrt/dynamo/compile.py

+84-15
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from torch_tensorrt.dynamo import CompilationSettings
1919
from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device
20-
from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend
20+
from torch_tensorrt.dynamo.backend import torch_tensorrt_backend
2121
from torch_tensorrt.dynamo.conversion import convert_module
2222

2323
from torch_tensorrt.dynamo._defaults import (
@@ -58,6 +58,7 @@ def compile(
5858
min_block_size=MIN_BLOCK_SIZE,
5959
torch_executed_ops=[],
6060
torch_executed_modules=[],
61+
pass_through_build_failures=PASS_THROUGH_BUILD_FAILURES,
6162
max_aux_streams=MAX_AUX_STREAMS,
6263
version_compatible=VERSION_COMPATIBLE,
6364
optimization_level=OPTIMIZATION_LEVEL,
@@ -97,30 +98,98 @@ def compile(
9798
f"Precision {enabled_precisions} not supported in the Dynamo Path"
9899
)
99100

100-
settings = CompilationSettings(
101+
if kwargs.get("ir", "dynamo") == "torch_compile":
102+
custom_backend = create_backend(
103+
precision=lower_precision,
104+
debug=debug,
105+
workspace_size=workspace_size,
106+
min_block_size=min_block_size,
107+
torch_executed_ops=torch_executed_ops,
108+
pass_through_build_failures=pass_through_build_failures,
109+
max_aux_streams=max_aux_streams,
110+
version_compatible=version_compatible,
111+
optimization_level=optimization_level,
112+
use_python_runtime=use_python_runtime,
113+
**kwargs,
114+
)
115+
model = torch.compile(gm, backend=custom_backend)
116+
# Ensure compilation occurs by calling the function with provided inputs
117+
model(*inputs)
118+
return model
119+
120+
else:
121+
settings = CompilationSettings(
122+
debug=debug,
123+
precision=lower_precision,
124+
workspace_size=workspace_size,
125+
min_block_size=min_block_size,
126+
torch_executed_ops=torch_executed_ops,
127+
pass_through_build_failures=pass_through_build_failures,
128+
max_aux_streams=max_aux_streams,
129+
version_compatible=version_compatible,
130+
optimization_level=optimization_level,
131+
use_python_runtime=use_python_runtime,
132+
)
133+
134+
model = trace(gm, inputs, **kwargs)
135+
136+
if kwargs.get("use_capability_partitioner", None):
137+
model = lower_model(model, inputs)
138+
return _compile_module(model, inputs, settings)
139+
else:
140+
split_result = lower_model_using_trt_splitter(model, inputs)
141+
trt_module = _compile_graph(split_result, inputs, settings)
142+
143+
return trt_module
144+
145+
146+
def create_backend(
147+
precision: LowerPrecision = PRECISION,
148+
debug: bool = DEBUG,
149+
workspace_size: int = WORKSPACE_SIZE,
150+
min_block_size: int = MIN_BLOCK_SIZE,
151+
torch_executed_ops: Sequence[str] = set(),
152+
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
153+
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
154+
version_compatible: bool = VERSION_COMPATIBLE,
155+
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
156+
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME,
157+
**kwargs,
158+
):
159+
"""Create torch.compile backend given specified arguments
160+
161+
Args:
162+
precision: Model Layer precision
163+
debug: Whether to print out verbose debugging information
164+
workspace_size: Workspace TRT is allowed to use for the module (0 is default)
165+
min_block_size: Minimum number of operators per TRT-Engine Block
166+
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
167+
pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
168+
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
169+
version_compatible: Provide version forward-compatibility for engine plan files
170+
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
171+
searching for more optimization options. TRT defaults to 3
172+
use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
173+
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
174+
argument as None
175+
Returns:
176+
Backend for torch.compile
177+
"""
178+
return partial(
179+
torch_tensorrt_backend,
101180
debug=debug,
102-
precision=lower_precision,
181+
precision=precision,
103182
workspace_size=workspace_size,
104183
min_block_size=min_block_size,
105184
torch_executed_ops=torch_executed_ops,
106-
pass_through_build_failures=False,
185+
pass_through_build_failures=pass_through_build_failures,
107186
max_aux_streams=max_aux_streams,
108187
version_compatible=version_compatible,
109188
optimization_level=optimization_level,
110189
use_python_runtime=use_python_runtime,
190+
**kwargs,
111191
)
112192

113-
model = trace(gm, inputs, **kwargs)
114-
115-
if kwargs.get("use_capability_partitioner", None):
116-
model = lower_model(model, inputs)
117-
return _compile_module(model, inputs, settings)
118-
else:
119-
split_result = lower_model_using_trt_splitter(model, inputs)
120-
trt_module = _compile_graph(split_result, inputs, settings)
121-
122-
return trt_module
123-
124193

125194
def _compile_graph(
126195
split_result: TRTSplitter,

py/torch_tensorrt/dynamo/conversion/conversion.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Sequence, Union
22
import torch
33
import io
4-
from torch_tensorrt.fx.trt_module import TRTModule
4+
from torch_tensorrt.dynamo.runtime import TRTModule
55
from torch_tensorrt.dynamo import CompilationSettings
66
from torch_tensorrt import Input
77
from torch_tensorrt.dynamo.conversion import TRTInterpreter
@@ -60,7 +60,7 @@ def convert_module(
6060
)
6161

6262
else:
63-
from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule
63+
from torch_tensorrt.dynamo.runtime import TorchTensorRTModule
6464

6565
with io.BytesIO() as engine_bytes:
6666
engine_bytes.write(interpreter_result.engine.serialize())

0 commit comments

Comments
 (0)