diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index a9db08621a..df4d30c397 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -115,7 +115,7 @@ std::vector segmentBlocksWithNonTensorInputs(SegmentedBlock& seg pytorch_nodes.push_back(n); prev_non_tensor_outputs = containNonTensorOutputs(n); } else { - // If pytorch_nodes is not empty, the previous nodes were all tensorrt_nodes. Construct a + // If pytorch_nodes is not empty, the previous nodes were all pytorch_nodes. Construct a // Pytorch segmented_block and clear the pytorch_nodes list to be later used for new Pytorch segments. if (!pytorch_nodes.empty()) { new_seg_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes); @@ -132,6 +132,7 @@ std::vector segmentBlocksWithNonTensorInputs(SegmentedBlock& seg new_seg_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes); } } + return std::move(new_seg_blocks); } @@ -159,6 +160,7 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shar } } + // For each non-tensor value in the usage_counts map, keep updating the produce_id to the earliest segmented block that has/produces it. for (auto& use : usage_counts) { // Set the produce_id to the segmented block index that contains/produces this non-tensor torch::jit::Value if (segmented_blocks[i].contain_raw_value(use.first)) { @@ -167,6 +169,7 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shar } } + std::unordered_set updated_segments; for (auto& use : usage_counts) { auto use_info = use.second; @@ -178,9 +181,8 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shar // Segmented Blocks with non-tensor inputs will have to be re-segmented as // TRTorch doesn't support non-tensor inputs for a module. auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[first_torch_id]); - segmented_blocks.erase(segmented_blocks.begin() + first_torch_id); - segmented_blocks.insert( - segmented_blocks.begin() + first_torch_id, to_inject_blocks.begin(), to_inject_blocks.end()); + auto next_iter = segmented_blocks_list.erase(idx_to_iter[first_torch_id]); + segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end()); updated_segments.insert(first_torch_id); } } @@ -314,6 +316,7 @@ std::vector segment_graph(torch::jit::Block* block, const Partit segmented_blocks.emplace_back(SegmentedBlock::kTorch, std::vector{n}); continue; } else if (n->kind() == torch::jit::prim::Loop) { + if (!pytorch_nodes.empty()) { segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes); pytorch_nodes.clear(); @@ -347,19 +350,15 @@ std::vector Partition( const PartitionInfo& partition_info) { LOG_DEBUG(partition_info); // segment lowering global graph into blocks - LOG_DEBUG("Partitioning graph into PyTorch and TensorRT segmented blocks"); std::vector segmented_blocks = segment_graph(block, partition_info); // resolve nonTensor inputs/outputs - LOG_DEBUG("Resolving non-tensor type inputs/outputs (eg: int/float types)"); resolveNonTensorInputs(segmented_blocks); // register input/output torch::jit::Value for segmented graphs - LOG_DEBUG("Registering input/outputs for segmented blocks"); registerSegmentsOutputs(segmented_blocks, block); // run shape analysis on each segmented block - LOG_DEBUG("Running shape analysis for all the segmented blocks"); runShapeAnalysis(segmented_blocks, input_ivalues_map); return segmented_blocks; diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index 9a2b3230dc..86d56bb5ca 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -56,7 +56,7 @@ void getSegmentsOutputByRunning( for (auto& input : seg_block.raw_inputs()) { TRTORCH_CHECK( ivalues_maps.count(input), - "Could not find torch::jit::Value* " << input->debugName() << " in lowering graph for mini graph input.\n"); + "Could not find torch::jit::Value* " << input->debugName() << " produced from " << util::node_info(input->node()) << " in lowering graph for mini graph input.\n"); if (input->node()->kind() == torch::jit::prim::Param) { jit_inputs_ivalues.push_back(ivalues_maps[input]); } else if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) { @@ -108,10 +108,8 @@ void runShapeAnalysis( std::unordered_map& ivalues_maps) { // register every segment's input shape, and it's running output IValues for (auto& seg_block : segmented_blocks) { - LOG_DEBUG("Segmented graph: " << *seg_block.g()); torch::jit::ConstantPooling(seg_block.g()); getSegmentsOutputByRunning(seg_block, ivalues_maps); - LOG_DEBUG("================="); } return; } diff --git a/tests/core/partitioning/test_loop_fallback.cpp b/tests/core/partitioning/test_loop_fallback.cpp index 23b0f5d1e4..d80e0beb60 100644 --- a/tests/core/partitioning/test_loop_fallback.cpp +++ b/tests/core/partitioning/test_loop_fallback.cpp @@ -33,30 +33,30 @@ TEST(Partitioning, CheckLoopFallbackEvalCompilesCorrectly) { ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6)); } -// TEST(Partitioning, CheckLoopFallbackNoEvalCompilesCorrectly) { -// torch::jit::script::Module mod; -// try { -// mod = torch::jit::load("tests/modules/loop_fallback_no_eval_scripted.jit.pt"); -// } catch (const c10::Error& e) { -// std::cerr << "error loading the model\n"; -// return; -// } -// -// const std::vector> input_shapes = {{1, 10}}; -// std::vector jit_inputs_ivalues; -// std::vector trt_inputs_ivalues; -// for (auto in_shape : input_shapes) { -// auto in = at::randint(5, in_shape, {at::kCUDA}); -// jit_inputs_ivalues.push_back(in.clone()); -// trt_inputs_ivalues.push_back(in.clone()); -// } -// -// std::vector input_ranges{trtorch::core::ir::Input({1, 10})}; -// trtorch::core::CompileSpec cfg(input_ranges); -// cfg.partition_info.enabled = true; -// -// auto jit_results = mod.forward(jit_inputs_ivalues).toTensor(); -// auto trt_mod = trtorch::core::CompileGraph(mod, cfg); -// auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor(); -// ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6)); -// } +TEST(Partitioning, CheckLoopFallbackNoEvalCompilesCorrectly) { + torch::jit::script::Module mod; + try { + mod = torch::jit::load("tests/modules/loop_fallback_no_eval_scripted.jit.pt"); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + return; + } + + const std::vector> input_shapes = {{1, 10}}; + std::vector jit_inputs_ivalues; + std::vector trt_inputs_ivalues; + for (auto in_shape : input_shapes) { + auto in = at::randint(5, in_shape, {at::kCUDA}); + jit_inputs_ivalues.push_back(in.clone()); + trt_inputs_ivalues.push_back(in.clone()); + } + + std::vector input_ranges{trtorch::core::ir::Input({1, 10})}; + trtorch::core::CompileSpec cfg(input_ranges); + cfg.partition_info.enabled = true; + + auto jit_results = mod.forward(jit_inputs_ivalues).toTensor(); + auto trt_mod = trtorch::core::CompileGraph(mod, cfg); + auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor(); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6)); +}