Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] add aten.bilinear op decomposing #3931

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6478,6 +6478,32 @@ def Torch_AtenLinearOp : Torch_Op<"aten.linear", [
}];
}

def Torch_AtenBilinearOp : Torch_Op<"aten.bilinear", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::bilinear : (Tensor, Tensor, Tensor, Tensor?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input1,
AnyTorchTensorType:$input2,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBilinearOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenBilinearOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenMmOp : Torch_Op<"aten.mm", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
103 changes: 103 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8877,6 +8877,100 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bilinear\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<int>>) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int3 = torch.constant.int 3\n"
" %int1 = torch.constant.int 1\n"
" %int2 = torch.constant.int 2\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %2 = torch.aten.eq.int %0, %1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
" %4 = torch.aten.eq.int %3, %int3 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %6 = torch.aten.sub.int %5, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %7 = torch.aten.__getitem__.t %arg0, %6 : !torch.list<int>, !torch.int -> !torch.int\n"
" %8 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %9 = torch.aten.eq.int %7, %8 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %9 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %10 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %11 = torch.aten.sub.int %10, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %12 = torch.aten.__getitem__.t %arg1, %11 : !torch.list<int>, !torch.int -> !torch.int\n"
" %13 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %14 = torch.aten.eq.int %12, %13 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %14 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %15 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %16 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %17 = torch.aten.sub.int %16, %int1 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.Loop %17, %true, init() {\n"
" ^bb0(%arg4: !torch.int):\n"
" %22 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list<int>, !torch.int -> !torch.int\n"
" %23 = torch.aten.__getitem__.t %arg1, %arg4 : !torch.list<int>, !torch.int -> !torch.int\n"
" %24 = torch.aten.eq.int %22, %23 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %24 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %25 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list<int>, !torch.int -> !torch.int\n"
" %26 = torch.aten.append.t %15, %25 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %18 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %19 = torch.aten.append.t %15, %18 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" %20 = torch.aten.__isnot__ %arg3, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
" %21 = torch.prim.If %20 -> (!torch.list<int>) {\n"
" %22 = torch.prim.unchecked_cast %arg3 : !torch.optional<list<int>> -> !torch.list<int>\n"
" %23 = torch.aten.len.t %22 : !torch.list<int> -> !torch.int\n"
" %24 = torch.aten.eq.int %23, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %25 = torch.prim.If %24 -> (!torch.bool) {\n"
" %27 = torch.aten.__getitem__.t %22, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %28 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %29 = torch.aten.eq.int %27, %28 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %29 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %25 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %26 = func.call @__torch__.torch.jit._shape_functions.broadcast(%15, %22) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" torch.prim.If.yield %26 : !torch.list<int>\n"
" } else {\n"
" torch.prim.If.yield %15 : !torch.list<int>\n"
" }\n"
" return %21 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._trilinear\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.list<int>, %arg7: !torch.int) -> !torch.list<int> {\n"
" %int3 = torch.constant.int 3\n"
" %int-1 = torch.constant.int -1\n"
Expand Down Expand Up @@ -15747,6 +15841,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bilinear\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.optional<tuple<int, int>>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._trilinear\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.list<int>, %arg7: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
Expand Down
104 changes: 100 additions & 4 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2108,11 +2108,34 @@ class DecomposeAten_TrilinearOp : public OpRewritePattern<Aten_TrilinearOp> {
input3 = *unsqueezeTensor(rewriter, op, input3, expandDim);
}

auto mulType = [&](Value input1, Value input2) -> Type {
BaseTensorType inputType1 = cast<BaseTensorType>(input1.getType());
BaseTensorType inputType2 = cast<BaseTensorType>(input2.getType());
Type elementType = inputType1.getOptionalDtype();
if (inputType1.hasSizes() && inputType2.hasSizes()) {
SmallVector<int64_t> mulShape;
ArrayRef<int64_t> inputSize1 = inputType1.getSizes();
ArrayRef<int64_t> inputSize2 = inputType2.getSizes();
for (unsigned i = 0; i < inputSize1.size(); i++) {
int64_t size1 = inputSize1[i];
int64_t size2 = inputSize2[i];
if (size1 == kUnknownSize || size2 == kUnknownSize) {
mulShape.push_back(kUnknownSize);
} else {
mulShape.push_back(size1 == 1 ? size2 : size1);
}
}
return inputType1.getWithSizesAndDtype(mulShape, elementType);
}
return inputType1.getWithSizesAndDtype(std::nullopt, elementType);
};

