diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index efb9fd6b6b..513124eb3d 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -64,7 +64,9 @@ void getSegmentsOutputByRunning( // set inputs ivalues, now supports Tensor/Int to pass argumentes between different segments for (auto& input : seg_block.raw_inputs()) { TRTORCH_CHECK(ivalues_maps.count(input), "Could not find mini graph input IValue " << input->debugName()); - if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) { + if (input->node()->kind() == torch::jit::prim::Param) { + jit_inputs_ivalues.push_back(ivalues_maps[input]); + } else if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) { jit_inputs_ivalues.push_back(ivalues_maps[input].toTensor()); } else if (input->type()->isSubtypeOf(torch::jit::IntType::get())) { jit_inputs_ivalues.push_back(ivalues_maps[input].toInt());