Skip to content

Commit

Permalink
[onnx] Fix onnx.Gather for bad expansion (#3625)
Browse files Browse the repository at this point in the history
A case where unsqueeze was require was missed causing compilation
failures.
  • Loading branch information
rsuderman authored Aug 13, 2024
1 parent 9ab9343 commit 39307f0
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 6 deletions.
29 changes: 23 additions & 6 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1809,10 +1809,16 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
flattenedIndices = rewriter.create<Torch::AtenUnsqueezeOp>(
loc, flattenIndicesTy, reshapedIndices, constZero);
} else if (indicesRank > 1) {
Value endDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(indicesRank - 2));
flattenedIndices = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flattenIndicesTy, reshapedIndices, batchDimCountVal, endDim);
if (batchDimCount > indicesRank - 2) {
flattenedIndices = rewriter.create<Torch::AtenUnsqueezeOp>(
loc, flattenIndicesTy, reshapedIndices, batchDimCountVal);
} else {
Value endDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(indicesRank - 2));
flattenedIndices = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flattenIndicesTy, reshapedIndices, batchDimCountVal,
endDim);
}
}

// step 8. Expand `r-b-indices_shape[-1]` dims of flattened indices.
Expand All @@ -1834,8 +1840,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value endDim = rewriter.create<Torch::ConstantIntOp>(
loc,
rewriter.getI64IntegerAttr(batchDimCount + indicesLastDim - 1));
Value flattenedData = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flattenDataTy, data, batchDimCountVal, endDim);
Value flattenedData = data;

if (indicesLastDim != 1) {
flattenedData = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flattenDataTy, data, batchDimCountVal, endDim);
}

// step 10. Now we have flattenedData and expandedIndices of same rank
// to perform gather operation.
Expand All @@ -1851,6 +1861,13 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, gather, /*dim=*/constZero);
return success();
}

if (unflattenIndicesDims.empty()) {
rewriter.replaceOpWithNewOp<Torch::AtenSqueezeDimOp>(
binder.op, resultType, gather, /*dim=*/batchDimCountVal);
return success();
}

Value unflattenSizeList = rewriter.create<Torch::PrimListConstructOp>(
loc, intListTy, unflattenIndicesDims);
rewriter.replaceOpWithNewOp<Torch::AtenUnflattenIntOp>(
Expand Down
35 changes: 35 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,41 @@ func.func @test_gather_nd_1D_indices(%arg0: !torch.vtensor<[2,6,8,5],f32>, %arg1

// -----

// CHECK-LABEL: func.func @test_gathernd_example_int32_batch_dim1
func.func @test_gathernd_example_int32_batch_dim1(%arg0: !torch.vtensor<[2,2,2],si32>, %arg1: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,2],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[DIM0:.+]] = torch.aten.size.int %arg0, %[[INT0]]
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: %[[DIM1:.+]] = torch.aten.size.int %arg0, %[[INT1]]
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: %[[DIM2:.+]] = torch.aten.size.int %arg0, %[[INT2]]
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
// CHECK: %[[INT0_2:.+]] = torch.constant.int 0
// CHECK: %[[B0:.+]] = torch.aten.size.int %arg1, %[[INT0_2]]
// CHECK: %[[INT1_3:.+]] = torch.constant.int 1
// CHECK: %[[INT1_4:.+]] = torch.constant.int 1
// CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %arg1, %[[INT1_3]], %[[INT0_0]], %[[INT1_4]], %[[INT1_1]]
// CHECK: %[[LT:.+]] = torch.aten.lt.Scalar %[[SLICE]], %[[INT0_0]]
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[SLICE]], %[[DIM1]], %[[INT1_1]]
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %[[SLICE]]
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[B0]], %[[INT1_1]]
// CHECK: %[[VIEW:.+]] = torch.aten.view %[[WHERE]], %[[LIST]]
// CHECK: %[[INT1_5:.+]] = torch.constant.int 1
// CHECK: %[[UNSQ:.+]] = torch.aten.unsqueeze %[[VIEW]], %[[INT1_5]]
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[DIM0]], %[[INT1_1]], %[[DIM2]]
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[UNSQ]], %[[LIST]], %[[FALSE]]
// CHECK: %[[INT1_6:.+]] = torch.constant.int 1
// CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT1_5]], %[[EXPAND]], %[[FALSE]]
// CHECK: %[[SQ:.+]] = torch.aten.squeeze.dim %[[GATHER]], %[[INT1_5]]
%none = torch.constant.none
%0 = torch.operator "onnx.GatherND"(%arg0, %arg1) {torch.onnx.batch_dims = 1 : si64} : (!torch.vtensor<[2,2,2],si32>, !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,2],si32>
return %0 : !torch.vtensor<[2,2],si32>
}

// -----

// CHECK-LABEL: func.func @test_gather_elements
func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0
Expand Down

0 comments on commit 39307f0

Please sign in to comment.