// Apply multiplication operation.
auto mul1 =
rewriter.create<AtenMulTensorOp>(loc, op.getType(), input1, input2);
auto mul2 =
rewriter.create<AtenMulTensorOp>(loc, op.getType(), mul1, input3);
BaseTensorType opType = cast<BaseTensorType>(op.getType());
auto type = opType.hasSizes() ? mulType(input1, input2) : opType;
auto mul1 = rewriter.create<AtenMulTensorOp>(loc, type, input1, input2);
type = opType.hasSizes() ? mulType(mul1, input3) : opType;
auto mul2 = rewriter.create<AtenMulTensorOp>(loc, type, mul1, input3);

// Apply sum operation.
// Parse sumDim in descending order to avoid any issues with the
Expand Down Expand Up @@ -7655,6 +7678,78 @@ class DecomposeAtenLinearOp : public OpRewritePattern<AtenLinearOp> {
};
} // namespace

namespace {
// Decompose `aten.bilinear` op into `aten._trilinear` and `aten.add` ops.
class DecomposeAtenBilinearOp : public OpRewritePattern<AtenBilinearOp> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add some code-snippet as comment here to show how the decomposition looks like?

public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenBilinearOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input1 = op.getInput1();
Value input2 = op.getInput2();
Value weight = op.getWeight();
Value bias = op.getBias();

BaseTensorType inputType1 = cast<BaseTensorType>(input1.getType());
BaseTensorType inputType2 = cast<BaseTensorType>(input2.getType());
if (!inputType1.hasSizes() || !inputType2.hasSizes())
return rewriter.notifyMatchFailure(op, "expected input to have sizes");

BaseTensorType weightType = cast<BaseTensorType>(weight.getType());
if (!weightType.hasSizes())
return rewriter.notifyMatchFailure(op, "expected weight to have sizes");
// `weight` must be a rank 3 matrix.
ArrayRef<int64_t> weightSizes = weightType.getSizes();
if (weightSizes.size() != 3)
return rewriter.notifyMatchFailure(op, "expected weight to be a rank 3");

// generate `aten._trilinear` op
unsigned n = inputType1.getSizes().size() - 1;
Type listOfInt =
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>());
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(n));
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(n + 1));
Value two =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(n + 2));
Value expand1 = rewriter.create<PrimListConstructOp>(
loc, listOfInt, SmallVector<Value>{zero, two});
Value expand2 = rewriter.create<PrimListConstructOp>(
loc, listOfInt, SmallVector<Value>{zero, one});
SmallVector<Value> expandWeightValue;
for (unsigned i = 0; i < n; i++) {
Value value =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
expandWeightValue.push_back(value);
}
Value expandw =
rewriter.create<PrimListConstructOp>(loc, listOfInt, expandWeightValue);
Value sumdim = rewriter.create<PrimListConstructOp>(
loc, listOfInt, SmallVector<Value>{one, two});
Value constOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value trilinear = rewriter.create<Aten_TrilinearOp>(
loc, op.getType(), input1, weight, input2, expand1, expandw, expand2,
sumdim, constOne);

if (isa<Torch::NoneType>(bias.getType())) {
rewriter.replaceOp(op, trilinear);
return success();
} else {
BaseTensorType biasType = cast<BaseTensorType>(bias.getType());
if (!biasType.hasSizes() || biasType.getSizes().size() != 1)
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
// generate `aten.add` op for bias
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), trilinear,
bias, constOne);
return success();
}
}
};
} // namespace

