Skip to content

Commit

Permalink
feat: handle scalar type of size [] in shape_analysis
Browse files Browse the repository at this point in the history
Signed-off-by: inocsin <vcheungyi@163.com>
  • Loading branch information
inocsin committed Oct 20, 2021
1 parent 606d4de commit fca53ce
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion core/partitioning/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,12 @@ void getSegmentsOutputByRunning(
if (dtype == c10::nullopt) {
TRTORCH_THROW_ERROR("Unsupported input data type " << ivalues_maps[i].toTensor().dtype());
}
input_shapes.push_back(util::toVec(util::toDims(ivalues_maps[i].toTensor().sizes())));
if (ivalues_maps[i].toTensor().sizes().size() == 0) {
// handle Scalar types, which has sizes of []
input_shapes.push_back(util::toVec(util::toDims(c10::List<long int>({1}))));
} else {
input_shapes.push_back(util::toVec(util::toDims(ivalues_maps[i].toTensor().sizes())));
}
input_types.push_back(ivalues_maps[i].toTensor().scalar_type());
}
}
Expand Down

0 comments on commit fca53ce

Please sign in to comment.