Skip to content

Commit

Permalink
Merge branch 'master' into pd_enable_sin_and_cos
Browse files Browse the repository at this point in the history
  • Loading branch information
xczhai authored Aug 28, 2023
2 parents 80cbcc5 + 94c21b5 commit ba6314d
Show file tree
Hide file tree
Showing 6 changed files with 1,526 additions and 1,298 deletions.
63 changes: 24 additions & 39 deletions src/core/shape_inference/include/nms_shape_inference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ std::vector<TRShape> shape_infer(const Node* op,
const auto& boxes_shape = input_shapes[0];
const auto& scores_shape = input_shapes[1];

auto output_shapes = std::vector<TRShape>{TRShape{TDim(-1), 3}};
auto output_shapes = std::vector<TRShape>{TRShape{TDim(dim::inf_bound), 3}};
if (boxes_shape.rank().is_static()) {
const auto max_out_boxes_per_class = get_input_const_data_as<TRShape, int64_t>(op, 2, ta);
auto max_out_class_boxes =
Expand Down Expand Up @@ -147,6 +147,7 @@ std::vector<TRShape> shape_infer(const Node* op,
NODE_VALIDATION_CHECK(op, cmp::Between<size_t>(1, 7)(inputs_size));
using TDim = typename TRShape::value_type;
using V = typename TDim::value_type;
using namespace ov::util;

nms::validate::boxes_shape(op, input_shapes);
nms::validate::scores_shape(op, input_shapes);
Expand Down Expand Up @@ -180,44 +181,31 @@ std::vector<TRShape> shape_infer(const Node* op,
}

const auto& boxes_shape = input_shapes[0];
const auto& scores_shape = input_shapes[1];
const auto boxes_rank = boxes_shape.rank();
const auto scores_rank = scores_shape.rank();

auto out_shape = TRShape{TDim(-1), 3};
if (boxes_rank.is_static()) {
int64_t max_out_boxes_per_class_val;
if (const auto max_out_boxes_per_class = get_input_const_data_as<TRShape, int64_t>(op, 2, ta)) {
max_out_boxes_per_class_val = max_out_boxes_per_class->front();
} else {
max_out_boxes_per_class_val = -1;
}
const auto& num_boxes = boxes_shape[1];
auto& selected_boxes = out_shape[0];
if (num_boxes.is_static()) {
const auto min_selected_boxes =
std::min(num_boxes.get_length(), static_cast<V>(max_out_boxes_per_class_val));
selected_boxes = static_output ? TDim{min_selected_boxes} : TDim{0, min_selected_boxes};
} else if (scores_rank.is_static() && num_boxes.get_max_length() != -1 &&
scores_shape[0].get_max_length() != -1 && scores_shape[1].get_max_length() != -1) {
const auto min_selected_boxes =
std::min(num_boxes.get_max_length(), static_cast<V>(max_out_boxes_per_class_val));
selected_boxes = static_output ? TDim{min_selected_boxes} : TDim{0, min_selected_boxes};
}

if (scores_rank.is_static()) {
auto out_shape = TRShape{TDim(dim::inf_bound), 3};
if (boxes_shape.rank().is_static()) {
const auto& scores_shape = input_shapes[1];

if (scores_shape.rank().is_static()) {
nms::validate::num_batches(op, input_shapes);
nms::validate::num_boxes(op, input_shapes);

auto& selected_boxes = out_shape[0];
if (const auto max_out_boxes_per_class = get_input_const_data_as<TRShape, int64_t>(op, 2, ta)) {
const auto& num_boxes = boxes_shape[1];
const auto min_selected_boxes =
std::min(num_boxes.get_max_length(), static_cast<V>(max_out_boxes_per_class->front()));
selected_boxes = static_output ? TDim{min_selected_boxes} : TDim{0, min_selected_boxes};
}

selected_boxes *= scores_shape[0].get_max_length();
selected_boxes *= scores_shape[1].get_max_length();
}

nms::validate::boxes_last_dim(op, input_shapes);
}

auto output_shapes = std::vector<TRShape>(2, out_shape);
output_shapes.emplace_back(std::initializer_list<typename TRShape::value_type>{1});
output_shapes.emplace_back(std::initializer_list<V>{1});
return output_shapes;
}
} // namespace nms
Expand Down Expand Up @@ -257,20 +245,17 @@ std::vector<TRShape> shape_infer(const NonMaxSuppression* op,
const auto& boxes_shape = input_shapes[0];
const auto& scores_shape = input_shapes[1];

auto output_shapes = std::vector<TRShape>{TRShape{TDim(-1), 3}};
if (boxes_shape.rank().is_static()) {
const auto max_out_boxes_per_class = get_input_const_data_as<TRShape, int64_t>(op, 2, ta);
const auto max_out_class_boxes = max_out_boxes_per_class ? max_out_boxes_per_class->front() : dim::inf_bound;
auto output_shapes = std::vector<TRShape>{TRShape{TDim(dim::inf_bound), 3}};

if (boxes_shape.rank().is_static() && scores_shape.rank().is_static()) {
const auto& num_boxes = boxes_shape[1];
auto& selected_boxes = output_shapes[0][0];
if (num_boxes.is_static()) {
selected_boxes = std::min(num_boxes.get_length(), static_cast<V>(max_out_class_boxes));
}

if (scores_shape.rank().is_static()) {
selected_boxes *= scores_shape[0].get_max_length();
selected_boxes *= scores_shape[1].get_max_length();
if (const auto max_out_boxes_per_class = get_input_const_data_as<TRShape, int64_t>(op, 2, ta)) {
auto& selected_boxes = output_shapes[0][0];
selected_boxes = std::min(num_boxes.get_length(), static_cast<V>(max_out_boxes_per_class->front()));
selected_boxes *= scores_shape[0].get_max_length();
selected_boxes *= scores_shape[1].get_max_length();
}
}
}

Expand Down
Loading

0 comments on commit ba6314d

Please sign in to comment.