From e0be6e463e90570f317c8d0115736cece4c3fc68 Mon Sep 17 00:00:00 2001 From: jinchen62 Date: Mon, 25 Nov 2024 21:13:03 -0800 Subject: [PATCH] Fix concat order of result --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 19 +++++++------------ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 4 ++-- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index fccbbc2921f3..a67727253276 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -3717,10 +3717,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( boxes = squeezedBoxes.value(); scores = squeezedScores.value(); - // TODO: Add support for handling score_threshold arg. - // If score_threshold > min(scores) then the op can't be lowered since - // the torchvision::nms op doesn't have support for handling the - // score_threshold arg. + // TODO: Support score_threshold input + // Filter out the boxes if the score < score_threshold if (operands.size() == 5) { Value scoreThreshold = rewriter.create( binder.getLoc(), rewriter.getType(), @@ -3742,6 +3740,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( "unimplemented: score_threshold should be <= min(scores)")); } + // TODO: Support default iou_threshold Value iouThreshold = rewriter.create( binder.getLoc(), rewriter.getType(), operands[3]); auto nmsTy = Torch::ValueTensorType::get( @@ -3796,14 +3795,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); Value tensorList = rewriter.create( - binder.op->getLoc(), listType, SmallVector{result, zeros}); - - // TODO: Add support for handling max_output_boxes_per_class arg. - // If numOutputBoxes (N) > max_output_boxes_per_class then the op can't - // be lowered since the torchvision::nms op doesn't have support for - // handling the max_output_boxes_per_class arg. Also, we have already - // constrained the number of classes to be 1 above, so the number of - // output boxes inferred from the result is num_output_boxes_per_class. + binder.getLoc(), listType, SmallVector{zeros, result}); + + // TODO: Support max_output_boxes_per_class input + // Slice the result if numOutputBoxes (N) > max_output_boxes_per_class Value maxOutputBoxesPerClass = rewriter.create( binder.getLoc(), rewriter.getType(), operands[2]); Value boxesCond = rewriter.create( diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 33dc51f14cce..20f4a85b9f54 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -2068,7 +2068,7 @@ func.func @test_nonmaxsuppression_identical_boxes(%arg0: !torch.vtensor<[1,10,4] // CHECK: %[[VAL_31:.*]] = torch.prim.ListConstruct %[[VAL_29]], %[[VAL_30]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_32:.*]] = torch.constant.none // CHECK: %[[VAL_33:.*]] = torch.aten.zeros %[[VAL_31]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> - // CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_27]], %[[VAL_33]] : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[1,2],si64>) -> !torch.list + // CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_33]], %[[VAL_27]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list // CHECK: %[[VAL_35:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int // CHECK: %[[VAL_36:.*]] = torch.aten.le.int %[[VAL_29]], %[[VAL_35]] : !torch.int, !torch.int -> !torch.bool // CHECK: torch.runtime.assert %[[VAL_36]], "unimplemented: number of output boxes per class should be <= max_output_boxes_per_class" @@ -2120,7 +2120,7 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>, // CHECK: %[[VAL_31:.*]] = torch.prim.ListConstruct %[[VAL_29]], %[[VAL_30]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_32:.*]] = torch.constant.none // CHECK: %[[VAL_33:.*]] = torch.aten.zeros %[[VAL_31]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> - // CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_27]], %[[VAL_33]] : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[1,2],si64>) -> !torch.list + // CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_33]], %[[VAL_27]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list // CHECK: %[[VAL_35:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int // CHECK: %[[VAL_36:.*]] = torch.aten.le.int %[[VAL_29]], %[[VAL_35]] : !torch.int, !torch.int -> !torch.bool // CHECK: torch.runtime.assert %[[VAL_36]], "unimplemented: number of output boxes per class should be <= max_output_boxes_per_class"