Skip to content

Commit

Permalink
[TOSA] Add legalization for aten.diagonal (#3740)
Browse files Browse the repository at this point in the history
- Add lowering from Torch to TOSA for aten.diagonal
- Clean up some code
- Update xfail_sets.py with the new e2e results
- Update basic_mlir with the new op mlir test

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I99bed685455752d09ed96edd837c4dfbee152701

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
  • Loading branch information
justin-ngo-arm authored Sep 30, 2024
1 parent 5f74de5 commit 5eab669
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 47 deletions.
242 changes: 209 additions & 33 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -891,8 +891,6 @@ class ConvertAtenReductionOp : public OpConversionPattern<AtenOpT> {
if (!result)
return failure();

// TBD - support dtype casting.

rewriter.replaceOp(op, {result.value()});

return success();
Expand Down Expand Up @@ -5647,8 +5645,7 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
return success();
}

// Template to create support tril mask tensor for aten.tril
// legalization
// Template to create supporting tril mask tensor for aten.tril
template <typename T>
Value createTrilMask(PatternRewriter &rewriter, Operation *op,
ArrayRef<int64_t> shape, int64_t h, int64_t w,
Expand All @@ -5671,28 +5668,6 @@ Value createTrilMask(PatternRewriter &rewriter, Operation *op,
return tosa::getConstTensor<T>(rewriter, op, vec, shape).value();
}

// Function to get tril mask tensor based on input type
// for aten.tril legalization
Value getTrilMask(PatternRewriter &rewriter, Operation *op,
ArrayRef<int64_t> shape, int64_t h, int64_t w,
int64_t diagonal, Type type) {
return TypeSwitch<Type, Value>(type)
.Case<mlir::FloatType>([&](auto) {
return createTrilMask<float>(rewriter, op, shape, h, w, diagonal);
})
.Case<mlir::IntegerType>([&](auto intType) {
switch (intType.getWidth()) {
case 1:
return createTrilMask<bool>(rewriter, op, shape, h, w, diagonal);
case 32:
return createTrilMask<int32_t>(rewriter, op, shape, h, w, diagonal);
case 64:
return createTrilMask<int64_t>(rewriter, op, shape, h, w, diagonal);
}
llvm_unreachable("Invalid integer width");
});
}

// Legalization for aten.tril
template <>
LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
Expand Down Expand Up @@ -5740,21 +5715,221 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "Diagonal value is not an integer");

// Define shape for mask tensor based on rank
SmallVector<int64_t> constShape;
SmallVector<int64_t> maskShape;
for (auto i = 0; i < selfRank - 2; i++)
constShape.push_back(1);
constShape.push_back(h);
constShape.push_back(w);

Value trilMask = getTrilMask(rewriter, op, constShape, h, w, diagonal,
resultType.getElementType());
maskShape.push_back(1);
maskShape.push_back(h);
maskShape.push_back(w);

Value trilMask = TypeSwitch<Type, Value>(resultType.getElementType())
.Case<mlir::FloatType>([&](auto) {
return createTrilMask<float>(rewriter, op, maskShape,
h, w, diagonal);
})
.Case<mlir::IntegerType>([&](auto intType) {
switch (intType.getWidth()) {
case 1:
return createTrilMask<bool>(rewriter, op, maskShape,
h, w, diagonal);
case 32:
return createTrilMask<int32_t>(
rewriter, op, maskShape, h, w, diagonal);
case 64:
return createTrilMask<int64_t>(
rewriter, op, maskShape, h, w, diagonal);
}
llvm_unreachable("Invalid integer width");
});

rewriter.replaceOpWithNewOp<tosa::MulOp>(op, resultType, self, trilMask,
/*shift=*/0);

return success();
}

// Template to create supporting diagonal mask tensor for aten.diagonal
template <typename T>
Value createDiagonalMask(PatternRewriter &rewriter, Operation *op,
ArrayRef<int64_t> shape, int64_t h, int64_t w,
int64_t offset) {
SmallVector<T> vec;

for (int64_t i = 0; i < h; i++) {
for (int64_t j = 0; j < w; j++) {
// Positive offset value moves above the main diagonal, while negative
// diagonal value moves below the main diagonal.
if (i + offset == j) {
vec.push_back(static_cast<T>(1));
} else {
vec.push_back(static_cast<T>(0));
}
}
}

return tosa::getConstTensor<T>(rewriter, op, vec, shape).value();
}

// Legalization for aten.diagonal
template <>
LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
AtenDiagonalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();

// Not a ranked tensor type
auto selfType = dyn_cast<RankedTensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types are supported");

// Rank below 2 not accepted
auto selfRank = selfType.getRank();
if (selfRank <= 1)
return rewriter.notifyMatchFailure(
op, "Rank 0 and 1 are not accepted as they cause underflow");

if (!selfType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Currently only static shapes are supported");

const TypeConverter *typeConverter = this->getTypeConverter();
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));
if (!resultType)
return rewriter.notifyMatchFailure(op, "Result type cannot be empty");

auto selfElemTy = selfType.getElementType();
auto resultElemTy = resultType.getElementType();

