@@ -138,10 +138,16 @@ partitioning::GraphAndMapping BuildHybridGraph(
138138 torch::jit::Block* block,
139139 CompileSpec cfg,
140140 ir::StaticParams static_params,
141- ir::CollectionTypeMap first_use_types) {
141+ ir::CollectionTypeMap first_use_types,
142+ bool expect_full_compilation = false ) {
142143 auto convert_info = cfg.convert_info ;
143144 auto partitioning_info = cfg.partitioning_info ;
144145
146+ // Any nonzero block size is valid if full compilation to TRT is desired
147+ if (expect_full_compilation) {
148+ partitioning_info.min_block_size = 1 ;
149+ }
150+
145151 auto partitioning_ctx = partitioning::PartitioningCtx (block, partitioning_info);
146152 partitioning_ctx.input_types_map = first_use_types;
147153
@@ -153,13 +159,16 @@ partitioning::GraphAndMapping BuildHybridGraph(
153159
154160 for (auto & partitioned_block : partitioning_ctx.partitioned_blocks ) {
155161 partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second ;
162+ int num_torch_segments = 0 ;
163+ int num_trt_segments = 0 ;
156164
157165 for (auto & seg_block : segmented_blocks) {
158166 LOG_INFO (" Block segment:" << seg_block);
159167 std::ostringstream trt_engine_id;
160168 trt_engine_id << reinterpret_cast <const int *>(&seg_block);
161169
162170 if (seg_block.target () == partitioning::SegmentedBlock::kTensorRT ) {
171+ num_trt_segments++;
163172 auto inputs = seg_block.construct_inputs_spec ();
164173 // update the input ranges for each segments
165174 convert_info.inputs = ir::associate_specs_with_inputs (seg_block.g (), inputs, static_params);
@@ -180,8 +189,32 @@ partitioning::GraphAndMapping BuildHybridGraph(
180189 true );
181190
182191 seg_block.update_graph (temp_g);
192+ } else {
193+ num_torch_segments++;
194+
195+ // If full compilation is expected, ensure that all operators in Torch blocks are
196+ // for collections processing
197+ if (expect_full_compilation) {
198+ for (auto torch_node : seg_block.block ()->nodes ()) {
199+ if (partitioning::CollectionSchemas.find (torch_node->kind ().toQualString ()) ==
200+ partitioning::CollectionSchemas.end ()) {
201+ LOG_WARNING (
202+ " Full compilation specified but node " << torch_node->kind ().toQualString ()
203+ << " was executed in Torch." );
204+ }
205+ }
206+ }
183207 }
184208 }
209+
210+ // If full compilation is expected, cannot have more than 2 Torch segments
211+ // (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment
212+ if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1 )) {
213+ LOG_WARNING (
214+ " Full compilation specified but number of torch segments was "
215+ << num_torch_segments << " and number of trt segments was " << num_trt_segments
216+ << " . Was expecting at most 2 Torch segments and 1 TRT segment." );
217+ }
185218 }
186219
187220 return partitioning::stitch (&partitioning_ctx, block);
@@ -191,7 +224,8 @@ ir::TypeMap MapInputsAndDetermineDTypes(
191224 CompileSpec& cfg,
192225 std::shared_ptr<torch::jit::Graph>& g,
193226 ir::StaticParams& static_params,
194- ir::CollectionTypeMap& first_use_type_map) {
227+ ir::CollectionTypeMap& first_use_type_map,
228+ bool expect_full_compilation = false ) {
195229 cfg.convert_info .collection_input_spec_map =
196230 std::move (ir::associate_specs_with_collection_inputs (g, cfg.graph_inputs , static_params));
197231 cfg.partitioning_info .collection_input_spec_map =
@@ -226,7 +260,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
226260 " Cannot infer input type from calcuations in graph for input "
227261 << in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
228262 spec[i].dtype = at::kFloat ;
229- } else if (spec[i].dtype_is_user_defined && cfg.partitioning_info .enabled ) {
263+ } else if (spec[i].dtype_is_user_defined && ( cfg.partitioning_info .enabled || expect_full_compilation) ) {
230264 if (!est_type_opt[i]) {
231265 LOG_INFO (" Cannot infer input tensor dtype in graph, compiler is going to use the user setting" );
232266 std::stringstream ss;
@@ -315,8 +349,14 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
315349 // Infer the type of an input from the weights of the calculation
316350 auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection (g->block ());
317351
352+ // Determine if the block is convertible/has collection output, and based on the result,
353+ // whether full compilation can be expected
354+ auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
355+ auto outputIsCollection = conversion::OutputIsCollection (g->block ());
356+ auto nearly_full_compilation = (isBlockConvertible && outputIsCollection);
357+
318358 // Extract map of IValue to DType
319- auto type_map = MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
359+ auto type_map = MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types, nearly_full_compilation );
320360
321361 // Check whether any of the input types are Long
322362 bool user_requested_long = false ;
@@ -330,20 +370,23 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
330370 user_requested_long &= (casts_inserted > 0 );
331371 }
332372
333- auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
334- auto outputIsCollection = conversion::OutputIsCollection (g->block ());
335373 if (cfg.partitioning_info .enabled && !user_requested_long &&
336374 (cfg.lower_info .forced_fallback_modules .size () == 0 &&
337375 cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) &&
338376 !outputIsCollection) {
339377 LOG_INFO (" Skipping partitioning since model is fully supported" );
340378 }
341379
342- if (cfg.partitioning_info .enabled &&
343- (!(cfg.lower_info .forced_fallback_modules .size () == 0 &&
344- cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) ||
345- outputIsCollection || user_requested_long)) {
346- auto graph_and_mapping = BuildHybridGraph (new_mod, g->block (), cfg, static_params, first_use_types);
380+ if ((cfg.partitioning_info .enabled &&
381+ (!(cfg.lower_info .forced_fallback_modules .size () == 0 &&
382+ cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) ||
383+ outputIsCollection || user_requested_long)) ||
384+ nearly_full_compilation) {
385+ // If the model is fully-compilable and the user has specified full compilation, run partitioning
386+ // to generate collection-processing code in Torch
387+ auto expect_full_compilation = (nearly_full_compilation && !cfg.partitioning_info .enabled );
388+ auto graph_and_mapping =
389+ BuildHybridGraph (new_mod, g->block (), cfg, static_params, first_use_types, expect_full_compilation);
347390 new_g = graph_and_mapping.first ;
348391 // renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
349392 for (size_t i = 0 ; i < new_g->inputs ().size (); ++i) {
0 commit comments