Skip to content

Commit

Permalink
Support max_output_boxes_per_class attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 committed Dec 17, 2024
1 parent 36a3cf8 commit 9852aac
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 57 deletions.
69 changes: 45 additions & 24 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3703,6 +3703,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, "unimplemented: expected center_point_box "
"attribute value to be 0");

// TODO: Support multiple batches and classes
// Squeeze the boxes and scores tensor.
// In Onnx, the shape of boxes is [BxNx4] while the
// torchvision expects it to be of shape [Nx4]. Similarly, for
Expand All @@ -3719,8 +3720,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
if (failed(squeezedScores))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze scores tensor");
squeezedScores = Torch::squeezeTensor(
rewriter, binder.op, loc, 0, squeezedScores.value());
squeezedScores = Torch::squeezeTensor(rewriter, binder.op, loc, 0,
squeezedScores.value());
if (failed(squeezedScores))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze scores tensor");
Expand Down Expand Up @@ -3750,9 +3751,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
}

// Get max_output_boxes_per_class and iou_threshold
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
Value cst0 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value maxOutputBoxesPerClass = cstZero;
Value cst1 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value maxOutputBoxesPerClass = cst0;
Value iouThreshold = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(0.0));
if (operands.size() > 3 &&
Expand All @@ -3767,27 +3770,57 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
}

auto nmsTy = Torch::ValueTensorType::get(
binder.op->getContext(),
SmallVector<int64_t>{resultType.getSizes()[0]},
binder.op->getContext(), SmallVector<int64_t>{-1},
rewriter.getIntegerType(64, /*signed=*/true));
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
loc, nmsTy, boxes, scores, iouThreshold);

// Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
Value numOutputBoxes =
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
Value boxesCond = rewriter.create<Torch::AtenGtIntOp>(
loc, numOutputBoxes, maxOutputBoxesPerClass);

auto nmsResultTy = Torch::ValueTensorType::get(
binder.op->getContext(),
SmallVector<int64_t>{resultType.getSizes()[0]},
rewriter.getIntegerType(64, /*signed=*/true));
auto ifSlice = rewriter.create<Torch::PrimIfOp>(
loc, TypeRange({nmsResultTy}), boxesCond);
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifSlice.getThenRegion(),
ifSlice.getThenRegion().begin());

Value curResult = rewriter.create<Torch::AtenSliceTensorOp>(
loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0,
/*end=*/maxOutputBoxesPerClass, /*step=*/cst1);
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
}
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifSlice.getElseRegion(),
ifSlice.getElseRegion().begin());

Value curResult = rewriter.create<Torch::TensorStaticInfoCastOp>(
loc, nmsResultTy, result);
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
}
result = ifSlice.getResult(0);

// The result generated by torchvision.nms op is of shape [n], while the
// onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor
// and make it of shape [n, 1] and then concatenate it with a zero
// tensor of shape [n, 2] to make it of shape [n, 3].
Value dim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
FailureOr<Value> unsqueezedResult =
Torch::unsqueezeTensor(rewriter, binder.op, result, dim);
Torch::unsqueezeTensor(rewriter, binder.op, result, cst1);
if (failed(unsqueezedResult))
return rewriter.notifyMatchFailure(
binder.op, "failed to unsqueeze result tensor");
result = unsqueezedResult.value();

Value numOutputBoxes = rewriter.create<Torch::AtenSizeIntOp>(
loc, result, cstZero);
numOutputBoxes =
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
SmallVector<Value> zerosShapeValues{numOutputBoxes};
zerosShapeValues.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(2)));
Expand All @@ -3796,7 +3829,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
zerosShapeValues);

std::optional<ArrayRef<int64_t>> resultShape =
cast<Torch::ValueTensorType>(result.getType()).getOptionalSizes();
if (!resultShape.has_value())
Expand All @@ -3816,19 +3848,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Type listType = Torch::ListType::get(listElemType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
loc, listType, SmallVector<Value>{zeros, result});

