diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index fcd4c5991cbc..e9f7dbd5c465 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1809,10 +1809,16 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( flattenedIndices = rewriter.create( loc, flattenIndicesTy, reshapedIndices, constZero); } else if (indicesRank > 1) { - Value endDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(indicesRank - 2)); - flattenedIndices = rewriter.create( - loc, flattenIndicesTy, reshapedIndices, batchDimCountVal, endDim); + if (batchDimCount > indicesRank - 2) { + flattenedIndices = rewriter.create( + loc, flattenIndicesTy, reshapedIndices, batchDimCountVal); + } else { + Value endDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(indicesRank - 2)); + flattenedIndices = rewriter.create( + loc, flattenIndicesTy, reshapedIndices, batchDimCountVal, + endDim); + } } // step 8. Expand `r-b-indices_shape[-1]` dims of flattened indices. @@ -1834,8 +1840,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value endDim = rewriter.create( loc, rewriter.getI64IntegerAttr(batchDimCount + indicesLastDim - 1)); - Value flattenedData = rewriter.create( - loc, flattenDataTy, data, batchDimCountVal, endDim); + Value flattenedData = data; + + if (indicesLastDim != 1) { + flattenedData = rewriter.create( + loc, flattenDataTy, data, batchDimCountVal, endDim); + } // step 10. Now we have flattenedData and expandedIndices of same rank // to perform gather operation. @@ -1851,6 +1861,13 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, gather, /*dim=*/constZero); return success(); } + + if (unflattenIndicesDims.empty()) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, gather, /*dim=*/batchDimCountVal); + return success(); + } + Value unflattenSizeList = rewriter.create( loc, intListTy, unflattenIndicesDims); rewriter.replaceOpWithNewOp( diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index b4ba9b93861d..59f82964a02b 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -180,6 +180,41 @@ func.func @test_gather_nd_1D_indices(%arg0: !torch.vtensor<[2,6,8,5],f32>, %arg1 // ----- +// CHECK-LABEL: func.func @test_gathernd_example_int32_batch_dim1 +func.func @test_gathernd_example_int32_batch_dim1(%arg0: !torch.vtensor<[2,2,2],si32>, %arg1: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,2],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[DIM0:.+]] = torch.aten.size.int %arg0, %[[INT0]] + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[DIM1:.+]] = torch.aten.size.int %arg0, %[[INT1]] + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIM2:.+]] = torch.aten.size.int %arg0, %[[INT2]] + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 + // CHECK: %[[INT0_2:.+]] = torch.constant.int 0 + // CHECK: %[[B0:.+]] = torch.aten.size.int %arg1, %[[INT0_2]] + // CHECK: %[[INT1_3:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_4:.+]] = torch.constant.int 1 + // CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %arg1, %[[INT1_3]], %[[INT0_0]], %[[INT1_4]], %[[INT1_1]] + // CHECK: %[[LT:.+]] = torch.aten.lt.Scalar %[[SLICE]], %[[INT0_0]] + // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[SLICE]], %[[DIM1]], %[[INT1_1]] + // CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %[[SLICE]] + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[B0]], %[[INT1_1]] + // CHECK: %[[VIEW:.+]] = torch.aten.view %[[WHERE]], %[[LIST]] + // CHECK: %[[INT1_5:.+]] = torch.constant.int 1 + // CHECK: %[[UNSQ:.+]] = torch.aten.unsqueeze %[[VIEW]], %[[INT1_5]] + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[DIM0]], %[[INT1_1]], %[[DIM2]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[UNSQ]], %[[LIST]], %[[FALSE]] + // CHECK: %[[INT1_6:.+]] = torch.constant.int 1 + // CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT1_5]], %[[EXPAND]], %[[FALSE]] + // CHECK: %[[SQ:.+]] = torch.aten.squeeze.dim %[[GATHER]], %[[INT1_5]] + %none = torch.constant.none + %0 = torch.operator "onnx.GatherND"(%arg0, %arg1) {torch.onnx.batch_dims = 1 : si64} : (!torch.vtensor<[2,2,2],si32>, !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,2],si32> + return %0 : !torch.vtensor<[2,2],si32> +} + +// ----- + // CHECK-LABEL: func.func @test_gather_elements func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0