From 73da7b8ce5f4adc2c61e060b541715d4f25fe1aa Mon Sep 17 00:00:00 2001 From: Sirui Mu Date: Tue, 2 Jul 2024 03:23:23 +0800 Subject: [PATCH] [CIR][NFC] Fix bug during fp16 unary op CIRGen (#706) This PR fixes a bug during the CIRGen of fp16 unary operations. Before this patch, for the expression `-x` where `x` is a fp16 value, CIRGen emits the code like the following: ```mlir %0 = cir.cast float_to_float %x : !cir.f16 -> !cir.float %1 = cir.cast float_to_float %0 : !cir.float -> !cir.f16 %2 = cir.unary minus %1 : !cir.fp16 ``` The expected CIRGen should instead be: ```mlir %0 = cir.cast float_to_float %x : !cir.f16 -> !cir.float %1 = cir.unary minus %0 : !cir.float %2 = cir.cast float_to_float %1 : !cir.float -> !cir.f16 ``` This PR fixes this issue. --- clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp | 42 +++++++++++++--------- clang/test/CIR/CodeGen/bf16-ops.c | 8 ++--- clang/test/CIR/CodeGen/fp16-ops.c | 8 ++--- 3 files changed, 34 insertions(+), 24 deletions(-) diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index 019705c247d0..c7289d62aa14 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -604,37 +604,47 @@ class ScalarExprEmitter : public StmtVisitor { : PromotionType; auto result = VisitPlus(E, promotionTy); if (result && !promotionTy.isNull()) - result = buildUnPromotedValue(result, E->getType()); - return buildUnaryOp(E, mlir::cir::UnaryOpKind::Plus, result); + return buildUnPromotedValue(result, E->getType()); + return result; } - mlir::Value VisitPlus(const UnaryOperator *E, QualType PromotionType) { + mlir::Value VisitPlus(const UnaryOperator *E, + QualType PromotionType = QualType()) { // This differs from gcc, though, most likely due to a bug in gcc. TestAndClearIgnoreResultAssign(); + + mlir::Value operand; if (!PromotionType.isNull()) - return CGF.buildPromotedScalarExpr(E->getSubExpr(), PromotionType); - return Visit(E->getSubExpr()); + operand = CGF.buildPromotedScalarExpr(E->getSubExpr(), PromotionType); + else + operand = Visit(E->getSubExpr()); + + return buildUnaryOp(E, mlir::cir::UnaryOpKind::Plus, operand); } - mlir::Value VisitUnaryMinus(const UnaryOperator *E) { - // NOTE(cir): QualType function parameter still not used, so donĀ“t replicate - // it here yet. - QualType promotionTy = getPromotionType(E->getSubExpr()->getType()); + mlir::Value VisitUnaryMinus(const UnaryOperator *E, + QualType PromotionType = QualType()) { + QualType promotionTy = PromotionType.isNull() + ? getPromotionType(E->getSubExpr()->getType()) + : PromotionType; auto result = VisitMinus(E, promotionTy); if (result && !promotionTy.isNull()) - result = buildUnPromotedValue(result, E->getType()); - return buildUnaryOp(E, mlir::cir::UnaryOpKind::Minus, result); + return buildUnPromotedValue(result, E->getType()); + return result; } mlir::Value VisitMinus(const UnaryOperator *E, QualType PromotionType) { TestAndClearIgnoreResultAssign(); + + mlir::Value operand; if (!PromotionType.isNull()) - return CGF.buildPromotedScalarExpr(E->getSubExpr(), PromotionType); + operand = CGF.buildPromotedScalarExpr(E->getSubExpr(), PromotionType); + else + operand = Visit(E->getSubExpr()); // NOTE: LLVM codegen will lower this directly to either a FNeg // or a Sub instruction. In CIR this will be handled later in LowerToLLVM. - - return Visit(E->getSubExpr()); + return buildUnaryOp(E, mlir::cir::UnaryOpKind::Minus, operand); } mlir::Value VisitUnaryNot(const UnaryOperator *E) { @@ -660,8 +670,8 @@ class ScalarExprEmitter : public StmtVisitor { mlir::Value buildUnaryOp(const UnaryOperator *E, mlir::cir::UnaryOpKind kind, mlir::Value input) { return Builder.create( - CGF.getLoc(E->getSourceRange().getBegin()), - CGF.getCIRType(E->getType()), kind, input); + CGF.getLoc(E->getSourceRange().getBegin()), input.getType(), kind, + input); } // C++ diff --git a/clang/test/CIR/CodeGen/bf16-ops.c b/clang/test/CIR/CodeGen/bf16-ops.c index 6a55e9acfe09..7812e03b129b 100644 --- a/clang/test/CIR/CodeGen/bf16-ops.c +++ b/clang/test/CIR/CodeGen/bf16-ops.c @@ -30,8 +30,8 @@ void foo(void) { h1 = -h1; // NONATIVE: %[[#A:]] = cir.cast(floating, %{{.+}} : !cir.bf16), !cir.float - // NONATIVE-NEXT: %[[#B:]] = cir.cast(floating, %[[#A]] : !cir.float), !cir.bf16 - // NONATIVE-NEXT: %{{.+}} = cir.unary(minus, %[[#B]]) : !cir.bf16, !cir.bf16 + // NONATIVE-NEXT: %[[#B:]] = cir.unary(minus, %[[#A]]) : !cir.float, !cir.float + // NONATIVE-NEXT: %{{.+}} = cir.cast(floating, %[[#B]] : !cir.float), !cir.bf16 // NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.bf16), !cir.float // NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.float), !cir.bf16 @@ -39,8 +39,8 @@ void foo(void) { h1 = +h1; // NONATIVE: %[[#A:]] = cir.cast(floating, %{{.+}} : !cir.bf16), !cir.float - // NONATIVE-NEXT: %[[#B:]] = cir.cast(floating, %[[#A]] : !cir.float), !cir.bf16 - // NONATIVE-NEXT: %{{.+}} = cir.unary(plus, %[[#B]]) : !cir.bf16, !cir.bf16 + // NONATIVE-NEXT: %[[#B:]] = cir.unary(plus, %[[#A]]) : !cir.float, !cir.float + // NONATIVE-NEXT: %{{.+}} = cir.cast(floating, %[[#B]] : !cir.float), !cir.bf16 // NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.bf16), !cir.float // NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.float), !cir.bf16 diff --git a/clang/test/CIR/CodeGen/fp16-ops.c b/clang/test/CIR/CodeGen/fp16-ops.c index e39b4fd4e9a9..46a410793a0e 100644 --- a/clang/test/CIR/CodeGen/fp16-ops.c +++ b/clang/test/CIR/CodeGen/fp16-ops.c @@ -30,8 +30,8 @@ void foo(void) { h1 = -h1; // NONATIVE: %[[#A:]] = cir.cast(floating, %{{.+}} : !cir.f16), !cir.float - // NONATIVE-NEXT: %[[#B:]] = cir.cast(floating, %[[#A]] : !cir.float), !cir.f16 - // NONATIVE-NEXT: %{{.+}} = cir.unary(minus, %[[#B]]) : !cir.f16, !cir.f16 + // NONATIVE-NEXT: %[[#B:]] = cir.unary(minus, %[[#A]]) : !cir.float, !cir.float + // NONATIVE-NEXT: %{{.+}} = cir.cast(floating, %[[#B]] : !cir.float), !cir.f16 // NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.f16), !cir.float // NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.float), !cir.f16 @@ -39,8 +39,8 @@ void foo(void) { h1 = +h1; // NONATIVE: %[[#A:]] = cir.cast(floating, %{{.+}} : !cir.f16), !cir.float - // NONATIVE-NEXT: %[[#B:]] = cir.cast(floating, %[[#A]] : !cir.float), !cir.f16 - // NONATIVE-NEXT: %{{.+}} = cir.unary(plus, %[[#B]]) : !cir.f16, !cir.f16 + // NONATIVE-NEXT: %[[#B:]] = cir.unary(plus, %[[#A]]) : !cir.float, !cir.float + // NONATIVE-NEXT: %{{.+}} = cir.cast(floating, %[[#B]] : !cir.float), !cir.f16 // NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.f16), !cir.float // NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.float), !cir.f16