// TODO: Support max_output_boxes_per_class input
// Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
Value boxesCond = rewriter.create<Torch::AtenLeIntOp>(
loc, numOutputBoxes, maxOutputBoxesPerClass);
rewriter.create<Torch::RuntimeAssertOp>(
loc, boxesCond,
rewriter.getStringAttr(
"unimplemented: number of output boxes per class should be "
"<= max_output_boxes_per_class"));

rewriter.replaceOpWithNewOp<Torch::AtenCatOp>(binder.op, resultType,
tensorList, dim);
tensorList, cst1);
return success();
});
}
81 changes: 48 additions & 33 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2057,22 +2057,30 @@ func.func @test_nonmaxsuppression_identical_boxes(%arg0: !torch.vtensor<[1,10,4]
// CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)"
// CHECK: %[[VAL_24:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[VAL_25:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_24]] : !torch.vtensor<[10,4],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[1],si64>
// CHECK: %[[VAL_26:.*]] = torch.constant.int 1
// CHECK: %[[VAL_27:.*]] = torch.aten.unsqueeze %[[VAL_25]], %[[VAL_26]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
// CHECK: %[[VAL_28:.*]] = torch.constant.int 0
// CHECK: %[[VAL_29:.*]] = torch.aten.size.int %[[VAL_27]], %[[VAL_28]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int
// CHECK: %[[VAL_30:.*]] = torch.constant.int 2
// CHECK: %[[VAL_31:.*]] = torch.prim.ListConstruct %[[VAL_29]], %[[VAL_30]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_32:.*]] = torch.constant.none
// CHECK: %[[VAL_33:.*]] = torch.aten.zeros %[[VAL_31]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64>
// CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_33]], %[[VAL_27]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list<vtensor>
// CHECK: %[[VAL_35:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[VAL_36:.*]] = torch.aten.le.int %[[VAL_29]], %[[VAL_35]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_36]], "unimplemented: number of output boxes per class should be <= max_output_boxes_per_class"
// CHECK: %[[VAL_37:.*]] = torch.aten.cat %[[VAL_34]], %[[VAL_26]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,3],si64>
// CHECK: return %[[VAL_37]] : !torch.vtensor<[1,3],si64>
// CHECK: %[[VAL_24:.*]] = torch.constant.int 0
// CHECK: %[[VAL_25:.*]] = torch.constant.int 1
// CHECK: %[[VAL_26:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[VAL_27:.*]] = torch.aten.item %arg3 : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[VAL_28:.*]] = torch.aten.item %arg2 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[VAL_29:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_27]] : !torch.vtensor<[10,4],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[?],si64>
// CHECK: %[[VAL_30:.*]] = torch.aten.size.int %[[VAL_29]], %[[VAL_24]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int
// CHECK: %[[VAL_31:.*]] = torch.aten.gt.int %[[VAL_30]], %[[VAL_28]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[VAL_32:.*]] = torch.prim.If %[[VAL_31]] -> (!torch.vtensor<[1],si64>)
// CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %[[VAL_29]], %[[VAL_24]], %[[VAL_24]], %[[VAL_28]], %[[VAL_25]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.prim.If.yield %[[SLICE]] : !torch.vtensor<[1],si64>
// CHECK: } else {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[VAL_29]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64>
// CHECK: torch.prim.If.yield %[[CAST]] : !torch.vtensor<[1],si64>
// CHECK: }
// CHECK: %[[VAL_33:.*]] = torch.aten.unsqueeze %[[VAL_32]], %[[VAL_25]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
// CHECK: %[[VAL_34:.*]] = torch.aten.size.int %[[VAL_33]], %[[VAL_24]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int
// CHECK: %[[VAL_35:.*]] = torch.constant.int 2
// CHECK: %[[VAL_36:.*]] = torch.prim.ListConstruct %[[VAL_34]], %[[VAL_35]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_37:.*]] = torch.constant.none
// CHECK: %[[VAL_38:.*]] = torch.aten.zeros %[[VAL_36]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64>
// CHECK: %[[VAL_39:.*]] = torch.prim.ListConstruct %[[VAL_38]], %[[VAL_33]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list<vtensor>
// CHECK: %[[VAL_40:.*]] = torch.aten.cat %[[VAL_39]], %[[VAL_25]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,3],si64>
// CHECK: return %[[VAL_40]] : !torch.vtensor<[1,3],si64>
%0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,10,4],f32>, !torch.vtensor<[1,1,10],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64>
return %0 : !torch.vtensor<[1,3],si64>
}
Expand Down Expand Up @@ -2109,23 +2117,30 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>,
// CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)"
// CHECK: %[[VAL_24:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[VAL_25:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_24]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],si64>
// CHECK: %[[VAL_26:.*]] = torch.constant.int 1
// CHECK: %[[VAL_27:.*]] = torch.aten.unsqueeze %[[VAL_25]], %[[VAL_26]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
// CHECK: %[[VAL_28:.*]] = torch.constant.int 0
// CHECK: %[[VAL_29:.*]] = torch.aten.size.int %[[VAL_27]], %[[VAL_28]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int
// CHECK: %[[VAL_30:.*]] = torch.constant.int 2
// CHECK: %[[VAL_31:.*]] = torch.prim.ListConstruct %[[VAL_29]], %[[VAL_30]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_32:.*]] = torch.constant.none
// CHECK: %[[VAL_33:.*]] = torch.aten.zeros %[[VAL_31]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64>
// CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_33]], %[[VAL_27]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list<vtensor>
// CHECK: %[[VAL_35:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[VAL_36:.*]] = torch.aten.le.int %[[VAL_29]], %[[VAL_35]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_36]], "unimplemented: number of output boxes per class should be <= max_output_boxes_per_class"
// CHECK: %[[VAL_37:.*]] = torch.aten.cat %[[VAL_34]], %[[VAL_26]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,3],si64>
// CHECK: return %[[VAL_37]] : !torch.vtensor<[1,3],si64>
// CHECK: }
// CHECK: %[[VAL_24:.*]] = torch.constant.int 0
// CHECK: %[[VAL_25:.*]] = torch.constant.int 1
// CHECK: %[[VAL_26:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[VAL_27:.*]] = torch.aten.item %arg3 : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[VAL_28:.*]] = torch.aten.item %arg2 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[VAL_29:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_27]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[?],si64>
// CHECK: %[[VAL_30:.*]] = torch.aten.size.int %[[VAL_29]], %[[VAL_24]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int
// CHECK: %[[VAL_31:.*]] = torch.aten.gt.int %[[VAL_30]], %[[VAL_28]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[VAL_32:.*]] = torch.prim.If %[[VAL_31]] -> (!torch.vtensor<[1],si64>)
// CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %[[VAL_29]], %[[VAL_24]], %[[VAL_24]], %[[VAL_28]], %[[VAL_25]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.prim.If.yield %[[SLICE]] : !torch.vtensor<[1],si64>
// CHECK: } else {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[VAL_29]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64>
// CHECK: torch.prim.If.yield %[[CAST]] : !torch.vtensor<[1],si64>
// CHECK: }
// CHECK: %[[VAL_33:.*]] = torch.aten.unsqueeze %[[VAL_32]], %[[VAL_25]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
// CHECK: %[[VAL_34:.*]] = torch.aten.size.int %[[VAL_33]], %[[VAL_24]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int
// CHECK: %[[VAL_35:.*]] = torch.constant.int 2
// CHECK: %[[VAL_36:.*]] = torch.prim.ListConstruct %[[VAL_34]], %[[VAL_35]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_37:.*]] = torch.constant.none
// CHECK: %[[VAL_38:.*]] = torch.aten.zeros %[[VAL_36]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64>
// CHECK: %[[VAL_39:.*]] = torch.prim.ListConstruct %[[VAL_38]], %[[VAL_33]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list<vtensor>
// CHECK: %[[VAL_40:.*]] = torch.aten.cat %[[VAL_39]], %[[VAL_25]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,3],si64>
// CHECK: return %[[VAL_40]] : !torch.vtensor<[1,3],si64>
%0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,1,4],f32>, !torch.vtensor<[1,1,1],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64>
return %0 : !torch.vtensor<[1,3],si64>
}
Expand Down

0 comments on commit 9852aac

Please sign in to comment.