int64_t offset, dim1, dim2;
if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset)))
offset = 0;

if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) {
dim1 = 0;
} else {
dim1 = toPositiveDim(dim1, selfRank);
}

if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2))) {
dim2 = 1;
} else {
dim2 = toPositiveDim(dim2, selfRank);
}

auto selfShape = makeShapeTorchCompatible(selfType.getShape());
int64_t h = selfShape[dim1];
int64_t w = selfShape[dim2];

// Overflowing offset not supported
if ((offset < 0 && std::abs(offset) >= h) || (offset >= 0 && offset >= w))
return rewriter.notifyMatchFailure(
op, "Offset greater or equal than shape not supported");

int64_t targetDim1 = selfRank - 2;
int64_t targetDim2 = selfRank - 1;

Value selfTransposed = self;
SmallVector<int64_t> transposedInputShape = selfShape;
RankedTensorType transposedInputType = selfType;

// If (dim1, dim2) != (rank - 2, rank - 1), transpose the input tensor
// so that dim1 and dim2 become rank - 2 and rank - 1. We do this so that
// we can consistently create the diagonal mask tensor.
if (!(dim1 == targetDim1 && dim2 == targetDim2)) {
SmallVector<int32_t> transposedDims;
transposedInputShape.clear();

for (int64_t i = 0; i < selfRank; ++i) {
if (i == dim1 || i == dim2)
continue;
transposedDims.push_back(i);
}
transposedDims.push_back(dim1);
transposedDims.push_back(dim2);

auto transposedDimsConst = tosa::getConstTensor<int32_t>(
rewriter, op,
/*vec=*/transposedDims,
/*shape=*/{static_cast<int32_t>(selfRank)});

for (auto &dim : transposedDims)
transposedInputShape.push_back(selfShape[dim]);

transposedInputType = RankedTensorType::get(
makeShapeLLVMCompatible(transposedInputShape), selfElemTy);

selfTransposed = rewriter.create<tosa::TransposeOp>(
op->getLoc(), transposedInputType, self, transposedDimsConst.value());
}

// Define shape for mask tensor based on rank
SmallVector<int64_t> maskShape;
for (auto i = 0; i < selfRank - 2; i++)
maskShape.push_back(1);
maskShape.push_back(h);
maskShape.push_back(w);

Value diagonalMask =
TypeSwitch<Type, Value>(resultElemTy)
.Case<mlir::FloatType>([&](auto) {
return createDiagonalMask<float>(rewriter, op, maskShape, h, w,
offset);
})
.Case<mlir::IntegerType>([&](auto intType) {
switch (intType.getWidth()) {
case 1:
return createDiagonalMask<bool>(rewriter, op, maskShape, h, w,
offset);
case 32:
return createDiagonalMask<int32_t>(rewriter, op, maskShape, h, w,
offset);
case 64:
return createDiagonalMask<int64_t>(rewriter, op, maskShape, h, w,
offset);
}
llvm_unreachable("Invalid integer width");
});

Value diagonalTensor = rewriter.create<tosa::MulOp>(
op->getLoc(), transposedInputType, selfTransposed, diagonalMask,
/*shift=*/0);

auto resultShape = makeShapeTorchCompatible(resultType.getShape());
auto targetReduceDim = resultShape[resultType.getRank() - 1];

// If transposedInputShape[targetDim1] (or h) is greater than the innermost
// dim of the result, we won't get the correct shape when we reduce sum along
// the innermost dim to get the result. Therefore, we have to slice the
// transposed tensor so that transposedInputShape[targetDim1] ==
// targetReduceDim.
if (h > targetReduceDim) {
transposedInputShape[targetDim1] = targetReduceDim;
transposedInputType = RankedTensorType::get(
makeShapeLLVMCompatible(transposedInputShape), selfElemTy);
SmallVector<int64_t> startSlice(selfRank, 0);
SmallVector<int64_t> sizeSlice =
llvm::to_vector(makeShapeTorchCompatible(transposedInputShape));
if (offset < 0)
startSlice[targetDim1] = std::abs(offset);
diagonalTensor = rewriter.create<tosa::SliceOp>(
op->getLoc(), transposedInputType, diagonalTensor,
rewriter.getDenseI64ArrayAttr(startSlice),
rewriter.getDenseI64ArrayAttr(sizeSlice));
}

// Apply Reduce Sum to get the result
auto reduceDimType = RankedTensorType::get({1}, rewriter.getI64Type());
auto reduceDimAttr =
DenseIntElementsAttr::get(reduceDimType, llvm::ArrayRef({targetDim2}));
auto result =
mlir::tosa::convertReduceSumOp(rewriter, op, resultType, diagonalTensor,
reduceDimAttr, /*keep_dims=*/false);

rewriter.replaceOp(op, result.value());

return success();
}
} // namespace

