Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add attributes support for onnx.nms #3920

Merged
merged 2 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 75 additions & 52 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3688,6 +3688,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
patterns.onOp(
"NonMaxSuppression", 10,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Location loc = binder.getLoc();
Torch::ValueTensorType resultType;
SmallVector<Value> operands;
int64_t centerPointBox;
Expand All @@ -3702,96 +3703,132 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, "unimplemented: expected center_point_box "
"attribute value to be 0");

// TODO: Add support for optional arguments to be absent.
if (operands.size() < 4)
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: expected at least 4 arguments");

// TODO: Support multiple batches and classes
// Squeeze the boxes and scores tensor.
// In Onnx, the shape of boxes is [BxNx4] while the
// torchvision expects it to be of shape [Nx4]. Similarly, for
// the scores tensor shape in Onnx is [BxCxN] while the
// torchvision expects it to be of shape [N].
Value boxes = operands[0], scores = operands[1];
FailureOr<Value> squeezedBoxes = Torch::squeezeTensor(
rewriter, binder.op, binder.getLoc(), 0, boxes);
FailureOr<Value> squeezedBoxes =
Torch::squeezeTensor(rewriter, binder.op, loc, 0, boxes);
if (failed(squeezedBoxes))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze boxes tensor");

FailureOr<Value> squeezedScores = Torch::squeezeTensor(
rewriter, binder.op, binder.getLoc(), 0, scores);
FailureOr<Value> squeezedScores =
Torch::squeezeTensor(rewriter, binder.op, loc, 0, scores);
if (failed(squeezedScores))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze scores tensor");
squeezedScores = Torch::squeezeTensor(
rewriter, binder.op, binder.getLoc(), 0, squeezedScores.value());
squeezedScores = Torch::squeezeTensor(rewriter, binder.op, loc, 0,
squeezedScores.value());
if (failed(squeezedScores))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze scores tensor");

boxes = squeezedBoxes.value();
scores = squeezedScores.value();

// TODO: Support score_threshold input
// Filter out the boxes if the score < score_threshold
if (operands.size() == 5) {
Value scoreThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
operands[4]);
loc, rewriter.getType<Torch::FloatType>(), operands[4]);
Value minScores = rewriter.create<Torch::AtenMinOp>(
binder.getLoc(),
loc,
Torch::ValueTensorType::get(binder.op->getContext(),
SmallVector<int64_t>{},
rewriter.getF32Type()),
scores);
minScores = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), minScores);
loc, rewriter.getType<Torch::FloatType>(), minScores);

Value scoresCond = rewriter.create<Torch::AtenGeFloatOp>(
binder.getLoc(), minScores, scoreThreshold);
loc, minScores, scoreThreshold);
rewriter.create<Torch::RuntimeAssertOp>(
binder.getLoc(), scoresCond,
loc, scoresCond,
rewriter.getStringAttr(
"unimplemented: score_threshold should be <= min(scores)"));
}

// TODO: Support default iou_threshold
Value iouThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[3]);
// Get max_output_boxes_per_class and iou_threshold
Value cst0 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value cst1 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value maxOutputBoxesPerClass = cst0;
Value iouThreshold = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(0.0));
if (operands.size() > 3 &&
!isa<Torch::NoneType>(operands[3].getType())) {
iouThreshold = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::FloatType>(), operands[3]);
}
if (operands.size() > 2 &&
!isa<Torch::NoneType>(operands[2].getType())) {
maxOutputBoxesPerClass = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), operands[2]);
}

auto nmsTy = Torch::ValueTensorType::get(
binder.op->getContext(), SmallVector<int64_t>{-1},
rewriter.getIntegerType(64, /*signed=*/true));
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
loc, nmsTy, boxes, scores, iouThreshold);

// Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
Value numOutputBoxes =
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
Value boxesCond = rewriter.create<Torch::AtenGtIntOp>(
loc, numOutputBoxes, maxOutputBoxesPerClass);

auto nmsResultTy = Torch::ValueTensorType::get(
binder.op->getContext(),
SmallVector<int64_t>{resultType.getSizes()[0]},
rewriter.getIntegerType(64, /*signed=*/true));
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
binder.getLoc(), nmsTy, boxes, scores, iouThreshold);
auto ifSlice = rewriter.create<Torch::PrimIfOp>(
loc, TypeRange({nmsResultTy}), boxesCond);
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifSlice.getThenRegion(),
ifSlice.getThenRegion().begin());

Value curResult = rewriter.create<Torch::AtenSliceTensorOp>(
loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0,
/*end=*/maxOutputBoxesPerClass, /*step=*/cst1);
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
}
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifSlice.getElseRegion(),
ifSlice.getElseRegion().begin());

Value curResult = rewriter.create<Torch::TensorStaticInfoCastOp>(
loc, nmsResultTy, result);
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
}
result = ifSlice.getResult(0);

// The result generated by torchvision.nms op is of shape [n], while the
// onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor
// and make it of shape [n, 1] and then concatenate it with a zero
// tensor of shape [n, 2] to make it of shape [n, 3].
Value dim = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
FailureOr<Value> unsqueezedResult =
Torch::unsqueezeTensor(rewriter, binder.op, result, dim);
Torch::unsqueezeTensor(rewriter, binder.op, result, cst1);
if (failed(unsqueezedResult))
return rewriter.notifyMatchFailure(
binder.op, "failed to unsqueeze result tensor");
result = unsqueezedResult.value();

Value numOutputBoxes = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), result,
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
numOutputBoxes =
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
SmallVector<Value> zerosShapeValues{numOutputBoxes};
zerosShapeValues.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2)));
loc, rewriter.getI64IntegerAttr(2)));
Value zerosShapeList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
loc,
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
zerosShapeValues);

std::optional<ArrayRef<int64_t>> resultShape =
cast<Torch::ValueTensorType>(result.getType()).getOptionalSizes();
if (!resultShape.has_value())
Expand All @@ -3800,33 +3837,19 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
llvm::SmallVector<int64_t> zerosShape = {resultShape->front(), 2};
auto zerosTy = Torch::ValueTensorType::get(
resultType.getContext(), zerosShape, resultType.getOptionalDtype());
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
Value zeros = rewriter.create<Torch::AtenZerosOp>(
binder.getLoc(), zerosTy, zerosShapeList, cstNone, cstNone, cstNone,
cstNone);
loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone);

Type listElemType =
cast<Torch::BaseTensorType>(resultType)
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), listType, SmallVector<Value>{zeros, result});

// TODO: Support max_output_boxes_per_class input
// Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
Value maxOutputBoxesPerClass = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), operands[2]);
Value boxesCond = rewriter.create<Torch::AtenLeIntOp>(
binder.getLoc(), numOutputBoxes, maxOutputBoxesPerClass);
rewriter.create<Torch::RuntimeAssertOp>(
binder.getLoc(), boxesCond,
rewriter.getStringAttr(
"unimplemented: number of output boxes per class should be "
"<= max_output_boxes_per_class"));

loc, listType, SmallVector<Value>{zeros, result});
rewriter.replaceOpWithNewOp<Torch::AtenCatOp>(binder.op, resultType,
tensorList, dim);
tensorList, cst1);
return success();
});
}
Loading
Loading