Skip to content

Commit

Permalink
[CIR] Add initial support for bit-precise integer types (llvm#538)
Browse files Browse the repository at this point in the history
This PR adds initial support for the bit-precise integer type
`_BitInt(N)`. This type goes into the C23 standard, and has already been
supported by clang since 2020, previously known as `_ExtInt(N)`.

This PR is quite simple and straight-forward. Basically it leverages the
existing `cir.int` type to represent such types. Previously `cir.int`
verifies that its width must be a multiple of 8, and this verification
has been removed in this PR.
  • Loading branch information
Lancern authored and lanza committed Oct 12, 2024
1 parent ee583ff commit 0fa6185
Show file tree
Hide file tree
Showing 11 changed files with 208 additions and 103 deletions.
41 changes: 21 additions & 20 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def ObjSizeOp : CIR_Op<"objsize", [Pure]> {

let arguments = (ins CIR_PointerType:$ptr, SizeInfoType:$kind,
UnitAttr:$dynamic);
let results = (outs CIR_IntType:$result);
let results = (outs PrimitiveInt:$result);

let assemblyFormat = [{
`(`
Expand Down Expand Up @@ -180,7 +180,7 @@ def PtrDiffOp : CIR_Op<"ptr_diff", [Pure, SameTypeOperands]> {
```
}];

let results = (outs CIR_IntType:$result);
let results = (outs PrimitiveInt:$result);
let arguments = (ins CIR_PointerType:$lhs, CIR_PointerType:$rhs);

let assemblyFormat = [{
Expand Down Expand Up @@ -208,7 +208,7 @@ def PtrStrideOp : CIR_Op<"ptr_stride",
```
}];

let arguments = (ins CIR_PointerType:$base, CIR_IntType:$stride);
let arguments = (ins CIR_PointerType:$base, PrimitiveInt:$stride);
let results = (outs CIR_PointerType:$result);

let assemblyFormat = [{
Expand Down Expand Up @@ -337,7 +337,7 @@ def AllocaOp : CIR_Op<"alloca", [
}];

let arguments = (ins
Optional<CIR_IntType>:$dynAllocSize,
Optional<PrimitiveInt>:$dynAllocSize,
TypeAttr:$allocaType,
StrAttr:$name,
UnitAttr:$init,
Expand Down Expand Up @@ -1034,7 +1034,7 @@ class CIR_BitOp<string mnemonic, TypeConstraint inputTy>
}];
}

def BitClrsbOp : CIR_BitOp<"bit.clrsb", SIntOfWidths<[32, 64]>> {
def BitClrsbOp : CIR_BitOp<"bit.clrsb", AnyTypeOf<[SInt32, SInt64]>> {
let summary = "Get the number of leading redundant sign bits in the input";
let description = [{
Compute the number of leading redundant sign bits in the input integer.
Expand Down Expand Up @@ -1065,7 +1065,7 @@ def BitClrsbOp : CIR_BitOp<"bit.clrsb", SIntOfWidths<[32, 64]>> {
}];
}

def BitClzOp : CIR_BitOp<"bit.clz", UIntOfWidths<[16, 32, 64]>> {
def BitClzOp : CIR_BitOp<"bit.clz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
let summary = "Get the number of leading 0-bits in the input";
let description = [{
Compute the number of leading 0-bits in the input.
Expand All @@ -1090,7 +1090,7 @@ def BitClzOp : CIR_BitOp<"bit.clz", UIntOfWidths<[16, 32, 64]>> {
}];
}

def BitCtzOp : CIR_BitOp<"bit.ctz", UIntOfWidths<[16, 32, 64]>> {
def BitCtzOp : CIR_BitOp<"bit.ctz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
let summary = "Get the number of trailing 0-bits in the input";
let description = [{
Compute the number of trailing 0-bits in the input.
Expand All @@ -1115,7 +1115,7 @@ def BitCtzOp : CIR_BitOp<"bit.ctz", UIntOfWidths<[16, 32, 64]>> {
}];
}

def BitFfsOp : CIR_BitOp<"bit.ffs", SIntOfWidths<[32, 64]>> {
def BitFfsOp : CIR_BitOp<"bit.ffs", AnyTypeOf<[SInt32, SInt64]>> {
let summary = "Get the position of the least significant 1-bit of input";
let description = [{
Compute the position of the least significant 1-bit of the input.
Expand All @@ -1138,7 +1138,7 @@ def BitFfsOp : CIR_BitOp<"bit.ffs", SIntOfWidths<[32, 64]>> {
}];
}

def BitParityOp : CIR_BitOp<"bit.parity", UIntOfWidths<[32, 64]>> {
def BitParityOp : CIR_BitOp<"bit.parity", AnyTypeOf<[UInt32, UInt64]>> {
let summary = "Get the parity of input";
let description = [{
Compute the parity of the input. The parity of an integer is the number of
Expand All @@ -1160,7 +1160,8 @@ def BitParityOp : CIR_BitOp<"bit.parity", UIntOfWidths<[32, 64]>> {
}];
}

def BitPopcountOp : CIR_BitOp<"bit.popcount", UIntOfWidths<[16, 32, 64]>> {
def BitPopcountOp
: CIR_BitOp<"bit.popcount", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
let summary = "Get the number of 1-bits in input";
let description = [{
Compute the number of 1-bits in the input.
Expand Down Expand Up @@ -1208,7 +1209,7 @@ def ByteswapOp : CIR_Op<"bswap", [Pure, SameOperandsAndResultType]> {
}];

let results = (outs CIR_IntType:$result);
let arguments = (ins UIntOfWidths<[16, 32, 64]>:$input);
let arguments = (ins AnyTypeOf<[UInt16, UInt32, UInt64]>:$input);

let assemblyFormat = [{
`(` $input `:` type($input) `)` `:` type($result) attr-dict
Expand Down Expand Up @@ -1252,7 +1253,7 @@ def CmpThreeWayOp : CIR_Op<"cmp3way", [Pure, SameTypeOperands]> {
```
}];

let results = (outs CIR_IntType:$result);
let results = (outs PrimitiveSInt:$result);
let arguments = (ins CIR_AnyType:$lhs, CIR_AnyType:$rhs,
CmpThreeWayInfoAttr:$info);

Expand All @@ -1261,7 +1262,7 @@ def CmpThreeWayOp : CIR_Op<"cmp3way", [Pure, SameTypeOperands]> {
`:` type($result) attr-dict
}];

let hasVerifier = 1;
let hasVerifier = 0;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2122,7 +2123,7 @@ def VecInsertOp : CIR_Op<"vec.insert", [Pure,
element is returned.
}];

let arguments = (ins CIR_VectorType:$vec, AnyType:$value, CIR_IntType:$index);
let arguments = (ins CIR_VectorType:$vec, AnyType:$value, PrimitiveInt:$index);
let results = (outs CIR_VectorType:$result);

let assemblyFormat = [{
Expand All @@ -2147,7 +2148,7 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure,
from a vector object.
}];

let arguments = (ins CIR_VectorType:$vec, CIR_IntType:$index);
let arguments = (ins CIR_VectorType:$vec, PrimitiveInt:$index);
let results = (outs CIR_AnyType:$result);

let assemblyFormat = [{
Expand Down Expand Up @@ -2935,7 +2936,7 @@ def CopyOp : CIR_Op<"copy", [SameTypeOperands]> {
def MemCpyOp : CIR_Op<"libc.memcpy"> {
let arguments = (ins Arg<CIR_PointerType, "", [MemWrite]>:$dst,
Arg<CIR_PointerType, "", [MemRead]>:$src,
CIR_IntType:$len);
PrimitiveInt:$len);
let summary = "Equivalent to libc's `memcpy`";
let description = [{
Given two CIR pointers, `src` and `dst`, `cir.libc.memcpy` will copy `len`
Expand Down Expand Up @@ -3115,10 +3116,10 @@ def ExpectOp : CIR_Op<"expect",
where probability = $prob.
}];

let arguments = (ins CIR_IntType:$val,
CIR_IntType:$expected,
let arguments = (ins PrimitiveInt:$val,
PrimitiveInt:$expected,
OptionalAttr<F64Attr>:$prob);
let results = (outs CIR_IntType:$result);
let results = (outs PrimitiveInt:$result);
let assemblyFormat = [{
`(` $val`,` $expected (`,` $prob^)? `)` `:` type($val) attr-dict
}];
Expand Down Expand Up @@ -3524,7 +3525,7 @@ def AtomicFetch : CIR_Op<"atomic.fetch",
of the computation (`__atomic_binop_fetch`).
}];
let results = (outs CIR_AnyIntOrFloat:$result);
let arguments = (ins IntOrFPPtr:$ptr, CIR_AnyIntOrFloat:$val,
let arguments = (ins PrimitiveIntOrFPPtr:$ptr, CIR_AnyIntOrFloat:$val,
AtomicFetchKind:$binop,
Arg<MemOrder, "memory order">:$mem_order,
UnitAttr:$is_volatile,
Expand Down
61 changes: 24 additions & 37 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,21 @@ def CIR_IntType : CIR_Type<"Int", "int",
std::string getAlias() const {
return (isSigned() ? 's' : 'u') + std::to_string(getWidth()) + 'i';
};
/// Return true if this is a primitive integer type (i.e. signed or unsigned
/// integer types whose bit width is 8, 16, 32, or 64).
bool isPrimitive() const {
return isValidPrimitiveIntBitwidth(getWidth());
}

/// Returns a minimum bitwidth of cir::IntType
static unsigned minBitwidth() { return 8; }
static unsigned minBitwidth() { return 1; }
/// Returns a maximum bitwidth of cir::IntType
static unsigned maxBitwidth() { return 64; }

/// Returns true if cir::IntType can be constructed from the provided bitwidth
static bool isValidBitwidth(unsigned width) {
return width >= minBitwidth()
&& width <= maxBitwidth()
&& llvm::isPowerOf2_32(width);
/// Returns true if cir::IntType that represents a primitive integer type
/// can be constructed from the provided bitwidth.
static bool isValidPrimitiveIntBitwidth(unsigned width) {
return width == 8 || width == 16 || width == 32 || width == 64;
}
}];
let genVerifyDecl = 1;
Expand Down Expand Up @@ -109,35 +113,15 @@ def SInt16 : SInt<16>;
def SInt32 : SInt<32>;
def SInt64 : SInt<64>;

// A type constraint that allows unsigned integer type whose width is among the
// specified list of possible widths.
class UIntOfWidths<list<int> widths>
: Type<And<[
CPred<"$_self.isa<::mlir::cir::IntType>()">,
CPred<"$_self.cast<::mlir::cir::IntType>().isUnsigned()">,
Or<!foreach(
w, widths,
CPred<"$_self.cast<::mlir::cir::IntType>().getWidth() == " # w>
)>
]>,
!interleave(!foreach(w, widths, w # "-bit"), " or ") # " uint",
"::mlir::cir::IntType"
> {}

// A type constraint that allows unsigned integer type whose width is among the
// specified list of possible widths.
class SIntOfWidths<list<int> widths>
: Type<And<[
CPred<"$_self.isa<::mlir::cir::IntType>()">,
CPred<"$_self.cast<::mlir::cir::IntType>().isSigned()">,
Or<!foreach(
w, widths,
CPred<"$_self.cast<::mlir::cir::IntType>().getWidth() == " # w>
)>
]>,
!interleave(!foreach(w, widths, w # "-bit"), " or ") # " sint",
"::mlir::cir::IntType"
> {}
def PrimitiveUInt
: AnyTypeOf<[UInt8, UInt16, UInt32, UInt64], "primitive unsigned int",
"::mlir::cir::IntType">;
def PrimitiveSInt
: AnyTypeOf<[SInt8, SInt16, SInt32, SInt64], "primitive signed int",
"::mlir::cir::IntType">;
def PrimitiveInt
: AnyTypeOf<[UInt8, UInt16, UInt32, UInt64, SInt8, SInt16, SInt32, SInt64],
"primitive int", "::mlir::cir::IntType">;

//===----------------------------------------------------------------------===//
// FloatType
Expand Down Expand Up @@ -374,8 +358,8 @@ def VoidPtr : Type<
"mlir::cir::VoidType::get($_builder.getContext()))"> {
}

// Pointer to int, float or double
def IntOrFPPtr : Type<
// Pointer to a primitive int, float or double
def PrimitiveIntOrFPPtr : Type<
And<[
CPred<"$_self.isa<::mlir::cir::PointerType>()">,
CPred<"$_self.cast<::mlir::cir::PointerType>()"
Expand Down Expand Up @@ -429,6 +413,9 @@ def IntegerVector : Type<
CPred<"$_self.isa<::mlir::cir::VectorType>()">,
CPred<"$_self.cast<::mlir::cir::VectorType>()"
".getEltType().isa<::mlir::cir::IntType>()">,
CPred<"$_self.cast<::mlir::cir::VectorType>()"
".getEltType().cast<::mlir::cir::IntType>()"
".isPrimitive()">
]>, "!cir.vector of !cir.int"> {
}

Expand Down
4 changes: 2 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
case 64:
return getUInt64Ty();
default:
llvm_unreachable("Unknown bit-width");
return mlir::cir::IntType::get(getContext(), N, false);
}
}

Expand All @@ -343,7 +343,7 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
case 64:
return getSInt64Ty();
default:
llvm_unreachable("Unknown bit-width");
return mlir::cir::IntType::get(getContext(), N, true);
}
}

Expand Down
10 changes: 8 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,11 @@ mlir::Type CIRGenTypes::convertTypeForMem(clang::QualType qualType,
mlir::Type convertedType = ConvertType(qualType);

assert(!forBitField && "Bit fields NYI");
assert(!qualType->isBitIntType() && "BitIntType NYI");

// If this is a bit-precise integer type in a bitfield representation, map
// this integer to the target-specified size.
if (forBitField && qualType->isBitIntType())
assert(!qualType->isBitIntType() && "Bit field with type _BitInt NYI");

return convertedType;
}
Expand Down Expand Up @@ -728,7 +732,9 @@ mlir::Type CIRGenTypes::ConvertType(QualType T) {
break;
}
case Type::BitInt: {
assert(0 && "not implemented");
const auto *bitIntTy = cast<BitIntType>(Ty);
ResultType = mlir::cir::IntType::get(
Builder.getContext(), bitIntTy->getNumBits(), bitIntTy->isSigned());
break;
}
}
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CIR/CodeGen/CIRRecordLayoutBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ struct CIRRecordLowering final {
// structures support.
mlir::Type getBitfieldStorageType(unsigned numBits) {
unsigned alignedBits = llvm::alignTo(numBits, astContext.getCharWidth());
if (mlir::cir::IntType::isValidBitwidth(alignedBits)) {
if (mlir::cir::IntType::isValidPrimitiveIntBitwidth(alignedBits)) {
return builder.getUIntNTy(alignedBits);
} else {
mlir::Type type = getCharType();
Expand Down
18 changes: 4 additions & 14 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ struct CIROpAsmDialectInterface : public OpAsmDialectInterface {
return AliasResult::OverridableAlias;
}
if (auto intType = type.dyn_cast<IntType>()) {
// We only provide alias for standard integer types (i.e. integer types
// whose width is divisible by 8).
if (intType.getWidth() % 8 != 0)
return AliasResult::NoAlias;
os << intType.getAlias();
return AliasResult::OverridableAlias;
}
Expand Down Expand Up @@ -940,20 +944,6 @@ Block *BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// CmpThreeWayOp
//===----------------------------------------------------------------------===//

LogicalResult CmpThreeWayOp::verify() {
// Type of the result must be a signed integer type.
if (!getType().isSigned()) {
emitOpError() << "result type of cir.cmp3way must be a signed integer type";
return failure();
}

return success();
}

//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 2 additions & 10 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,12 +600,8 @@ Type IntType::parse(mlir::AsmParser &parser) {
// Fetch integer size.
if (parser.parseInteger(width))
return {};
if (width % 8 != 0) {
parser.emitError(loc, "expected integer width to be a multiple of 8");
return {};
}
if (width < 8 || width > 64) {
parser.emitError(loc, "expected integer width to be from 8 up to 64");
if (width < 1 || width > 64) {
parser.emitError(loc, "expected integer width to be from 1 up to 64");
return {};
}

Expand Down Expand Up @@ -646,10 +642,6 @@ IntType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
<< IntType::minBitwidth() << "up to " << IntType::maxBitwidth();
return mlir::failure();
}
if (width % 8 != 0) {
emitError() << "IntType width is not a multiple of 8";
return mlir::failure();
}

return mlir::success();
}
Expand Down
22 changes: 22 additions & 0 deletions clang/test/CIR/CodeGen/bitint.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s

void VLATest(_BitInt(3) A, _BitInt(42) B, _BitInt(17) C) {
int AR1[A];
int AR2[B];
int AR3[C];
}

// CHECK: cir.func @VLATest
// CHECK: %[[#A:]] = cir.load %{{.+}} : cir.ptr <!cir.int<s, 3>>, !cir.int<s, 3>
// CHECK-NEXT: %[[#A_PROMOTED:]] = cir.cast(integral, %[[#A]] : !cir.int<s, 3>), !u64i
// CHECK-NEXT: %[[#SP:]] = cir.stack_save : !cir.ptr<!u8i>
// CHECK-NEXT: cir.store %[[#SP]], %{{.+}} : !cir.ptr<!u8i>, cir.ptr <!cir.ptr<!u8i>>
// CHECK-NEXT: %{{.+}} = cir.alloca !s32i, cir.ptr <!s32i>, %[[#A_PROMOTED]] : !u64i
// CHECK-NEXT: %[[#B:]] = cir.load %1 : cir.ptr <!cir.int<s, 42>>, !cir.int<s, 42>
// CHECK-NEXT: %[[#B_PROMOTED:]] = cir.cast(integral, %[[#B]] : !cir.int<s, 42>), !u64i
// CHECK-NEXT: %{{.+}} = cir.alloca !s32i, cir.ptr <!s32i>, %[[#B_PROMOTED]] : !u64i
// CHECK-NEXT: %[[#C:]] = cir.load %2 : cir.ptr <!cir.int<s, 17>>, !cir.int<s, 17>
// CHECK-NEXT: %[[#C_PROMOTED:]] = cir.cast(integral, %[[#C]] : !cir.int<s, 17>), !u64i
// CHECK-NEXT: %{{.+}} = cir.alloca !s32i, cir.ptr <!s32i>, %[[#C_PROMOTED]] : !u64i
// CHECK: }
Loading

0 comments on commit 0fa6185

Please sign in to comment.