// -----------------------------------------------------------------------------
Expand Down Expand Up @@ -6060,6 +6235,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenIscloseOp);
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
INSERT_ATENOP_PATTERN(AtenTrilOp);
INSERT_ATENOP_PATTERN(AtenDiagonalOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
Expand Down
18 changes: 4 additions & 14 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,6 +1663,8 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"DiagonalWithStaticShapeModule_basic",
"EinsumStaticDiagonalDimensionModule_basic",
"ElementwiseAtenFloorDivideBroadcastModule_basic",
"ElementwiseAtenFloorDivideScalarModule_basic",
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
Expand Down Expand Up @@ -3190,6 +3192,7 @@
}

FX_IMPORTER_TOSA_XFAIL_SET = {
"AdaptiveMaxPool1dDimOneStatic_basic",
"AtenPolarDoubleModule_basic",
"AtenPolarFloatModule_basic",
"HstackBasicComplexModule_basic",
Expand All @@ -3213,7 +3216,6 @@
"Conv_Transpose2dStaticModule_basic",
"Conv_Transpose3dModule_basic",
"Conv_Transpose3dStaticModule_basic",
"EinsumStaticDiagonalDimensionModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic",
"ElementwiseIntTensorLtFloatTensorModule_basic",
"ElementwiseRreluEvalModule_basic",
Expand Down Expand Up @@ -3384,14 +3386,6 @@
"DeterminantBatchedModule_F32",
"DeterminantDynamicModule_F32",
"DeterminantModule_F32",
"DiagonalModule_basic",
"DiagonalModule_nonsquare",
"DiagonalModule_transposed",
"DiagonalModule_with_dims",
"DiagonalModule_with_dims_and_offset",
"DiagonalModule_with_negative_dims",
"DiagonalModule_with_offset",
"DiagonalWithStaticShapeModule_basic",
"DivFloatModule_basic",
"DivIntModule_basic",
"DropoutTrainModule_basic",
Expand Down Expand Up @@ -3805,11 +3799,7 @@
"ToCopyWithDTypeModule_basic",
"TorchPrimLoopForLikeModule_basic",
"TorchPrimLoopWhileLikeModule_basic",
"TraceModule_basic",
"TraceModule_empty",
"TraceModule_nonsquare",
"TraceSignedIntModule_basic",
"TraceUnsignedIntModule_basic",
"TraceUnsignedIntModule_empty",
"TypeConversionI1ToF64Module_basic",
"TypeConversionI1ToI32Module_basic",
Expand Down Expand Up @@ -3845,6 +3835,7 @@
}

ONNX_TOSA_XFAIL_SET = {
"AdaptiveMaxPool1dDimOneStatic_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"HstackBasicComplexModule_basic",
"HstackBasicFloatModule_basic",
Expand Down Expand Up @@ -3874,7 +3865,6 @@
"Conv_Transpose2dStaticModule_basic",
"Conv_Transpose3dModule_basic",
"Conv_Transpose3dStaticModule_basic",
"EinsumStaticDiagonalDimensionModule_basic",
"EinsumStaticModule_basic",
"ElementwiseFmaxModule_basic",
"ElementwiseFminModule_basic",
Expand Down
26 changes: 26 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1859,3 +1859,29 @@ func.func @torch.aten.bitwise_right_shift.Tensor$basic(%arg0: !torch.vtensor<[?,
%0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
return %0: !torch.vtensor<[?,?],si32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.diagonal$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5,6],si32>) -> !torch.vtensor<[5,6,2],si32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5,6],si32> -> tensor<3x4x5x6xi32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = torch.constant.int 0
// CHECK: %[[VAL_4:.*]] = torch.constant.int -2
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<[2, 3, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[VAL_6:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_5]] : (tensor<3x4x5x6xi32>, tensor<4xi32>) -> tensor<5x6x4x3xi32>
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 1, 0]]]]> : tensor<1x1x4x3xi32>}> : () -> tensor<1x1x4x3xi32>
// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_6]], %[[VAL_7]] {shift = 0 : i8} : (tensor<5x6x4x3xi32>, tensor<1x1x4x3xi32>) -> tensor<5x6x4x3xi32>
// CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]] {size = array<i64: 5, 6, 2, 3>, start = array<i64: 0, 0, 2, 0>} : (tensor<5x6x4x3xi32>) -> tensor<5x6x2x3xi32>
// CHECK: %[[VAL_10:.*]] = tosa.reduce_sum %[[VAL_9]] {axis = 3 : i32} : (tensor<5x6x2x3xi32>) -> tensor<5x6x2x1xi32>
// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array<i64: 5, 6, 2>} : (tensor<5x6x2x1xi32>) -> tensor<5x6x2xi32>
// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<5x6x2xi32> -> !torch.vtensor<[5,6,2],si32>
// CHECK: return %[[VAL_12]] : !torch.vtensor<[5,6,2],si32>
// CHECK: }
func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> !torch.vtensor<[5,6,2], si32> {
%dim1 = torch.constant.int 1
%dim2 = torch.constant.int 0
%offset = torch.constant.int -2
%0 = torch.aten.diagonal %arg0, %offset, %dim1, %dim2 : !torch.vtensor<[3,4,5,6],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[5,6,2],si32>
return %0 : !torch.vtensor<[5,6,2],si32>
}

0 comments on commit 5eab669

Please sign in to comment.