Skip to content

Commit

Permalink
Fix concat order of result
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 committed Nov 26, 2024
1 parent 7e69602 commit e0be6e4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 14 deletions.
19 changes: 7 additions & 12 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3717,10 +3717,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
boxes = squeezedBoxes.value();
scores = squeezedScores.value();

// TODO: Add support for handling score_threshold arg.
// If score_threshold > min(scores) then the op can't be lowered since
// the torchvision::nms op doesn't have support for handling the
// score_threshold arg.
// TODO: Support score_threshold input
// Filter out the boxes if the score < score_threshold
if (operands.size() == 5) {
Value scoreThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
Expand All @@ -3742,6 +3740,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
"unimplemented: score_threshold should be <= min(scores)"));
}

// TODO: Support default iou_threshold
Value iouThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[3]);
auto nmsTy = Torch::ValueTensorType::get(
Expand Down Expand Up @@ -3796,14 +3795,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
binder.op->getLoc(), listType, SmallVector<Value>{result, zeros});

// TODO: Add support for handling max_output_boxes_per_class arg.
// If numOutputBoxes (N) > max_output_boxes_per_class then the op can't
// be lowered since the torchvision::nms op doesn't have support for
// handling the max_output_boxes_per_class arg. Also, we have already
// constrained the number of classes to be 1 above, so the number of
// output boxes inferred from the result is num_output_boxes_per_class.
binder.getLoc(), listType, SmallVector<Value>{zeros, result});

// TODO: Support max_output_boxes_per_class input
// Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
Value maxOutputBoxesPerClass = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), operands[2]);
Value boxesCond = rewriter.create<Torch::AtenLeIntOp>(
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2068,7 +2068,7 @@ func.func @test_nonmaxsuppression_identical_boxes(%arg0: !torch.vtensor<[1,10,4]
// CHECK: %[[VAL_31:.*]] = torch.prim.ListConstruct %[[VAL_29]], %[[VAL_30]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_32:.*]] = torch.constant.none
// CHECK: %[[VAL_33:.*]] = torch.aten.zeros %[[VAL_31]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64>
// CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_27]], %[[VAL_33]] : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[1,2],si64>) -> !torch.list<vtensor>
// CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_33]], %[[VAL_27]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list<vtensor>
// CHECK: %[[VAL_35:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[VAL_36:.*]] = torch.aten.le.int %[[VAL_29]], %[[VAL_35]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_36]], "unimplemented: number of output boxes per class should be <= max_output_boxes_per_class"
Expand Down Expand Up @@ -2120,7 +2120,7 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>,
// CHECK: %[[VAL_31:.*]] = torch.prim.ListConstruct %[[VAL_29]], %[[VAL_30]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_32:.*]] = torch.constant.none
// CHECK: %[[VAL_33:.*]] = torch.aten.zeros %[[VAL_31]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64>
// CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_27]], %[[VAL_33]] : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[1,2],si64>) -> !torch.list<vtensor>
// CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_33]], %[[VAL_27]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list<vtensor>
// CHECK: %[[VAL_35:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[VAL_36:.*]] = torch.aten.le.int %[[VAL_29]], %[[VAL_35]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_36]], "unimplemented: number of output boxes per class should be <= max_output_boxes_per_class"
Expand Down

0 comments on commit e0be6e4

Please sign in to comment.