namespace {
// Decompose `aten.mish` op into `aten.tanh` and `aten.softplus` ops.
// Mish(x) = x * Tanh(Softplus(x))
Expand Down Expand Up @@ -11612,6 +11707,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenHardtanhOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFullOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinearOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBilinearOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewFullOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenHardtanhOp>();
target.addIllegalOp<AtenFullOp>();
target.addIllegalOp<AtenLinearOp>();
target.addIllegalOp<AtenBilinearOp>();
target.addIllegalOp<AtenMishOp>();
target.addIllegalOp<AtenFullLikeOp>();
target.addIllegalOp<AtenNewFullOp>();
Expand Down
31 changes: 19 additions & 12 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,6 @@
"AtenIntBoolOpModule_basic",
"AtenIntMM_basic",
"AtenNonzero1DDynamicModule_basic", # no lowering for torch.aten.sym_constrain_range_for_size
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"QuantizedReluInt32_basic",
"QuantizedReluInt8_basic",
Expand Down Expand Up @@ -514,9 +513,6 @@
"_SoftmaxModule_basic",
"UpSampleNearest2dDynamicFactor_basic",
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleSumAllDims_basic",
"Aten_TrilinearModuleSumdims_basic",
}

FX_IMPORTER_STABLEHLO_XFAIL_SET = {
Expand Down Expand Up @@ -622,7 +618,6 @@
"AtenTopKModule_basic",
"AtenTopKSmallestModule_basic",
"Aten_EmbeddingBagExample_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"AvgPool2dDivisorOverrideModule_basic",
"BernoulliTensorModule_basic",
Expand Down Expand Up @@ -942,9 +937,6 @@
# materialization callback produced value of incorrect type failed
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"Aten_TrilinearModuleSumdims_basic",
"Aten_TrilinearModuleSumAllDims_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"CrossEntropyLossModule_basic",
"CrossEntropyLossNoReductionModule_basic",
}
Expand Down Expand Up @@ -977,6 +969,14 @@
"AtenLinearMatVec_basic",
"AtenLinearVecMatBias_basic",
"AtenLinearVecMat_basic",
"Aten_BilinearModule1D_basic",
"Aten_BilinearModuleND_basic",
"Aten_BilinearModule_basic",
"Aten_TrilinearModuleSumAllDims_basic",
"Aten_TrilinearModuleSumdims_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModule_basic",
"ReduceAminSingleDim_basic",
"AtenDotModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
Expand Down Expand Up @@ -1676,9 +1676,6 @@
}

FX_IMPORTER_TOSA_CRASHING_SET = {
"Aten_TrilinearModuleSumAllDims_basic",
"Aten_TrilinearModuleSumdims_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"CrossEntropyLossModule_basic",
"CrossEntropyLossNoReductionModule_basic",
"ScatterSrcModule_basic",
Expand Down Expand Up @@ -1775,6 +1772,9 @@
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModule_basic",
"Aten_BilinearModule1D_basic",
"Aten_BilinearModuleND_basic",
"Aten_BilinearModule_basic",
"ElementwiseAddBoolModule_basic",
"Exp2StaticModule_basic",
"CosineSimilarityStaticBroadcastModule_basic",
Expand Down Expand Up @@ -3350,6 +3350,10 @@
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"Aten_BilinearModule1D_basic",
"Aten_BilinearModuleDynamic_basic",
"Aten_BilinearModuleND_basic",
"Aten_BilinearModule_basic",
}

if torch_version_for_comparison() < version.parse("2.3.0.dev"):
Expand Down Expand Up @@ -3438,7 +3442,6 @@
"Threshold3dIntModule_basic",
"ElementwiseCopysignModule_basic",
"ElementwiseSignbitModule_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"MaxPool3dEmptyStrideStaticModule_basic",
"MaxPool3dLargeDatadModule_basic",
Expand Down Expand Up @@ -4110,6 +4113,10 @@
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"Aten_BilinearModule1D_basic",
"Aten_BilinearModuleDynamic_basic",
"Aten_BilinearModuleND_basic",
"Aten_BilinearModule_basic",
"AtenTrilModule_basic",
"AtenTrilWithNegDiagonalModule_basic",
"AtenTrilWithPosDiagonalModule_basic",
Expand Down
Loading
Loading