44import torch_tensorrt
55from functools import partial
66
7- from typing import Any , Sequence
7+ from typing import Any , Optional , Sequence
88from torch_tensorrt import EngineCapability , Device
99from torch_tensorrt .fx .utils import LowerPrecision
1010
1414from torch_tensorrt .dynamo .backend ._defaults import (
1515 PRECISION ,
1616 DEBUG ,
17- MAX_WORKSPACE_SIZE ,
17+ WORKSPACE_SIZE ,
1818 MIN_BLOCK_SIZE ,
1919 PASS_THROUGH_BUILD_FAILURES ,
20+ MAX_AUX_STREAMS ,
21+ VERSION_COMPATIBLE ,
22+ OPTIMIZATION_LEVEL ,
23+ USE_EXPERIMENTAL_RT ,
2024)
2125
2226
@@ -35,7 +39,7 @@ def compile(
3539 debug = DEBUG ,
3640 capability = EngineCapability .default ,
3741 num_avg_timing_iters = 1 ,
38- workspace_size = MAX_WORKSPACE_SIZE ,
42+ workspace_size = WORKSPACE_SIZE ,
3943 dla_sram_size = 1048576 ,
4044 dla_local_dram_size = 1073741824 ,
4145 dla_global_dram_size = 536870912 ,
@@ -45,6 +49,10 @@ def compile(
4549 min_block_size = MIN_BLOCK_SIZE ,
4650 torch_executed_ops = [],
4751 torch_executed_modules = [],
52+ max_aux_streams = MAX_AUX_STREAMS ,
53+ version_compatible = VERSION_COMPATIBLE ,
54+ optimization_level = OPTIMIZATION_LEVEL ,
55+ use_experimental_rt = USE_EXPERIMENTAL_RT ,
4856 ** kwargs ,
4957):
5058 if debug :
@@ -86,6 +94,10 @@ def compile(
8694 workspace_size = workspace_size ,
8795 min_block_size = min_block_size ,
8896 torch_executed_ops = torch_executed_ops ,
97+ max_aux_streams = max_aux_streams ,
98+ version_compatible = version_compatible ,
99+ optimization_level = optimization_level ,
100+ use_experimental_rt = use_experimental_rt ,
89101 ** kwargs ,
90102 )
91103
@@ -105,19 +117,30 @@ def compile(
105117def create_backend (
106118 precision : LowerPrecision = PRECISION ,
107119 debug : bool = DEBUG ,
108- workspace_size : int = MAX_WORKSPACE_SIZE ,
120+ workspace_size : int = WORKSPACE_SIZE ,
109121 min_block_size : int = MIN_BLOCK_SIZE ,
110122 torch_executed_ops : Sequence [str ] = set (),
111123 pass_through_build_failures : bool = PASS_THROUGH_BUILD_FAILURES ,
124+ max_aux_streams : Optional [int ] = MAX_AUX_STREAMS ,
125+ version_compatible : bool = VERSION_COMPATIBLE ,
126+ optimization_level : Optional [int ] = OPTIMIZATION_LEVEL ,
127+ use_experimental_rt : bool = USE_EXPERIMENTAL_RT ,
112128 ** kwargs ,
113129):
114130 """Create torch.compile backend given specified arguments
115131
116132 Args:
117133 precision:
118134 debug: Whether to print out verbose debugging information
119- workspace_size: Maximum workspace TRT is allowed to use for the module
120- precision: Model Layer precision
135+ workspace_size: Workspace TRT is allowed to use for the module (0 is default)
136+ min_block_size: Minimum number of operators per TRT-Engine Block
137+ torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
138+ pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
139+ max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
140+ version_compatible: Provide version forward-compatibility for engine plan files
141+ optimization_level: Builder optimization 0-5, higher levels imply longer build time,
142+ searching for more optimization options. TRT defaults to 3
143+ use_experimental_rt: Whether to use the new experimental TRTModuleNext for TRT engines
121144 Returns:
122145 Backend for torch.compile
123146 """
@@ -131,6 +154,10 @@ def create_backend(
131154 min_block_size = min_block_size ,
132155 torch_executed_ops = torch_executed_ops ,
133156 pass_through_build_failures = pass_through_build_failures ,
157+ max_aux_streams = max_aux_streams ,
158+ version_compatible = version_compatible ,
159+ optimization_level = optimization_level ,
160+ use_experimental_rt = use_experimental_rt ,
134161 )
135162
136163 return partial (
0 commit comments