@@ -261,7 +261,7 @@ GraphAndMapping ConstructFallbackGraph(
261
261
if (seg_block.target () == partitioning::SegmentedBlock::kTensorRT ) {
262
262
std::vector<ir::Input> inputs;
263
263
for (auto & shape : seg_block.in_shape ()) {
264
- inputs.push_back (ir::InputRange (shape));
264
+ inputs.push_back (ir::Input (shape));
265
265
}
266
266
// update the input ranges for each segments
267
267
convert_cfg.inputs = inputs;
@@ -332,46 +332,6 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
332
332
return mod;
333
333
}
334
334
335
- // <<<<<<< HEAD
336
- // =======
337
- // std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
338
- // // add global graph's input to old_to_new_g mapping
339
- // for (auto input : g->inputs()) {
340
- // util::getOrAddInputForValue(input, new_g, old_to_new_g);
341
- // }
342
- // for (auto& seg_block : segmented_blocks) {
343
- // std::string cur_block_target =
344
- // seg_block.target() == partitioning::SegmentedBlock::kTensorRT ? "TensorRT" : "Torch";
345
- // LOG_INFO(*seg_block.g() << "(Sub Graph" << cur_block_target << "Block)\n");
346
- // std::ostringstream trt_engine_id;
347
- // trt_engine_id << reinterpret_cast<const int*>(&seg_block);
348
- // if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
349
- // std::vector<ir::Input> inputs;
350
- // for (auto& shape : seg_block.in_shape()) {
351
- // inputs.push_back(ir::Input(shape));
352
- // }
353
- // // update the input ranges for each segments
354
- // convert_cfg.inputs = inputs;
355
- // auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
356
- // auto temp_g = std::make_shared<torch::jit::Graph>();
357
- // auto device_spec = convert_cfg.engine_settings.device;
358
- // auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
359
- // AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
360
- //
361
- // seg_block.update_graph(temp_g);
362
- // AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
363
- // } else {
364
- // AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
365
- // }
366
- // }
367
- //
368
- // for (auto& output : g->outputs()) {
369
- // new_g->registerOutput(old_to_new_g[output]);
370
- // }
371
- //
372
- // LOG_INFO(*new_g << "(FallbackGraph)\n");
373
- //
374
- // >>>>>>> master
375
335
auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
376
336
auto schema = util::GenerateGraphSchema (new_method->name (), new_g);
377
337
new_mod.type ()->addMethod (new_method);
0 commit comments