diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index a3bad0e0423b..54cfdfc7e00f 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -16007,6 +16007,31 @@ def Torch_AtenAddFloatIntOp : Torch_Op<"aten.add.float_int", [ let hasFolder = 1; } +def Torch_AtenMulFloatIntOp : Torch_Op<"aten.mul.float_int", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::mul.float_int : (float, int) -> (float)`"; + let arguments = (ins + Torch_FloatType:$a, + Torch_IntType:$b + ); + let results = (outs + Torch_FloatType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMulFloatIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMulFloatIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenSubFloatOp : Torch_Op<"aten.sub.float", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index 833397f41bb8..baed74fed6dc 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -74,7 +74,8 @@ class ConvertAtenBinaryOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Value a = adaptor.getA(); Value b = adaptor.getB(); - if (llvm::is_one_of::value) + if (llvm::is_one_of::value || + llvm::is_one_of::value) b = convertScalarToDtype(rewriter, op.getLoc(), b, a.getType()); if (llvm::is_one_of::value) a = convertScalarToDtype(rewriter, op.getLoc(), a, b.getType()); @@ -492,7 +493,8 @@ class ConvertTorchToArith target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); + AtenMulIntOp, AtenRemainderIntOp, AtenMulIntFloatOp, + AtenMulFloatIntOp>(); patterns.add>( typeConverter, context); patterns.add>( @@ -505,6 +507,8 @@ class ConvertTorchToArith typeConverter, context); patterns.add>( typeConverter, context); + patterns.add>( + typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 87d1464e245c..eafbe14162cc 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -4278,6 +4278,18 @@ OpFoldResult AtenAddFloatIntOp::fold(FoldAdaptor adaptor) { adaptor.getOperands(), [](double a, double b) { return a + b; }); } +//===----------------------------------------------------------------------===// +// AtenMulFloatIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenMulFloatIntOp::fold(FoldAdaptor adaptor) { + if (!adaptor.getA() || !adaptor.getB()) { + return nullptr; + } + return atenBinaryFloatOperatorFoldHelper( + adaptor.getOperands(), [](double a, double b) { return a * b; }); +} + //===----------------------------------------------------------------------===// // AtenPowIntFloatOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index a8ce5ed20c6b..c4850fac18fc 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6908,12 +6908,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " %11 = torch.aten.__getitem__.t %9, %int0 : !torch.list, !torch.int -> !torch.float\n" " %12 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" -" %13 = torch.operator \"aten.mul.float_int\"(%11, %12) : (!torch.float, !torch.int) -> !torch.float \n" +" %13 = torch.aten.mul.float_int %11, %12 : !torch.float, !torch.int -> !torch.float\n" " %14 = torch.aten.Int.float %13 : !torch.float -> !torch.int\n" " %15 = torch.aten.append.t %3, %14 : !torch.list, !torch.int -> !torch.list\n" " %16 = torch.aten.__getitem__.t %9, %int1 : !torch.list, !torch.int -> !torch.float\n" " %17 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" -" %18 = torch.operator \"aten.mul.float_int\"(%16, %17) : (!torch.float, !torch.int) -> !torch.float \n" +" %18 = torch.aten.mul.float_int %16, %17 : !torch.float, !torch.int -> !torch.float\n" " %19 = torch.aten.Int.float %18 : !torch.float -> !torch.int\n" " %20 = torch.aten.append.t %3, %19 : !torch.list, !torch.int -> !torch.list\n" " torch.prim.If.yield %true, %3 : !torch.bool, !torch.list\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 07029d0894ee..db9c2c9bfd2a 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1123,6 +1123,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::neg.int : (int) -> (int)", has_folder=True) emit("aten::log.int : (int) -> (float)") emit("aten::add.float_int : (float, int) -> (float)", has_folder=True) + emit("aten::mul.float_int : (float, int) -> (float)", has_folder=True) emit("aten::sub.float : (float, float) -> (float)", has_folder=True) emit("aten::mul.float : (float, float) -> (float)", has_folder=True) emit("aten::div.float : (float, float) -> (float)", has_folder=True) diff --git a/test/Conversion/TorchToArith/basic.mlir b/test/Conversion/TorchToArith/basic.mlir index 88d08d695f8c..8fa13b47e588 100644 --- a/test/Conversion/TorchToArith/basic.mlir +++ b/test/Conversion/TorchToArith/basic.mlir @@ -250,6 +250,20 @@ func.func @torch.aten.mul.int_float(%arg0: !torch.int, %arg1: !torch.float) -> ! return %0 : !torch.float } +// CHECK-LABEL: func.func @torch.aten.mul.float_int( +// CHECK-SAME: %[[LHS:.*]]: !torch.float, +// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.float { +// CHECK-DAG: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] +// CHECK-DAG: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK: %[[RHS_F64:.*]] = arith.sitofp %[[RHS_I64]] : i64 to f64 +// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_F64]], %[[RHS_F64]] : f64 +// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[MUL]] +// CHECK: return %[[OUT]] : !torch.float +func.func @torch.aten.mul.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.float { + %0 = torch.aten.mul.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.float + return %0 : !torch.float +} + // CHECK-LABEL: func.func @torch.aten.div.float( // CHECK-SAME: %[[LHS:.*]]: !torch.float, // CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float { diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 12778f4017e8..d4afd67d65db 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1255,6 +1255,16 @@ func.func @torch.aten.mul.float() -> !torch.float { return %ret : !torch.float } +// CHECK-LABEL: func.func @torch.aten.mul.float_int() -> !torch.float { +// CHECK: %[[CST6:.*]] = torch.constant.float 6.000000e+00 +// CHECK: return %[[CST6]] : !torch.float +func.func @torch.aten.mul.float_int() -> !torch.float { + %cst2 = torch.constant.float 2.0 + %cst3 = torch.constant.int 3 + %ret = torch.aten.mul.float_int %cst2, %cst3: !torch.float, !torch.int -> !torch.float + return %ret : !torch.float +} + // CHECK-LABEL: func.func @torch.aten.neg.float() -> !torch.float { // CHECK: %[[CST_6:.*]] = torch.constant.float -6.000000e+00 // CHECK: return %[[CST_6]] : !torch.float