diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 81601ce51f431..4c9dad9e2c173 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -26,57 +26,29 @@ namespace mlir { using namespace mlir; namespace { -// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780. struct AbsOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto loc = op.getLoc(); + auto type = op.getType(); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); - Type elementType = op.getType(); - Value arg = adaptor.getComplex(); - - Value zero = - b.create(elementType, b.getZeroAttr(elementType)); - Value one = b.create(elementType, - b.getFloatAttr(elementType, 1.0)); - - Value real = b.create(elementType, arg); - Value imag = b.create(elementType, arg); - - Value realIsZero = - b.create(arith::CmpFPredicate::OEQ, real, zero); - Value imagIsZero = - b.create(arith::CmpFPredicate::OEQ, imag, zero); - - // Real > Imag - Value imagDivReal = b.create(imag, real, fmf.getValue()); - Value imagSq = - b.create(imagDivReal, imagDivReal, fmf.getValue()); - Value imagSqPlusOne = b.create(imagSq, one, fmf.getValue()); - Value imagSqrt = b.create(imagSqPlusOne, fmf.getValue()); - Value absImag = b.create(imagSqrt, real, fmf.getValue()); - - // Real <= Imag - Value realDivImag = b.create(real, imag, fmf.getValue()); - Value realSq = - b.create(realDivImag, realDivImag, fmf.getValue()); - Value realSqPlusOne = b.create(realSq, one, fmf.getValue()); - Value realSqrt = b.create(realSqPlusOne, fmf.getValue()); - Value absReal = b.create(realSqrt, imag, fmf.getValue()); - - rewriter.replaceOpWithNewOp( - op, realIsZero, imag, - b.create( - imagIsZero, real, - b.create( - b.create(arith::CmpFPredicate::OGT, real, imag), - absImag, absReal))); - + Value real = + rewriter.create(loc, type, adaptor.getComplex()); + Value imag = + rewriter.create(loc, type, adaptor.getComplex()); + Value realSqr = + rewriter.create(loc, real, real, fmf.getValue()); + Value imagSqr = + rewriter.create(loc, imag, imag, fmf.getValue()); + Value sqNorm = + rewriter.create(loc, realSqr, imagSqr, fmf.getValue()); + + rewriter.replaceOpWithNewOp(op, sqNorm); return success(); } }; diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index d5f83e0af4184..8fa29ea43854a 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -7,28 +7,13 @@ func.func @complex_abs(%arg: complex) -> f32 { %abs = complex.abs %arg: complex return %abs : f32 } - -// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex -// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 -// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 -// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32 -// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32 -// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32 -// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32 -// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] : f32 -// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32 -// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32 -// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32 -// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32 -// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] : f32 -// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32 -// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 -// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32 -// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32 -// CHECK: return %[[ABS3]] : f32 +// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32 +// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32 +// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] : f32 +// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK: return %[[NORM]] : f32 // ----- @@ -256,26 +241,12 @@ func.func @complex_log(%arg: complex) -> complex { %log = complex.log %arg: complex return %log : complex } -// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex -// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 -// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 -// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32 -// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32 -// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32 -// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32 -// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] : f32 -// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32 -// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32 -// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32 -// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32 -// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] : f32 -// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32 -// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 -// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32 -// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32 +// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32 +// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32 +// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32 +// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 // CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex @@ -498,26 +469,12 @@ func.func @complex_sign(%arg: complex) -> complex { // CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 // CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 // CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1 -// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex -// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL2]], %[[ZERO]] : f32 -// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG2]], %[[ZERO]] : f32 -// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG2]], %[[REAL2]] : f32 -// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32 -// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32 -// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32 -// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL2]] : f32 -// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL2]], %[[IMAG2]] : f32 -// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32 -// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32 -// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32 -// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG2]] : f32 -// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL2]], %[[IMAG2]] : f32 -// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 -// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL2]], %[[ABS1]] : f32 -// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG2]], %[[ABS2]] : f32 +// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL2]], %[[REAL2]] : f32 +// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG2]], %[[IMAG2]] : f32 +// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32 +// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 // CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] : f32 // CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] : f32 // CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex @@ -759,27 +716,13 @@ func.func @complex_abs_with_fmf(%arg: complex) -> f32 { %abs = complex.abs %arg fastmath : complex return %abs : f32 } -// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex -// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 -// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 -// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath : f32 -// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath : f32 -// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath : f32 -// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath : f32 -// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] fastmath : f32 -// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath : f32 -// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath : f32 -// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath : f32 -// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath : f32 -// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] fastmath : f32 -// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32 -// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 -// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32 -// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32 -// CHECK: return %[[ABS3]] : f32 +// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath : f32 +// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath : f32 +// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] fastmath : f32 +// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK: return %[[NORM]] : f32 // ----- @@ -864,26 +807,12 @@ func.func @complex_log_with_fmf(%arg: complex) -> complex { %log = complex.log %arg fastmath : complex return %log : complex } -// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex -// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 -// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 -// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath : f32 -// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath : f32 -// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath : f32 -// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath : f32 -// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] fastmath : f32 -// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath : f32 -// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath : f32 -// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath : f32 -// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath : f32 -// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] fastmath : f32 -// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32 -// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 -// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32 -// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32 +// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath : f32 +// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath : f32 +// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] fastmath : f32 +// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 // CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] fastmath : f32 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex diff --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir index d710dc8e1adeb..9983dd46f0943 100644 --- a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir +++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir @@ -6,29 +6,12 @@ func.func @complex_abs(%arg: complex) -> f32 { %abs = complex.abs %arg: complex return %abs : f32 } -// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 -// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 // CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]] // CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]] -// CHECK: %[[REAL_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[REAL]], %[[ZERO]] : f32 -// CHECK: %[[IMAG_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[IMAG]], %[[ZERO]] : f32 - -// CHECK: %[[IMAG_DIV_REAL:.*]] = llvm.fdiv %[[IMAG]], %[[REAL]] : f32 -// CHECK: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32 -// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = llvm.fadd %[[IMAG_SQ]], %[[ONE]] : f32 -// CHECK: %[[IMAG_SQRT:.*]] = llvm.intr.sqrt(%[[IMAG_SQ_PLUS_ONE]]) : (f32) -> f32 -// CHECK: %[[ABS_IMAG:.*]] = llvm.fmul %[[IMAG_SQRT]], %[[REAL]] : f32 - -// CHECK: %[[REAL_DIV_IMAG:.*]] = llvm.fdiv %[[REAL]], %[[IMAG]] : f32 -// CHECK: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32 -// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = llvm.fadd %[[REAL_SQ]], %[[ONE]] : f32 -// CHECK: %[[REAL_SQRT:.*]] = llvm.intr.sqrt(%[[REAL_SQ_PLUS_ONE]]) : (f32) -> f32 -// CHECK: %[[ABS_REAL:.*]] = llvm.fmul %[[REAL_SQRT]], %[[IMAG]] : f32 - -// CHECK: %[[REAL_GT_IMAG:.*]] = llvm.fcmp "ogt" %[[REAL]], %[[IMAG]] : f32 -// CHECK: %[[ABS1:.*]] = llvm.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : i1, f32 -// CHECK: %[[ABS2:.*]] = llvm.select %[[IMAG_IS_ZERO]], %[[REAL]], %[[ABS1]] : i1, f32 -// CHECK: %[[NORM:.*]] = llvm.select %[[REAL_IS_ZERO]], %[[IMAG]], %[[ABS2]] : i1, f32 +// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]] : f32 +// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]] : f32 +// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]] : f32 +// CHECK: %[[NORM:.*]] = llvm.intr.sqrt(%[[SQ_NORM]]) : (f32) -> f32 // CHECK: llvm.return %[[NORM]] : f32 // CHECK-LABEL: llvm.func @complex_eq diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir index c8327e94def8a..349b92a7aefa2 100644 --- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir +++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir @@ -106,27 +106,6 @@ func.func @angle(%arg: complex) -> f32 { func.return %angle : f32 } -func.func @test_element_f64(%input: tensor>, - %func: (complex) -> f64) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %size = tensor.dim %input, %c0: tensor> - - scf.for %i = %c0 to %size step %c1 { - %elem = tensor.extract %input[%i]: tensor> - - %val = func.call_indirect %func(%elem) : (complex) -> f64 - vector.print %val : f64 - scf.yield - } - func.return -} - -func.func @abs(%arg: complex) -> f64 { - %abs = complex.abs %arg : complex - func.return %abs : f64 -} - func.func @entry() { // complex.sqrt test %sqrt_test = arith.constant dense<[ @@ -321,28 +300,5 @@ func.func @entry() { call @test_element(%angle_test_cast, %angle_func) : (tensor>, (complex) -> f32) -> () - // complex.abs test - %abs_test = arith.constant dense<[ - (1.0, 1.0), - // CHECK: 1.414 - (1.0e300, 1.0e300), - // CHECK-NEXT: 1.41421e+300 - (1.0e-300, 1.0e-300), - // CHECK-NEXT: 1.41421e-300 - (5.0, 0.0), - // CHECK-NEXT: 5 - (0.0, 6.0), - // CHECK-NEXT: 6 - (7.0, 8.0) - // CHECK-NEXT: 10.6301 - ]> : tensor<6xcomplex> - %abs_test_cast = tensor.cast %abs_test - : tensor<6xcomplex> to tensor> - - %abs_func = func.constant @abs : (complex) -> f64 - - call @test_element_f64(%abs_test_cast, %abs_func) - : (tensor>, (complex) -> f64) -> () - func.return }