77import torch
88import torch_tensorrt
99from torch .fx .passes .pass_manager import PassManager
10- from torch .fx .passes .splitter_base import SplitResult
1110from torch_tensorrt ._Device import Device
1211from torch_tensorrt ._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
1312 EngineCapability ,
2120 PASS_THROUGH_BUILD_FAILURES ,
2221 PRECISION ,
2322 TRUNCATE_LONG_AND_DOUBLE ,
23+ USE_FAST_PARTITIONER ,
2424 USE_PYTHON_RUNTIME ,
2525 VERSION_COMPATIBLE ,
2626 WORKSPACE_SIZE ,
2727)
2828from torch_tensorrt .dynamo .backend .backends import _compile_module
29- from torch_tensorrt .dynamo .conversion import convert_module
3029from torch_tensorrt .dynamo .lowering ._fusers import (
3130 fuse_permute_linear ,
3231 fuse_permute_matmul ,
3332)
3433from torch_tensorrt .dynamo .utils import prepare_device , prepare_inputs
35- from torch_tensorrt .fx .tools .trt_splitter import TRTSplitter , TRTSplitterSetting
3634
3735logger = logging .getLogger (__name__ )
3836
@@ -64,6 +62,7 @@ def compile(
6462 version_compatible : bool = VERSION_COMPATIBLE ,
6563 optimization_level : Optional [int ] = OPTIMIZATION_LEVEL ,
6664 use_python_runtime : bool = USE_PYTHON_RUNTIME ,
65+ use_fast_partitioner : bool = USE_FAST_PARTITIONER ,
6766 ** kwargs : Any ,
6867) -> torch .fx .GraphModule :
6968 if debug :
@@ -75,7 +74,7 @@ def compile(
7574 "The Dynamo backend is an experimental feature, for which only the "
7675 + "following arguments are supported: "
7776 + "{enabled_precisions, debug, workspace_size, min_block_size, "
78- + "torch_executed_ops, pass_through_build_failures}"
77+ + "torch_executed_ops, pass_through_build_failures, use_fast_partitioner }"
7978 )
8079
8180 if not isinstance (inputs , collections .abc .Sequence ):
@@ -115,55 +114,12 @@ def compile(
115114 "optimization_level" : optimization_level ,
116115 "use_python_runtime" : use_python_runtime ,
117116 "truncate_long_and_double" : truncate_long_and_double ,
117+ "use_fast_partitioner" : use_fast_partitioner ,
118118 }
119119
120120 settings = CompilationSettings (** compilation_options )
121- if kwargs .get ("use_capability_partitioner" , None ):
122- model = lower_model (gm , torch_inputs )
123- return _compile_module (model , torch_inputs , settings )
124- else :
125- split_result = lower_model_using_trt_splitter (gm , torch_inputs )
126- trt_module = _compile_graph (split_result , torch_inputs , settings )
127-
128- return trt_module
129121
130-
131- def _compile_graph (
132- split_result : SplitResult ,
133- inputs : Any ,
134- settings : CompilationSettings = CompilationSettings (),
135- ** kwargs : Any ,
136- ) -> torch .fx .GraphModule :
137- for submod_name , submod_inputs in split_result .submodule_inputs .items ():
138- submod = getattr (split_result .split_module , submod_name )
139- # Only acc submodules will be lowered.
140- if not submod_name .startswith (split_result .non_acc_submodule_prefix ):
141- # Create TRT Module from submodule
142- trt_mod = convert_module (
143- submod ,
144- submod_inputs ,
145- settings = settings ,
146- name = submod_name ,
147- )
148- setattr (split_result .split_module , submod_name , trt_mod )
149-
150- return split_result .split_module
151-
152-
153- def lower_model_using_trt_splitter (
154- model : torch .nn .Module , inputs : Any , ** kwargs : Any
155- ) -> SplitResult :
156- # Perform basic lowering
157- model = lower_model (model , inputs )
158- splitter_setting = TRTSplitterSetting ()
159- splitter_setting .use_implicit_batch_dim = False
160- splitter_setting .min_acc_module_size = 1
161- splitter_setting .use_experimental_rt = False
162- splitter = TRTSplitter (model , inputs , settings = splitter_setting )
163- splitter .node_support_preview ()
164- split_result = splitter .generate_split_results ()
165-
166- return split_result
122+ return _compile_module (gm , torch_inputs , settings )
167123
168124
169125def lower_model (
0 commit comments