Skip to content

Commit

Permalink
Support __builtin_elementwise_abs and extend AbsOp to take vector input
Browse files Browse the repository at this point in the history
  • Loading branch information
ghehg committed Nov 11, 2024
1 parent e6b2808 commit ad787bd
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 6 deletions.
5 changes: 3 additions & 2 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4331,8 +4331,8 @@ def SqrtOp : UnaryFPToFPBuiltinOp<"sqrt", "SqrtOp">;
def TruncOp : UnaryFPToFPBuiltinOp<"trunc", "FTruncOp">;

def AbsOp : CIR_Op<"abs", [Pure, SameOperandsAndResultType]> {
let arguments = (ins PrimitiveSInt:$src, UnitAttr:$poison);
let results = (outs PrimitiveSInt:$result);
let arguments = (ins CIR_AnySignedIntOrVecOfSignedInt:$src, UnitAttr:$poison);
let results = (outs CIR_AnySignedIntOrVecOfSignedInt:$result);
let summary = [{
libc builtin equivalent abs, labs, llabs

Expand All @@ -4345,6 +4345,7 @@ def AbsOp : CIR_Op<"abs", [Pure, SameOperandsAndResultType]> {
```mlir
%0 = cir.const #cir.int<-42> : s32i
%1 = cir.abs %0 poison : s32i
%2 = cir.abs %3 : !cir.vector<!s32i x 4>
```
}];
let assemblyFormat = "$src ( `poison` $poison^ )? `:` type($src) attr-dict";
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ class StructType

bool isAnyFloatingPointType(mlir::Type t);
bool isFPOrFPVectorTy(mlir::Type);
bool isCIRIntOrIntVectorTy(mlir::Type);
} // namespace cir

mlir::ParseResult parseAddrSpaceAttribute(mlir::AsmParser &p,
Expand Down
17 changes: 17 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def CIR_IntType : CIR_Type<"Int", "int",
bool isPrimitive() const {
return isValidPrimitiveIntBitwidth(getWidth());
}
bool isSignedPrimitive() const {
return isPrimitive() && isSigned();
}

/// Returns a minimum bitwidth of cir::IntType
static unsigned minBitwidth() { return 1; }
Expand Down Expand Up @@ -538,8 +541,22 @@ def IntegerVector : Type<
]>, "!cir.vector of !cir.int"> {
}

// Vector of signed integral type
def SignedIntegerVector : Type<
And<[
CPred<"::mlir::isa<::cir::VectorType>($_self)">,
CPred<"::mlir::isa<::cir::IntType>("
"::mlir::cast<::cir::VectorType>($_self).getEltType())">,
CPred<"::mlir::cast<::cir::IntType>("
"::mlir::cast<::cir::VectorType>($_self).getEltType())"
".isSignedPrimitive()">
]>, "!cir.vector of !cir.int"> {
}

// Constraints
def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_IntType, IntegerVector]>;
def CIR_AnySignedIntOrVecOfSignedInt: AnyTypeOf<
[PrimitiveSInt, SignedIntegerVector]>;

// Pointer to Arrays
def ArrayPtr : Type<
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/CIR/MissingFeatures.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ struct MissingFeatures {

//-- Other missing features

// We need to extend fpUnaryOPs to support vector types.
static bool fpUnaryOPsSupportVectorType() { return false; }

// We need to track the parent record types that represent a field
// declaration. This is necessary to determine the layout of a class.
static bool fieldDeclAbstraction() { return false; }
Expand Down
19 changes: 16 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1255,9 +1255,22 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_nondeterministic_value:
llvm_unreachable("BI__builtin_nondeterministic_value NYI");

case Builtin::BI__builtin_elementwise_abs:
llvm_unreachable("BI__builtin_elementwise_abs NYI");

case Builtin::BI__builtin_elementwise_abs: {
mlir::Type cirTy = ConvertType(E->getArg(0)->getType());
bool isIntTy = cir::isCIRIntOrIntVectorTy(cirTy);
if (!isIntTy) {
if (cir::isAnyFloatingPointType(cirTy)) {
return emitUnaryFPBuiltin<cir::FAbsOp>(*this, *E);
}
assert(!MissingFeatures::fpUnaryOPsSupportVectorType());
llvm_unreachable("unsupported type for BI__builtin_elementwise_abs");
}
mlir::Value arg = emitScalarExpr(E->getArg(0));
auto call = getBuilder().create<cir::AbsOp>(getLoc(E->getExprLoc()),
arg.getType(), arg, false);
mlir::Value result = call->getResult(0);
return RValue::get(result);
}
case Builtin::BI__builtin_elementwise_acos:
llvm_unreachable("BI__builtin_elementwise_acos NYI");
case Builtin::BI__builtin_elementwise_asin:
Expand Down
14 changes: 13 additions & 1 deletion clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ bool cir::isAnyFloatingPointType(mlir::Type t) {
}

//===----------------------------------------------------------------------===//
// Floating-point and Float-point Vecotr type helpers
// Floating-point and Float-point Vector type helpers
//===----------------------------------------------------------------------===//

bool cir::isFPOrFPVectorTy(mlir::Type t) {
Expand All @@ -841,6 +841,18 @@ bool cir::isFPOrFPVectorTy(mlir::Type t) {
return isAnyFloatingPointType(t);
}

//===----------------------------------------------------------------------===//
// CIR Integer and Integer Vector type helpers
//===----------------------------------------------------------------------===//

bool cir::isCIRIntOrIntVectorTy(mlir::Type t) {

if (isa<cir::VectorType>(t)) {
return isa<cir::IntType>(mlir::dyn_cast<cir::VectorType>(t).getEltType());
}
return isa<cir::IntType>(t);
}

//===----------------------------------------------------------------------===//
// ComplexType Definitions
//===----------------------------------------------------------------------===//
Expand Down
27 changes: 27 additions & 0 deletions clang/test/CIR/CodeGen/builtins-elementwise.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: %clang_cc1 -triple aarch64-none-linux-android24 -emit-cir %s -o %t.cir
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
// RUN: %clang_cc1 -triple aarch64-none-linux-android24 -fclangir \
// RUN: -emit-llvm %s -o %t.ll
// RUN: FileCheck --check-prefix=LLVM --input-file=%t.ll %s

typedef int vint4 __attribute__((ext_vector_type(4)));

void test_builtin_elementwise_abs(vint4 vi4, int i, float f, double d) {
// CIR-LABEL: test_builtin_elementwise_abs
// LLVM-LABEL: test_builtin_elementwise_abs
// CIR: {{%.*}} = cir.fabs {{%.*}} : !cir.float
// LLVM: {{%.*}} = call float @llvm.fabs.f32(float {{%.*}})
f = __builtin_elementwise_abs(f);

// CIR: {{%.*}} = cir.fabs {{%.*}} : !cir.double
// LLVM: {{%.*}} = call double @llvm.fabs.f64(double {{%.*}})
d = __builtin_elementwise_abs(d);

// CIR: {{%.*}} = cir.abs {{%.*}} : !cir.vector<!s32i x 4>
// LLVM: {{%.*}} = call <4 x i32> @llvm.abs.v4i32(<4 x i32> {{%.*}}, i1 false)
vi4 = __builtin_elementwise_abs(vi4);

// CIR: {{%.*}} = cir.abs {{%.*}} : !s32
// LLVM: {{%.*}} = call i32 @llvm.abs.i32(i32 {{%.*}}, i1 false)
i = __builtin_elementwise_abs(i);
}

0 comments on commit ad787bd

Please sign in to comment.