Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR] Add initial support for bit-precise integer types #538

Merged
merged 4 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def ObjSizeOp : CIR_Op<"objsize", [Pure]> {

let arguments = (ins CIR_PointerType:$ptr, SizeInfoType:$kind,
UnitAttr:$dynamic);
let results = (outs CIR_IntType:$result);
let results = (outs PrimitiveInt:$result);

let assemblyFormat = [{
`(`
Expand Down Expand Up @@ -180,7 +180,7 @@ def PtrDiffOp : CIR_Op<"ptr_diff", [Pure, SameTypeOperands]> {
```
}];

let results = (outs CIR_IntType:$result);
let results = (outs PrimitiveInt:$result);
let arguments = (ins CIR_PointerType:$lhs, CIR_PointerType:$rhs);

let assemblyFormat = [{
Expand Down Expand Up @@ -208,7 +208,7 @@ def PtrStrideOp : CIR_Op<"ptr_stride",
```
}];

let arguments = (ins CIR_PointerType:$base, CIR_IntType:$stride);
let arguments = (ins CIR_PointerType:$base, PrimitiveInt:$stride);
let results = (outs CIR_PointerType:$result);

let assemblyFormat = [{
Expand Down Expand Up @@ -337,7 +337,7 @@ def AllocaOp : CIR_Op<"alloca", [
}];

let arguments = (ins
Optional<CIR_IntType>:$dynAllocSize,
Optional<PrimitiveInt>:$dynAllocSize,
TypeAttr:$allocaType,
StrAttr:$name,
UnitAttr:$init,
Expand Down Expand Up @@ -1034,7 +1034,7 @@ class CIR_BitOp<string mnemonic, TypeConstraint inputTy>
}];
}

def BitClrsbOp : CIR_BitOp<"bit.clrsb", SIntOfWidths<[32, 64]>> {
def BitClrsbOp : CIR_BitOp<"bit.clrsb", AnyTypeOf<[SInt32, SInt64]>> {
let summary = "Get the number of leading redundant sign bits in the input";
let description = [{
Compute the number of leading redundant sign bits in the input integer.
Expand Down Expand Up @@ -1065,7 +1065,7 @@ def BitClrsbOp : CIR_BitOp<"bit.clrsb", SIntOfWidths<[32, 64]>> {
}];
}

def BitClzOp : CIR_BitOp<"bit.clz", UIntOfWidths<[16, 32, 64]>> {
def BitClzOp : CIR_BitOp<"bit.clz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
let summary = "Get the number of leading 0-bits in the input";
let description = [{
Compute the number of leading 0-bits in the input.
Expand All @@ -1090,7 +1090,7 @@ def BitClzOp : CIR_BitOp<"bit.clz", UIntOfWidths<[16, 32, 64]>> {
}];
}

def BitCtzOp : CIR_BitOp<"bit.ctz", UIntOfWidths<[16, 32, 64]>> {
def BitCtzOp : CIR_BitOp<"bit.ctz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
let summary = "Get the number of trailing 0-bits in the input";
let description = [{
Compute the number of trailing 0-bits in the input.
Expand All @@ -1115,7 +1115,7 @@ def BitCtzOp : CIR_BitOp<"bit.ctz", UIntOfWidths<[16, 32, 64]>> {
}];
}

def BitFfsOp : CIR_BitOp<"bit.ffs", SIntOfWidths<[32, 64]>> {
def BitFfsOp : CIR_BitOp<"bit.ffs", AnyTypeOf<[SInt32, SInt64]>> {
let summary = "Get the position of the least significant 1-bit of input";
let description = [{
Compute the position of the least significant 1-bit of the input.
Expand All @@ -1138,7 +1138,7 @@ def BitFfsOp : CIR_BitOp<"bit.ffs", SIntOfWidths<[32, 64]>> {
}];
}

def BitParityOp : CIR_BitOp<"bit.parity", UIntOfWidths<[32, 64]>> {
def BitParityOp : CIR_BitOp<"bit.parity", AnyTypeOf<[UInt32, UInt64]>> {
let summary = "Get the parity of input";
let description = [{
Compute the parity of the input. The parity of an integer is the number of
Expand All @@ -1160,7 +1160,8 @@ def BitParityOp : CIR_BitOp<"bit.parity", UIntOfWidths<[32, 64]>> {
}];
}

def BitPopcountOp : CIR_BitOp<"bit.popcount", UIntOfWidths<[16, 32, 64]>> {
def BitPopcountOp
: CIR_BitOp<"bit.popcount", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
let summary = "Get the number of 1-bits in input";
let description = [{
Compute the number of 1-bits in the input.
Expand Down Expand Up @@ -1208,7 +1209,7 @@ def ByteswapOp : CIR_Op<"bswap", [Pure, SameOperandsAndResultType]> {
}];

let results = (outs CIR_IntType:$result);
let arguments = (ins UIntOfWidths<[16, 32, 64]>:$input);
let arguments = (ins AnyTypeOf<[UInt16, UInt32, UInt64]>:$input);

let assemblyFormat = [{
`(` $input `:` type($input) `)` `:` type($result) attr-dict
Expand Down Expand Up @@ -1252,7 +1253,7 @@ def CmpThreeWayOp : CIR_Op<"cmp3way", [Pure, SameTypeOperands]> {
```
}];

let results = (outs CIR_IntType:$result);
let results = (outs PrimitiveSInt:$result);
let arguments = (ins CIR_AnyType:$lhs, CIR_AnyType:$rhs,
CmpThreeWayInfoAttr:$info);

Expand All @@ -1261,7 +1262,7 @@ def CmpThreeWayOp : CIR_Op<"cmp3way", [Pure, SameTypeOperands]> {
`:` type($result) attr-dict
}];

let hasVerifier = 1;
let hasVerifier = 0;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2122,7 +2123,7 @@ def VecInsertOp : CIR_Op<"vec.insert", [Pure,
element is returned.
}];

let arguments = (ins CIR_VectorType:$vec, AnyType:$value, CIR_IntType:$index);
let arguments = (ins CIR_VectorType:$vec, AnyType:$value, PrimitiveInt:$index);
let results = (outs CIR_VectorType:$result);

let assemblyFormat = [{
Expand All @@ -2147,7 +2148,7 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure,
from a vector object.
}];

let arguments = (ins CIR_VectorType:$vec, CIR_IntType:$index);
let arguments = (ins CIR_VectorType:$vec, PrimitiveInt:$index);
let results = (outs CIR_AnyType:$result);

let assemblyFormat = [{
Expand Down Expand Up @@ -2935,7 +2936,7 @@ def CopyOp : CIR_Op<"copy", [SameTypeOperands]> {
def MemCpyOp : CIR_Op<"libc.memcpy"> {
let arguments = (ins Arg<CIR_PointerType, "", [MemWrite]>:$dst,
Arg<CIR_PointerType, "", [MemRead]>:$src,
CIR_IntType:$len);
PrimitiveInt:$len);
let summary = "Equivalent to libc's `memcpy`";
let description = [{
Given two CIR pointers, `src` and `dst`, `cir.libc.memcpy` will copy `len`
Expand Down Expand Up @@ -3115,10 +3116,10 @@ def ExpectOp : CIR_Op<"expect",
where probability = $prob.
}];

let arguments = (ins CIR_IntType:$val,
CIR_IntType:$expected,
let arguments = (ins PrimitiveInt:$val,
PrimitiveInt:$expected,
OptionalAttr<F64Attr>:$prob);
let results = (outs CIR_IntType:$result);
let results = (outs PrimitiveInt:$result);
let assemblyFormat = [{
`(` $val`,` $expected (`,` $prob^)? `)` `:` type($val) attr-dict
}];
Expand Down Expand Up @@ -3524,7 +3525,7 @@ def AtomicFetch : CIR_Op<"atomic.fetch",
of the computation (`__atomic_binop_fetch`).
}];
let results = (outs CIR_AnyIntOrFloat:$result);
let arguments = (ins IntOrFPPtr:$ptr, CIR_AnyIntOrFloat:$val,
let arguments = (ins PrimitiveIntOrFPPtr:$ptr, CIR_AnyIntOrFloat:$val,
AtomicFetchKind:$binop,
Arg<MemOrder, "memory order">:$mem_order,
UnitAttr:$is_volatile,
Expand Down
61 changes: 24 additions & 37 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,21 @@ def CIR_IntType : CIR_Type<"Int", "int",
std::string getAlias() const {
return (isSigned() ? 's' : 'u') + std::to_string(getWidth()) + 'i';
};
/// Return true if this is a primitive integer type (i.e. signed or unsigned
/// integer types whose bit width is 8, 16, 32, or 64).
bool isPrimitive() const {
return isValidPrimitiveIntBitwidth(getWidth());
}

/// Returns a minimum bitwidth of cir::IntType
static unsigned minBitwidth() { return 8; }
static unsigned minBitwidth() { return 1; }
/// Returns a maximum bitwidth of cir::IntType
static unsigned maxBitwidth() { return 64; }

/// Returns true if cir::IntType can be constructed from the provided bitwidth
static bool isValidBitwidth(unsigned width) {
return width >= minBitwidth()
&& width <= maxBitwidth()
&& llvm::isPowerOf2_32(width);
/// Returns true if cir::IntType that represents a primitive integer type
/// can be constructed from the provided bitwidth.
static bool isValidPrimitiveIntBitwidth(unsigned width) {
return width == 8 || width == 16 || width == 32 || width == 64;
}
}];
let genVerifyDecl = 1;
Expand Down Expand Up @@ -109,35 +113,15 @@ def SInt16 : SInt<16>;
def SInt32 : SInt<32>;
def SInt64 : SInt<64>;

// A type constraint that allows unsigned integer type whose width is among the
// specified list of possible widths.
class UIntOfWidths<list<int> widths>
: Type<And<[
CPred<"$_self.isa<::mlir::cir::IntType>()">,
CPred<"$_self.cast<::mlir::cir::IntType>().isUnsigned()">,
Or<!foreach(
w, widths,
CPred<"$_self.cast<::mlir::cir::IntType>().getWidth() == " # w>
)>
]>,
!interleave(!foreach(w, widths, w # "-bit"), " or ") # " uint",
"::mlir::cir::IntType"
> {}

// A type constraint that allows unsigned integer type whose width is among the
// specified list of possible widths.
class SIntOfWidths<list<int> widths>
: Type<And<[
CPred<"$_self.isa<::mlir::cir::IntType>()">,
CPred<"$_self.cast<::mlir::cir::IntType>().isSigned()">,
Or<!foreach(
w, widths,
CPred<"$_self.cast<::mlir::cir::IntType>().getWidth() == " # w>
)>
]>,
!interleave(!foreach(w, widths, w # "-bit"), " or ") # " sint",
"::mlir::cir::IntType"
> {}
def PrimitiveUInt
: AnyTypeOf<[UInt8, UInt16, UInt32, UInt64], "primitive unsigned int",
"::mlir::cir::IntType">;
def PrimitiveSInt
: AnyTypeOf<[SInt8, SInt16, SInt32, SInt64], "primitive signed int",
"::mlir::cir::IntType">;
def PrimitiveInt
: AnyTypeOf<[UInt8, UInt16, UInt32, UInt64, SInt8, SInt16, SInt32, SInt64],
"primitive int", "::mlir::cir::IntType">;

//===----------------------------------------------------------------------===//
// FloatType
Expand Down Expand Up @@ -374,8 +358,8 @@ def VoidPtr : Type<
"mlir::cir::VoidType::get($_builder.getContext()))"> {
}

// Pointer to int, float or double
def IntOrFPPtr : Type<
// Pointer to a primitive int, float or double
def PrimitiveIntOrFPPtr : Type<
And<[
CPred<"$_self.isa<::mlir::cir::PointerType>()">,
CPred<"$_self.cast<::mlir::cir::PointerType>()"
Expand Down Expand Up @@ -429,6 +413,9 @@ def IntegerVector : Type<
CPred<"$_self.isa<::mlir::cir::VectorType>()">,
CPred<"$_self.cast<::mlir::cir::VectorType>()"
".getEltType().isa<::mlir::cir::IntType>()">,
CPred<"$_self.cast<::mlir::cir::VectorType>()"
".getEltType().cast<::mlir::cir::IntType>()"
".isPrimitive()">
]>, "!cir.vector of !cir.int"> {
}

Expand Down
4 changes: 2 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
case 64:
return getUInt64Ty();
default:
llvm_unreachable("Unknown bit-width");
return mlir::cir::IntType::get(getContext(), N, false);
}
}

Expand All @@ -343,7 +343,7 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
case 64:
return getSInt64Ty();
default:
llvm_unreachable("Unknown bit-width");
return mlir::cir::IntType::get(getContext(), N, true);
}
}

Expand Down
10 changes: 8 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,11 @@ mlir::Type CIRGenTypes::convertTypeForMem(clang::QualType qualType,
mlir::Type convertedType = ConvertType(qualType);

assert(!forBitField && "Bit fields NYI");
assert(!qualType->isBitIntType() && "BitIntType NYI");

// If this is a bit-precise integer type in a bitfield representation, map
// this integer to the target-specified size.
if (forBitField && qualType->isBitIntType())
assert(!qualType->isBitIntType() && "Bit field with type _BitInt NYI");

return convertedType;
}
Expand Down Expand Up @@ -725,7 +729,9 @@ mlir::Type CIRGenTypes::ConvertType(QualType T) {
break;
}
case Type::BitInt: {
assert(0 && "not implemented");
const auto *bitIntTy = cast<BitIntType>(Ty);
ResultType = mlir::cir::IntType::get(
Builder.getContext(), bitIntTy->getNumBits(), bitIntTy->isSigned());
break;
}
}
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CIR/CodeGen/CIRRecordLayoutBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ struct CIRRecordLowering final {
// structures support.
mlir::Type getBitfieldStorageType(unsigned numBits) {
unsigned alignedBits = llvm::alignTo(numBits, astContext.getCharWidth());
if (mlir::cir::IntType::isValidBitwidth(alignedBits)) {
if (mlir::cir::IntType::isValidPrimitiveIntBitwidth(alignedBits)) {
return builder.getUIntNTy(alignedBits);
} else {
mlir::Type type = getCharType();
Expand Down
18 changes: 4 additions & 14 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ struct CIROpAsmDialectInterface : public OpAsmDialectInterface {
return AliasResult::OverridableAlias;
}
if (auto intType = type.dyn_cast<IntType>()) {
// We only provide alias for standard integer types (i.e. integer types
// whose width is divisible by 8).
if (intType.getWidth() % 8 != 0)
return AliasResult::NoAlias;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a blocker for this PR, but why not?

os << intType.getAlias();
return AliasResult::OverridableAlias;
}
Expand Down Expand Up @@ -940,20 +944,6 @@ Block *BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// CmpThreeWayOp
//===----------------------------------------------------------------------===//

LogicalResult CmpThreeWayOp::verify() {
// Type of the result must be a signed integer type.
if (!getType().isSigned()) {
emitOpError() << "result type of cir.cmp3way must be a signed integer type";
return failure();
}

return success();
}

//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 2 additions & 10 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -601,12 +601,8 @@ Type IntType::parse(mlir::AsmParser &parser) {
// Fetch integer size.
if (parser.parseInteger(width))
return {};
if (width % 8 != 0) {
parser.emitError(loc, "expected integer width to be a multiple of 8");
return {};
}
if (width < 8 || width > 64) {
parser.emitError(loc, "expected integer width to be from 8 up to 64");
if (width < 1 || width > 64) {
parser.emitError(loc, "expected integer width to be from 1 up to 64");
return {};
}

Expand Down Expand Up @@ -647,10 +643,6 @@ IntType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
<< IntType::minBitwidth() << "up to " << IntType::maxBitwidth();
return mlir::failure();
}
if (width % 8 != 0) {
emitError() << "IntType width is not a multiple of 8";
return mlir::failure();
}

return mlir::success();
}
Expand Down
22 changes: 22 additions & 0 deletions clang/test/CIR/CodeGen/bitint.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir-enable -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s

void VLATest(_BitInt(3) A, _BitInt(42) B, _BitInt(17) C) {
int AR1[A];
int AR2[B];
int AR3[C];
}

// CHECK: cir.func @VLATest
// CHECK: %[[#A:]] = cir.load %{{.+}} : cir.ptr <!cir.int<s, 3>>, !cir.int<s, 3>
// CHECK-NEXT: %[[#A_PROMOTED:]] = cir.cast(integral, %[[#A]] : !cir.int<s, 3>), !u64i
// CHECK-NEXT: %[[#SP:]] = cir.stack_save : !cir.ptr<!u8i>
// CHECK-NEXT: cir.store %[[#SP]], %{{.+}} : !cir.ptr<!u8i>, cir.ptr <!cir.ptr<!u8i>>
// CHECK-NEXT: %{{.+}} = cir.alloca !s32i, cir.ptr <!s32i>, %[[#A_PROMOTED]] : !u64i
// CHECK-NEXT: %[[#B:]] = cir.load %1 : cir.ptr <!cir.int<s, 42>>, !cir.int<s, 42>
// CHECK-NEXT: %[[#B_PROMOTED:]] = cir.cast(integral, %[[#B]] : !cir.int<s, 42>), !u64i
// CHECK-NEXT: %{{.+}} = cir.alloca !s32i, cir.ptr <!s32i>, %[[#B_PROMOTED]] : !u64i
// CHECK-NEXT: %[[#C:]] = cir.load %2 : cir.ptr <!cir.int<s, 17>>, !cir.int<s, 17>
// CHECK-NEXT: %[[#C_PROMOTED:]] = cir.cast(integral, %[[#C]] : !cir.int<s, 17>), !u64i
// CHECK-NEXT: %{{.+}} = cir.alloca !s32i, cir.ptr <!s32i>, %[[#C_PROMOTED]] : !u64i
// CHECK: }
Loading
Loading