@@ -143,19 +143,14 @@ partitioning::GraphAndMapping BuildHybridGraph(
143143 auto convert_info = cfg.convert_info ;
144144 auto partitioning_info = cfg.partitioning_info ;
145145
146- // Any nonzero block size is valid if full compilation to TRT is desired
147- if (expect_full_compilation) {
148- partitioning_info.min_block_size = 1 ;
149- }
150-
151146 auto partitioning_ctx = partitioning::PartitioningCtx (block, partitioning_info);
152147 partitioning_ctx.input_types_map = first_use_types;
153148
154149 // Generate a dictionary of input torch::jit::Value's to their min, opt, max tensors and store in ctx
155150 // TODO: Combine this within partition call
156151 partitioning::populateInputIValues (&partitioning_ctx);
157152
158- partitioning::partition (&partitioning_ctx);
153+ partitioning::partition (&partitioning_ctx, expect_full_compilation );
159154
160155 for (auto & partitioned_block : partitioning_ctx.partitioned_blocks ) {
161156 partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second ;
@@ -197,9 +192,11 @@ partitioning::GraphAndMapping BuildHybridGraph(
197192 if (expect_full_compilation) {
198193 for (auto torch_node : seg_block.block ()->nodes ()) {
199194 if (partitioning::CollectionNodeKinds.find (torch_node->kind ()) == partitioning::CollectionNodeKinds.end ()) {
200- LOG_ERROR (
201- " Full compilation specified but node " << torch_node->kind ().toQualString ()
202- << " was executed in Torch." );
195+ TORCHTRT_THROW_ERROR (
196+ " Full compilation specified but node "
197+ << *torch_node
198+ << " is set to run in PyTorch due to either lack of support in TensorRT or graph partitioning rules."
199+ << " Try recompiling with require_full_compilation=False." );
203200 }
204201 }
205202 }
@@ -209,10 +206,9 @@ partitioning::GraphAndMapping BuildHybridGraph(
209206 // If full compilation is expected, cannot have more than 2 Torch segments
210207 // (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment
211208 if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1 )) {
212- LOG_ERROR (
213- " Full compilation specified but number of torch segments was "
214- << num_torch_segments << " and number of trt segments was " << num_trt_segments
215- << " . Was expecting at most 2 Torch segments and 1 TRT segment." );
209+ TORCHTRT_THROW_ERROR (
210+ " Full compilation was requested but unable to convert all operations to TensorRT."
211+ << " Try recompiling with require_full_compilation=False." );
216212 }
217213 }
218214
@@ -224,7 +220,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
224220 std::shared_ptr<torch::jit::Graph>& g,
225221 ir::StaticParams& static_params,
226222 ir::CollectionTypeMap& first_use_type_map,
227- bool expect_full_compilation = false ) {
223+ bool requires_collection_handling = false ) {
228224 cfg.convert_info .collection_input_spec_map =
229225 std::move (ir::associate_specs_with_collection_inputs (g, cfg.graph_inputs , static_params));
230226 cfg.partitioning_info .collection_input_spec_map =
@@ -259,7 +255,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
259255 " Cannot infer input type from calcuations in graph for input "
260256 << in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
261257 spec[i].dtype = at::kFloat ;
262- } else if (spec[i].dtype_is_user_defined && (cfg.partitioning_info .enabled || expect_full_compilation )) {
258+ } else if (spec[i].dtype_is_user_defined && (cfg.partitioning_info .enabled || requires_collection_handling )) {
263259 if (!est_type_opt[i]) {
264260 LOG_INFO (" Cannot infer input tensor dtype in graph, compiler is going to use the user setting" );
265261 std::stringstream ss;
@@ -330,6 +326,11 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
330326 return engine;
331327}
332328
329+ bool userRequestedFallback (CompileSpec& cfg) {
330+ return cfg.lower_info .forced_fallback_modules .size () != 0 ||
331+ cfg.partitioning_info .forced_fallback_operators .size () != 0 ;
332+ }
333+
333334torch::jit::Module CompileGraph (const torch::jit::Module& mod, CompileSpec cfg) {
334335 torch::jit::Module new_mod (mod._ivalue ()->name () + " _trt" );
335336
@@ -352,10 +353,13 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
352353 // whether full compilation can be expected
353354 auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
354355 auto outputIsCollection = conversion::OutputIsCollection (g->block ());
355- auto nearly_full_compilation = (isBlockConvertible && outputIsCollection);
356+ auto requires_collection_handling = (isBlockConvertible && outputIsCollection);
357+
358+ // Determine whether user specifications necessitate partitioning
359+ auto isFallbackRequested = userRequestedFallback (cfg);
356360
357361 // Extract map of IValue to DType
358- auto type_map = MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types, nearly_full_compilation );
362+ auto type_map = MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types, requires_collection_handling );
359363
360364 // Check whether any of the input types are Long
361365 bool user_requested_long = false ;
@@ -369,21 +373,26 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
369373 user_requested_long &= (casts_inserted > 0 );
370374 }
371375
372- if (cfg.partitioning_info .enabled && !user_requested_long &&
373- (cfg.lower_info .forced_fallback_modules .size () == 0 &&
374- cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) &&
375- !outputIsCollection) {
376+ // Partitioning is required if:
377+ // 1. User requested some modules/operators fallback
378+ // 2. The block (graph) cannot be converted due to operator coverage
379+ // 3. The output of the graph is a collection
380+ // 4. The user requested a non-TRT data type input
381+ auto isPartitioningRequired =
382+ (isFallbackRequested || !isBlockConvertible || outputIsCollection || user_requested_long);
383+
384+ // The user did not require full compilation, but the model can be fully compiled
385+ if (cfg.partitioning_info .enabled && !isPartitioningRequired) {
376386 LOG_INFO (" Skipping partitioning since model is fully supported" );
377387 }
378388
379- if ((cfg.partitioning_info .enabled &&
380- (!(cfg.lower_info .forced_fallback_modules .size () == 0 &&
381- cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) ||
382- outputIsCollection || user_requested_long)) ||
383- nearly_full_compilation) {
389+ // The user did not require full compilation, and the model can be fully compiled
390+ // or, the user required full compilation but the I/O of the graph use collections
391+ if ((cfg.partitioning_info .enabled && isPartitioningRequired) || requires_collection_handling) {
384392 // If the model is fully-compilable and the user has specified full compilation, run partitioning
385393 // to generate collection-processing code in Torch
386- auto expect_full_compilation = (nearly_full_compilation && !cfg.partitioning_info .enabled );
394+ auto expect_full_compilation = (requires_collection_handling && !cfg.partitioning_info .enabled );
395+
387396 auto graph_and_mapping =
388397 BuildHybridGraph (new_mod, g->block (), cfg, static_params, first_use_types, expect_full_compilation);
389398 new_g = graph_and_mapping.first ;
0 commit comments