diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index ff1ffd7e2b62..96132e38fe63 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6478,6 +6478,32 @@ def Torch_AtenLinearOp : Torch_Op<"aten.linear", [ }]; } +def Torch_AtenBilinearOp : Torch_Op<"aten.bilinear", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::bilinear : (Tensor, Tensor, Tensor, Tensor?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input1, + AnyTorchTensorType:$input2, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBilinearOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenBilinearOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenMmOp : Torch_Op<"aten.mm", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 5fd05708961c..bf4ba9894083 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8877,6 +8877,100 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list, !torch.list, !torch.optional>) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.bilinear\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional>) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int3 = torch.constant.int 3\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %2 = torch.aten.eq.int %0, %1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %6 = torch.aten.sub.int %5, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %7 = torch.aten.__getitem__.t %arg0, %6 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %9 = torch.aten.eq.int %7, %8 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %9 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %10 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %11 = torch.aten.sub.int %10, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %12 = torch.aten.__getitem__.t %arg1, %11 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.eq.int %12, %13 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %14 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %15 = torch.prim.ListConstruct : () -> !torch.list\n" +" %16 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %17 = torch.aten.sub.int %16, %int1 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %17, %true, init() {\n" +" ^bb0(%arg4: !torch.int):\n" +" %22 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list, !torch.int -> !torch.int\n" +" %23 = torch.aten.__getitem__.t %arg1, %arg4 : !torch.list, !torch.int -> !torch.int\n" +" %24 = torch.aten.eq.int %22, %23 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %24 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %25 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list, !torch.int -> !torch.int\n" +" %26 = torch.aten.append.t %15, %25 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %18 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.aten.append.t %15, %18 : !torch.list, !torch.int -> !torch.list\n" +" %20 = torch.aten.__isnot__ %arg3, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %21 = torch.prim.If %20 -> (!torch.list) {\n" +" %22 = torch.prim.unchecked_cast %arg3 : !torch.optional> -> !torch.list\n" +" %23 = torch.aten.len.t %22 : !torch.list -> !torch.int\n" +" %24 = torch.aten.eq.int %23, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %25 = torch.prim.If %24 -> (!torch.bool) {\n" +" %27 = torch.aten.__getitem__.t %22, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %28 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %29 = torch.aten.eq.int %27, %28 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %29 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %25 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %26 = func.call @__torch__.torch.jit._shape_functions.broadcast(%15, %22) : (!torch.list, !torch.list) -> !torch.list\n" +" torch.prim.If.yield %26 : !torch.list\n" +" } else {\n" +" torch.prim.If.yield %15 : !torch.list\n" +" }\n" +" return %21 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._trilinear\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.list, %arg7: !torch.int) -> !torch.list {\n" " %int3 = torch.constant.int 3\n" " %int-1 = torch.constant.int -1\n" @@ -15747,6 +15841,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bilinear\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.optional>) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._trilinear\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.list, %arg7: !torch.int) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9c2a80187c93..e5b24a05f250 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2108,11 +2108,34 @@ class DecomposeAten_TrilinearOp : public OpRewritePattern { input3 = *unsqueezeTensor(rewriter, op, input3, expandDim); } + auto mulType = [&](Value input1, Value input2) -> Type { + BaseTensorType inputType1 = cast(input1.getType()); + BaseTensorType inputType2 = cast(input2.getType()); + Type elementType = inputType1.getOptionalDtype(); + if (inputType1.hasSizes() && inputType2.hasSizes()) { + SmallVector mulShape; + ArrayRef inputSize1 = inputType1.getSizes(); + ArrayRef inputSize2 = inputType2.getSizes(); + for (unsigned i = 0; i < inputSize1.size(); i++) { + int64_t size1 = inputSize1[i]; + int64_t size2 = inputSize2[i]; + if (size1 == kUnknownSize || size2 == kUnknownSize) { + mulShape.push_back(kUnknownSize); + } else { + mulShape.push_back(size1 == 1 ? size2 : size1); + } + } + return inputType1.getWithSizesAndDtype(mulShape, elementType); + } + return inputType1.getWithSizesAndDtype(std::nullopt, elementType); + }; + // Apply multiplication operation. - auto mul1 = - rewriter.create(loc, op.getType(), input1, input2); - auto mul2 = - rewriter.create(loc, op.getType(), mul1, input3); + BaseTensorType opType = cast(op.getType()); + auto type = opType.hasSizes() ? mulType(input1, input2) : opType; + auto mul1 = rewriter.create(loc, type, input1, input2); + type = opType.hasSizes() ? mulType(mul1, input3) : opType; + auto mul2 = rewriter.create(loc, type, mul1, input3); // Apply sum operation. // Parse sumDim in descending order to avoid any issues with the @@ -7655,6 +7678,78 @@ class DecomposeAtenLinearOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose `aten.bilinear` op into `aten._trilinear` and `aten.add` ops. +class DecomposeAtenBilinearOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenBilinearOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input1 = op.getInput1(); + Value input2 = op.getInput2(); + Value weight = op.getWeight(); + Value bias = op.getBias(); + + BaseTensorType inputType1 = cast(input1.getType()); + BaseTensorType inputType2 = cast(input2.getType()); + if (!inputType1.hasSizes() || !inputType2.hasSizes()) + return rewriter.notifyMatchFailure(op, "expected input to have sizes"); + + BaseTensorType weightType = cast(weight.getType()); + if (!weightType.hasSizes()) + return rewriter.notifyMatchFailure(op, "expected weight to have sizes"); + // `weight` must be a rank 3 matrix. + ArrayRef weightSizes = weightType.getSizes(); + if (weightSizes.size() != 3) + return rewriter.notifyMatchFailure(op, "expected weight to be a rank 3"); + + // generate `aten._trilinear` op + unsigned n = inputType1.getSizes().size() - 1; + Type listOfInt = + rewriter.getType(rewriter.getType()); + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(n)); + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(n + 1)); + Value two = + rewriter.create(loc, rewriter.getI64IntegerAttr(n + 2)); + Value expand1 = rewriter.create( + loc, listOfInt, SmallVector{zero, two}); + Value expand2 = rewriter.create( + loc, listOfInt, SmallVector{zero, one}); + SmallVector expandWeightValue; + for (unsigned i = 0; i < n; i++) { + Value value = + rewriter.create(loc, rewriter.getI64IntegerAttr(i)); + expandWeightValue.push_back(value); + } + Value expandw = + rewriter.create(loc, listOfInt, expandWeightValue); + Value sumdim = rewriter.create( + loc, listOfInt, SmallVector{one, two}); + Value constOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value trilinear = rewriter.create( + loc, op.getType(), input1, weight, input2, expand1, expandw, expand2, + sumdim, constOne); + + if (isa(bias.getType())) { + rewriter.replaceOp(op, trilinear); + return success(); + } else { + BaseTensorType biasType = cast(bias.getType()); + if (!biasType.hasSizes() || biasType.getSizes().size() != 1) + return rewriter.notifyMatchFailure(op, "expected bias to be rank 1"); + // generate `aten.add` op for bias + rewriter.replaceOpWithNewOp(op, op.getType(), trilinear, + bias, constOne); + return success(); + } + } +}; +} // namespace + namespace { // Decompose `aten.mish` op into `aten.tanh` and `aten.softplus` ops. // Mish(x) = x * Tanh(Softplus(x)) @@ -11612,6 +11707,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index f15911e2b5ba..15195cbe9632 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -485,6 +485,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index bb8f3a029b1d..e312c5b3a7fc 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -399,7 +399,6 @@ "AtenIntBoolOpModule_basic", "AtenIntMM_basic", "AtenNonzero1DDynamicModule_basic", # no lowering for torch.aten.sym_constrain_range_for_size - "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", "QuantizedReluInt32_basic", "QuantizedReluInt8_basic", @@ -514,9 +513,6 @@ "_SoftmaxModule_basic", "UpSampleNearest2dDynamicFactor_basic", "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", - "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", - "Aten_TrilinearModuleSumAllDims_basic", - "Aten_TrilinearModuleSumdims_basic", } FX_IMPORTER_STABLEHLO_XFAIL_SET = { @@ -622,7 +618,6 @@ "AtenTopKModule_basic", "AtenTopKSmallestModule_basic", "Aten_EmbeddingBagExample_basic", - "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", "AvgPool2dDivisorOverrideModule_basic", "BernoulliTensorModule_basic", @@ -942,9 +937,6 @@ # materialization callback produced value of incorrect type failed "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", - "Aten_TrilinearModuleSumdims_basic", - "Aten_TrilinearModuleSumAllDims_basic", - "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", } @@ -977,6 +969,14 @@ "AtenLinearMatVec_basic", "AtenLinearVecMatBias_basic", "AtenLinearVecMat_basic", + "Aten_BilinearModule1D_basic", + "Aten_BilinearModuleND_basic", + "Aten_BilinearModule_basic", + "Aten_TrilinearModuleSumAllDims_basic", + "Aten_TrilinearModuleSumdims_basic", + "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", + "Aten_TrilinearModuleVaryingRanks_basic", + "Aten_TrilinearModule_basic", "ReduceAminSingleDim_basic", "AtenDotModule_basic", "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", @@ -1676,9 +1676,6 @@ } FX_IMPORTER_TOSA_CRASHING_SET = { - "Aten_TrilinearModuleSumAllDims_basic", - "Aten_TrilinearModuleSumdims_basic", - "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", "ScatterSrcModule_basic", @@ -1775,6 +1772,9 @@ "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModule_basic", + "Aten_BilinearModule1D_basic", + "Aten_BilinearModuleND_basic", + "Aten_BilinearModule_basic", "ElementwiseAddBoolModule_basic", "Exp2StaticModule_basic", "CosineSimilarityStaticBroadcastModule_basic", @@ -3350,6 +3350,10 @@ "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", "Aten_TrilinearModuleZerodDimBug_basic", + "Aten_BilinearModule1D_basic", + "Aten_BilinearModuleDynamic_basic", + "Aten_BilinearModuleND_basic", + "Aten_BilinearModule_basic", } if torch_version_for_comparison() < version.parse("2.3.0.dev"): @@ -3438,7 +3442,6 @@ "Threshold3dIntModule_basic", "ElementwiseCopysignModule_basic", "ElementwiseSignbitModule_basic", - "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", "MaxPool3dEmptyStrideStaticModule_basic", "MaxPool3dLargeDatadModule_basic", @@ -4110,6 +4113,10 @@ "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", "Aten_TrilinearModuleZerodDimBug_basic", + "Aten_BilinearModule1D_basic", + "Aten_BilinearModuleDynamic_basic", + "Aten_BilinearModuleND_basic", + "Aten_BilinearModule_basic", "AtenTrilModule_basic", "AtenTrilWithNegDiagonalModule_basic", "AtenTrilWithPosDiagonalModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index a73d188d7168..45607ea2bf19 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1304,6 +1304,30 @@ def aten〇unflatten〇int〡shape(self: List[int], dim: int, sizes: List[int]) def aten〇linear〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]: return upstream_shape_functions.linear(input, weight, bias) +@check_shape_function([ + Invocation(TensorOfShape(8, 2), TensorOfShape(8, 3), TensorOfShape(4, 2, 3), TensorOfShape(4)), # Basic case + Invocation(TensorOfShape(8, 2, 2), TensorOfShape(8, 2, 3), TensorOfShape(4, 2, 3), TensorOfShape(4)), # 3D inputs + ErrorInvocation(TensorOfShape(8, 2), TensorOfShape(8, 2, 3), TensorOfShape(4, 2, 3)), # input dimensions don't match + ErrorInvocation(TensorOfShape(8, 2), TensorOfShape(8, 3), TensorOfShape(2, 3, 4)), # weight dimensions don't match + ErrorInvocation(TensorOfShape(8, 2), TensorOfShape(8, 3), TensorOfShape(4, 2, 3), TensorOfShape(8)), # bias dimensions don't match +]) +def aten〇bilinear〡shape(input1: List[int], input2: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]: + assert len(input1) == len(input2) + assert len(weight) == 3 + assert input1[len(input1)-1] == weight[1] + assert input2[len(input2)-1] == weight[2] + + out_shape: List[int] = [] + for i in range(len(input1)-1): + assert input1[i] == input2[i] + out_shape.append(input1[i]) + out_shape.append(weight[0]) + + if bias is not None: + assert len(bias) == 1 and bias[0] == weight[0] + out_shape = upstream_shape_functions.broadcast(out_shape, bias) + return out_shape + @check_shape_function([ Invocation(TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), [], [], [], [], 0), # Basic case Invocation(TensorOfShape(4, 5, 6), TensorOfShape(4, 5, 6), TensorOfShape(4, 5, 6), [1], [0], [0], [], 2), # Expansions w/ Non-Zero unroll_dim @@ -5605,6 +5629,15 @@ def aten〇linear〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: promoted_dtype = promote_dtypes(ranks, dtypes) return promoted_dtype +def aten〇bilinear〡dtype(input1_rank_dtype: Tuple[int, int], input2_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None) -> int: + input1_rank, input1_dtype = input1_rank_dtype + input2_rank, input2_dtype = input2_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + ranks: List[Optional[int]] = [input1_rank, input2_rank, weight_rank] + dtypes = [input1_dtype, input2_dtype, weight_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype(3, None, None, None, expand1 = [], expand2 = [], expand3 = [], sumdim = [], unroll_dim = 0), ) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 930979b3c939..d4833ff9de3b 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -564,6 +564,7 @@ def emit_with_mutating_variants(key, **kwargs): # Non-elementwise tensor compute ops emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)") + emit("aten::bilinear : (Tensor, Tensor, Tensor, Tensor?) -> (Tensor)") emit("aten::mm : (Tensor, Tensor) -> (Tensor)") emit("aten::_int_mm : (Tensor, Tensor) -> (Tensor)") emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index d1ddc42b39b1..39df5d4bbe4a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1945,3 +1945,100 @@ def forward(self, i1, i2, i3): @register_test_case(module_factory=lambda: Aten_TrilinearModuleZerodDimBug()) def Aten_TrilinearModuleZerodDimBug_basic(module, tu: TestUtils): return module.forward(tu.rand(2, 3, 6), tu.rand(2, 3, 6), tu.rand(2, 3, 6)) + + +# ============================================================================== + + +class Aten_BilinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([8, 2], torch.float32, True), + ([8, 3], torch.float32, True), + ([4, 2, 3], torch.float32, True), + ([4], torch.float32, True), + ] + ) + def forward(self, input1, input2, weight, bias): + return torch.ops.aten.bilinear(input1, input2, weight, bias) + + +@register_test_case(module_factory=lambda: Aten_BilinearModule()) +def Aten_BilinearModule_basic(module, tu: TestUtils): + module.forward(tu.rand(8, 2), tu.rand(8, 3), tu.rand(4, 2, 3), tu.rand(4)) + + +class Aten_BilinearModuleDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, input1, input2, weight, bias): + return torch.ops.aten.bilinear(input1, input2, weight, bias) + + +@register_test_case(module_factory=lambda: Aten_BilinearModuleDynamic()) +def Aten_BilinearModuleDynamic_basic(module, tu: TestUtils): + module.forward(tu.rand(8, 2), tu.rand(8, 3), tu.rand(4, 2, 3), tu.rand(4)) + + +class Aten_BilinearModule1D(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2], torch.float32, True), + ([3], torch.float32, True), + ([4, 2, 3], torch.float32, True), + ([4], torch.float32, True), + ] + ) + def forward(self, input1, input2, weight, bias): + return torch.ops.aten.bilinear(input1, input2, weight, bias) + + +@register_test_case(module_factory=lambda: Aten_BilinearModule1D()) +def Aten_BilinearModule1D_basic(module, tu: TestUtils): + module.forward(tu.rand(2), tu.rand(3), tu.rand(4, 2, 3), tu.rand(4)) + + +class Aten_BilinearModuleND(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([8, 6, 12, 2], torch.float32, True), + ([8, 6, 12, 3], torch.float32, True), + ([4, 2, 3], torch.float32, True), + ([4], torch.float32, True), + ] + ) + def forward(self, input1, input2, weight, bias): + return torch.ops.aten.bilinear(input1, input2, weight, bias) + + +@register_test_case(module_factory=lambda: Aten_BilinearModuleND()) +def Aten_BilinearModuleND_basic(module, tu: TestUtils): + module.forward( + tu.rand(8, 6, 12, 2), tu.rand(8, 6, 12, 3), tu.rand(4, 2, 3), tu.rand(4) + )