Skip to content

Commit

Permalink
[CIR][NFC] Fix bug during fp16 unary op CIRGen (llvm#706)
Browse files Browse the repository at this point in the history
This PR fixes a bug during the CIRGen of fp16 unary operations. Before
this patch, for the expression `-x` where `x` is a fp16 value, CIRGen
emits the code like the following:

```mlir
%0 = cir.cast float_to_float %x : !cir.f16 -> !cir.float
%1 = cir.cast float_to_float %0 : !cir.float -> !cir.f16
%2 = cir.unary minus %1 : !cir.fp16
```

The expected CIRGen should instead be:

```mlir
%0 = cir.cast float_to_float %x : !cir.f16 -> !cir.float
%1 = cir.unary minus %0 : !cir.float
%2 = cir.cast float_to_float %1 : !cir.float -> !cir.f16
```

This PR fixes this issue.
  • Loading branch information
Lancern authored and smeenai committed Oct 9, 2024
1 parent 8286a86 commit 73da7b8
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 24 deletions.
42 changes: 26 additions & 16 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,37 +604,47 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
: PromotionType;
auto result = VisitPlus(E, promotionTy);
if (result && !promotionTy.isNull())
result = buildUnPromotedValue(result, E->getType());
return buildUnaryOp(E, mlir::cir::UnaryOpKind::Plus, result);
return buildUnPromotedValue(result, E->getType());
return result;
}

mlir::Value VisitPlus(const UnaryOperator *E, QualType PromotionType) {
mlir::Value VisitPlus(const UnaryOperator *E,
QualType PromotionType = QualType()) {
// This differs from gcc, though, most likely due to a bug in gcc.
TestAndClearIgnoreResultAssign();

mlir::Value operand;
if (!PromotionType.isNull())
return CGF.buildPromotedScalarExpr(E->getSubExpr(), PromotionType);
return Visit(E->getSubExpr());
operand = CGF.buildPromotedScalarExpr(E->getSubExpr(), PromotionType);
else
operand = Visit(E->getSubExpr());

return buildUnaryOp(E, mlir::cir::UnaryOpKind::Plus, operand);
}

mlir::Value VisitUnaryMinus(const UnaryOperator *E) {
// NOTE(cir): QualType function parameter still not used, so don´t replicate
// it here yet.
QualType promotionTy = getPromotionType(E->getSubExpr()->getType());
mlir::Value VisitUnaryMinus(const UnaryOperator *E,
QualType PromotionType = QualType()) {
QualType promotionTy = PromotionType.isNull()
? getPromotionType(E->getSubExpr()->getType())
: PromotionType;
auto result = VisitMinus(E, promotionTy);
if (result && !promotionTy.isNull())
result = buildUnPromotedValue(result, E->getType());
return buildUnaryOp(E, mlir::cir::UnaryOpKind::Minus, result);
return buildUnPromotedValue(result, E->getType());
return result;
}

mlir::Value VisitMinus(const UnaryOperator *E, QualType PromotionType) {
TestAndClearIgnoreResultAssign();

mlir::Value operand;
if (!PromotionType.isNull())
return CGF.buildPromotedScalarExpr(E->getSubExpr(), PromotionType);
operand = CGF.buildPromotedScalarExpr(E->getSubExpr(), PromotionType);
else
operand = Visit(E->getSubExpr());

// NOTE: LLVM codegen will lower this directly to either a FNeg
// or a Sub instruction. In CIR this will be handled later in LowerToLLVM.

return Visit(E->getSubExpr());
return buildUnaryOp(E, mlir::cir::UnaryOpKind::Minus, operand);
}

mlir::Value VisitUnaryNot(const UnaryOperator *E) {
Expand All @@ -660,8 +670,8 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
mlir::Value buildUnaryOp(const UnaryOperator *E, mlir::cir::UnaryOpKind kind,
mlir::Value input) {
return Builder.create<mlir::cir::UnaryOp>(
CGF.getLoc(E->getSourceRange().getBegin()),
CGF.getCIRType(E->getType()), kind, input);
CGF.getLoc(E->getSourceRange().getBegin()), input.getType(), kind,
input);
}

// C++
Expand Down
8 changes: 4 additions & 4 deletions clang/test/CIR/CodeGen/bf16-ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@ void foo(void) {

h1 = -h1;
// NONATIVE: %[[#A:]] = cir.cast(floating, %{{.+}} : !cir.bf16), !cir.float
// NONATIVE-NEXT: %[[#B:]] = cir.cast(floating, %[[#A]] : !cir.float), !cir.bf16
// NONATIVE-NEXT: %{{.+}} = cir.unary(minus, %[[#B]]) : !cir.bf16, !cir.bf16
// NONATIVE-NEXT: %[[#B:]] = cir.unary(minus, %[[#A]]) : !cir.float, !cir.float
// NONATIVE-NEXT: %{{.+}} = cir.cast(floating, %[[#B]] : !cir.float), !cir.bf16

// NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.bf16), !cir.float
// NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.float), !cir.bf16
// NATIVE: %{{.+}} = cir.unary(minus, %{{.+}}) : !cir.bf16, !cir.bf16

h1 = +h1;
// NONATIVE: %[[#A:]] = cir.cast(floating, %{{.+}} : !cir.bf16), !cir.float
// NONATIVE-NEXT: %[[#B:]] = cir.cast(floating, %[[#A]] : !cir.float), !cir.bf16
// NONATIVE-NEXT: %{{.+}} = cir.unary(plus, %[[#B]]) : !cir.bf16, !cir.bf16
// NONATIVE-NEXT: %[[#B:]] = cir.unary(plus, %[[#A]]) : !cir.float, !cir.float
// NONATIVE-NEXT: %{{.+}} = cir.cast(floating, %[[#B]] : !cir.float), !cir.bf16

// NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.bf16), !cir.float
// NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.float), !cir.bf16
Expand Down
8 changes: 4 additions & 4 deletions clang/test/CIR/CodeGen/fp16-ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@ void foo(void) {

h1 = -h1;
// NONATIVE: %[[#A:]] = cir.cast(floating, %{{.+}} : !cir.f16), !cir.float
// NONATIVE-NEXT: %[[#B:]] = cir.cast(floating, %[[#A]] : !cir.float), !cir.f16
// NONATIVE-NEXT: %{{.+}} = cir.unary(minus, %[[#B]]) : !cir.f16, !cir.f16
// NONATIVE-NEXT: %[[#B:]] = cir.unary(minus, %[[#A]]) : !cir.float, !cir.float
// NONATIVE-NEXT: %{{.+}} = cir.cast(floating, %[[#B]] : !cir.float), !cir.f16

// NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.f16), !cir.float
// NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.float), !cir.f16
// NATIVE: %{{.+}} = cir.unary(minus, %{{.+}}) : !cir.f16, !cir.f16

h1 = +h1;
// NONATIVE: %[[#A:]] = cir.cast(floating, %{{.+}} : !cir.f16), !cir.float
// NONATIVE-NEXT: %[[#B:]] = cir.cast(floating, %[[#A]] : !cir.float), !cir.f16
// NONATIVE-NEXT: %{{.+}} = cir.unary(plus, %[[#B]]) : !cir.f16, !cir.f16
// NONATIVE-NEXT: %[[#B:]] = cir.unary(plus, %[[#A]]) : !cir.float, !cir.float
// NONATIVE-NEXT: %{{.+}} = cir.cast(floating, %[[#B]] : !cir.float), !cir.f16

// NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.f16), !cir.float
// NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.float), !cir.f16
Expand Down

0 comments on commit 73da7b8

Please sign in to comment.