Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][CIRGen][Builtin] Support __builtin_elementwise_abs and extend AbsOp to take vector input #1099

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4331,8 +4331,8 @@ def SqrtOp : UnaryFPToFPBuiltinOp<"sqrt", "SqrtOp">;
def TruncOp : UnaryFPToFPBuiltinOp<"trunc", "FTruncOp">;

def AbsOp : CIR_Op<"abs", [Pure, SameOperandsAndResultType]> {
let arguments = (ins PrimitiveSInt:$src, UnitAttr:$poison);
let results = (outs PrimitiveSInt:$result);
let arguments = (ins CIR_AnySignedIntOrVecOfSignedInt:$src, UnitAttr:$poison);
let results = (outs CIR_AnySignedIntOrVecOfSignedInt:$result);
let summary = [{
libc builtin equivalent abs, labs, llabs

Expand All @@ -4345,6 +4345,7 @@ def AbsOp : CIR_Op<"abs", [Pure, SameOperandsAndResultType]> {
```mlir
%0 = cir.const #cir.int<-42> : s32i
%1 = cir.abs %0 poison : s32i
%2 = cir.abs %3 : !cir.vector<!s32i x 4>
```
}];
let assemblyFormat = "$src ( `poison` $poison^ )? `:` type($src) attr-dict";
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ class StructType

bool isAnyFloatingPointType(mlir::Type t);
bool isFPOrFPVectorTy(mlir::Type);
bool isIntOrIntVectorTy(mlir::Type);
} // namespace cir

mlir::ParseResult parseAddrSpaceAttribute(mlir::AsmParser &p,
Expand Down
17 changes: 17 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def CIR_IntType : CIR_Type<"Int", "int",
bool isPrimitive() const {
return isValidPrimitiveIntBitwidth(getWidth());
}
bool isSignedPrimitive() const {
return isPrimitive() && isSigned();
}

/// Returns a minimum bitwidth of cir::IntType
static unsigned minBitwidth() { return 1; }
Expand Down Expand Up @@ -538,8 +541,22 @@ def IntegerVector : Type<
]>, "!cir.vector of !cir.int"> {
}

// Vector of signed integral type
def SignedIntegerVector : Type<
And<[
CPred<"::mlir::isa<::cir::VectorType>($_self)">,
CPred<"::mlir::isa<::cir::IntType>("
"::mlir::cast<::cir::VectorType>($_self).getEltType())">,
CPred<"::mlir::cast<::cir::IntType>("
"::mlir::cast<::cir::VectorType>($_self).getEltType())"
".isSignedPrimitive()">
]>, "!cir.vector of !cir.int"> {
}

// Constraints
def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_IntType, IntegerVector]>;
def CIR_AnySignedIntOrVecOfSignedInt: AnyTypeOf<
[PrimitiveSInt, SignedIntegerVector]>;

// Pointer to Arrays
def ArrayPtr : Type<
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/CIR/MissingFeatures.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ struct MissingFeatures {

//-- Other missing features

// We need to extend fpUnaryOPs to support vector types.
static bool fpUnaryOPsSupportVectorType() { return false; }

// We need to track the parent record types that represent a field
// declaration. This is necessary to determine the layout of a class.
static bool fieldDeclAbstraction() { return false; }
Expand Down
19 changes: 16 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1255,9 +1255,22 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_nondeterministic_value:
llvm_unreachable("BI__builtin_nondeterministic_value NYI");

case Builtin::BI__builtin_elementwise_abs:
llvm_unreachable("BI__builtin_elementwise_abs NYI");

case Builtin::BI__builtin_elementwise_abs: {
mlir::Type cirTy = ConvertType(E->getArg(0)->getType());
bool isIntTy = cir::isIntOrIntVectorTy(cirTy);
if (!isIntTy) {
if (cir::isAnyFloatingPointType(cirTy)) {
return emitUnaryFPBuiltin<cir::FAbsOp>(*this, *E);
}
assert(!MissingFeatures::fpUnaryOPsSupportVectorType());
llvm_unreachable("unsupported type for BI__builtin_elementwise_abs");
}
mlir::Value arg = emitScalarExpr(E->getArg(0));
auto call = getBuilder().create<cir::AbsOp>(getLoc(E->getExprLoc()),
arg.getType(), arg, false);
mlir::Value result = call->getResult(0);
return RValue::get(result);
}
case Builtin::BI__builtin_elementwise_acos:
llvm_unreachable("BI__builtin_elementwise_acos NYI");
case Builtin::BI__builtin_elementwise_asin:
Expand Down
14 changes: 13 additions & 1 deletion clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ bool cir::isAnyFloatingPointType(mlir::Type t) {
}

//===----------------------------------------------------------------------===//
// Floating-point and Float-point Vecotr type helpers
// Floating-point and Float-point Vector type helpers
//===----------------------------------------------------------------------===//

bool cir::isFPOrFPVectorTy(mlir::Type t) {
Expand All @@ -841,6 +841,18 @@ bool cir::isFPOrFPVectorTy(mlir::Type t) {
return isAnyFloatingPointType(t);
}

//===----------------------------------------------------------------------===//
// CIR Integer and Integer Vector type helpers
//===----------------------------------------------------------------------===//

bool cir::isIntOrIntVectorTy(mlir::Type t) {

if (isa<cir::VectorType>(t)) {
return isa<cir::IntType>(mlir::dyn_cast<cir::VectorType>(t).getEltType());
}
return isa<cir::IntType>(t);
}

//===----------------------------------------------------------------------===//
// ComplexType Definitions
//===----------------------------------------------------------------------===//
Expand Down
27 changes: 27 additions & 0 deletions clang/test/CIR/CodeGen/builtins-elementwise.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: %clang_cc1 -triple aarch64-none-linux-android24 -emit-cir %s -o %t.cir
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
// RUN: %clang_cc1 -triple aarch64-none-linux-android24 -fclangir \
// RUN: -emit-llvm %s -o %t.ll
// RUN: FileCheck --check-prefix=LLVM --input-file=%t.ll %s

typedef int vint4 __attribute__((ext_vector_type(4)));

void test_builtin_elementwise_abs(vint4 vi4, int i, float f, double d) {
// CIR-LABEL: test_builtin_elementwise_abs
// LLVM-LABEL: test_builtin_elementwise_abs
// CIR: {{%.*}} = cir.fabs {{%.*}} : !cir.float
// LLVM: {{%.*}} = call float @llvm.fabs.f32(float {{%.*}})
f = __builtin_elementwise_abs(f);

// CIR: {{%.*}} = cir.fabs {{%.*}} : !cir.double
// LLVM: {{%.*}} = call double @llvm.fabs.f64(double {{%.*}})
d = __builtin_elementwise_abs(d);

// CIR: {{%.*}} = cir.abs {{%.*}} : !cir.vector<!s32i x 4>
// LLVM: {{%.*}} = call <4 x i32> @llvm.abs.v4i32(<4 x i32> {{%.*}}, i1 false)
vi4 = __builtin_elementwise_abs(vi4);

// CIR: {{%.*}} = cir.abs {{%.*}} : !s32
// LLVM: {{%.*}} = call i32 @llvm.abs.i32(i32 {{%.*}}, i1 false)
i = __builtin_elementwise_abs(i);
}