-
Notifications
You must be signed in to change notification settings - Fork 167
[CIR] Add GEP flags to ptr stride op #1863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f9c26ce
6b1cf57
6bbbd84
92b58e8
da7c2af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -386,6 +386,24 @@ def CIR_PtrDiffOp : CIR_Op<"ptr_diff", [Pure, SameTypeOperands]> { | |||||||||||||
//===----------------------------------------------------------------------===// | ||||||||||||||
// PtrStrideOp | ||||||||||||||
//===----------------------------------------------------------------------===// | ||||||||||||||
def CIR_GEPNone : I32BitEnumCaseNone<"none">; | ||||||||||||||
def CIR_GEPInboundsFlag : I32BitEnumCaseBit<"inboundsFlag", 0, "inbounds_flag">; | ||||||||||||||
def CIR_GEPNusw : I32BitEnumCaseBit<"nusw", 1>; | ||||||||||||||
def CIR_GEPNuw : I32BitEnumCaseBit<"nuw", 2>; | ||||||||||||||
def CIR_GEPInbounds | ||||||||||||||
: BitEnumCaseGroup<"inbounds", [CIR_GEPInboundsFlag, CIR_GEPNusw]>; | ||||||||||||||
|
||||||||||||||
def CIR_GEPNoWrapFlags | ||||||||||||||
: CIR_I32BitEnum<"CIR_GEPNoWrapFlags", "::cir::CIR_GEPNoWrapFlags", | ||||||||||||||
[CIR_GEPNone, CIR_GEPInboundsFlag, CIR_GEPNusw, CIR_GEPNuw, | ||||||||||||||
CIR_GEPInbounds]> { | ||||||||||||||
let cppNamespace = "::cir"; | ||||||||||||||
let printBitEnumPrimaryGroups = 1; | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
def CIR_GEPNoWrapFlagsProp : EnumProp<CIR_GEPNoWrapFlags> { | ||||||||||||||
let defaultValue = interfaceType#"::none"; | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
def CIR_PtrStrideOp : CIR_Op<"ptr_stride",[ | ||||||||||||||
Pure, AllTypesMatch<["base", "result"]> | ||||||||||||||
|
@@ -397,19 +415,23 @@ def CIR_PtrStrideOp : CIR_Op<"ptr_stride",[ | |||||||||||||
|
||||||||||||||
```mlir | ||||||||||||||
%3 = cir.const 0 : i32 | ||||||||||||||
|
||||||||||||||
%4 = cir.ptr_stride(%2 : !cir.ptr<i32>, %3 : i32), !cir.ptr<i32> | ||||||||||||||
|
||||||||||||||
%4 = cir.ptr_stride(%2 : !cir.ptr<i32>, %3 : i32, inbounds), !cir.ptr<i32> | ||||||||||||||
|
||||||||||||||
%4 = cir.ptr_stride(%2 : !cir.ptr<i32>, %3 : i32, inbounds|nuw), !cir.ptr<i32> | ||||||||||||||
Comment on lines
+421
to
+423
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
||||||||||||||
``` | ||||||||||||||
}]; | ||||||||||||||
|
||||||||||||||
let arguments = (ins | ||||||||||||||
CIR_PointerType:$base, | ||||||||||||||
CIR_AnyFundamentalIntType:$stride | ||||||||||||||
); | ||||||||||||||
let arguments = (ins CIR_PointerType:$base, CIR_AnyFundamentalIntType:$stride, | ||||||||||||||
badumbatish marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
CIR_GEPNoWrapFlagsProp:$noWrapFlags); | ||||||||||||||
|
||||||||||||||
let results = (outs CIR_PointerType:$result); | ||||||||||||||
|
||||||||||||||
let assemblyFormat = [{ | ||||||||||||||
`(` $base `:` qualified(type($base)) `,` $stride `:` qualified(type($stride)) `)` | ||||||||||||||
`(` $base `:` qualified(type($base)) `,` $stride `:` qualified(type($stride))(`,` $noWrapFlags^)?`)` | ||||||||||||||
`,` qualified(type($result)) attr-dict | ||||||||||||||
Comment on lines
+434
to
435
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should have been Then I would suggest There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah i can file an issue for that |
||||||||||||||
}]; | ||||||||||||||
|
||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -92,6 +92,25 @@ void walkRegionSkipping(mlir::Region ®ion, | |
}); | ||
} | ||
|
||
/// Convert from a CIR PtrStrideOp kind to an LLVM IR equivalent of GEP. | ||
mlir::LLVM::GEPNoWrapFlags | ||
convertPtrStrideKindToGEPFlags(cir::CIR_GEPNoWrapFlags flags) { | ||
static std::unordered_map<cir::CIR_GEPNoWrapFlags, mlir::LLVM::GEPNoWrapFlags> | ||
mp = { | ||
{cir::CIR_GEPNoWrapFlags::none, mlir::LLVM::GEPNoWrapFlags::none}, | ||
{cir::CIR_GEPNoWrapFlags::inbounds, | ||
mlir::LLVM::GEPNoWrapFlags::inbounds}, | ||
{cir::CIR_GEPNoWrapFlags::inboundsFlag, | ||
mlir::LLVM::GEPNoWrapFlags::inboundsFlag}, | ||
{cir::CIR_GEPNoWrapFlags::nusw, mlir::LLVM::GEPNoWrapFlags::nusw}, | ||
{cir::CIR_GEPNoWrapFlags::nuw, mlir::LLVM::GEPNoWrapFlags::nuw}, | ||
}; | ||
mlir::LLVM::GEPNoWrapFlags x = mlir::LLVM::GEPNoWrapFlags::none; | ||
for (auto [key, _] : mp) | ||
x = x | mp.at(flags & key); | ||
return x; | ||
Comment on lines
+98
to
+111
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't use a This should just be a series of |
||
} | ||
|
||
/// Convert from a CIR comparison kind to an LLVM IR integral comparison kind. | ||
mlir::LLVM::ICmpPredicate convertCmpKindToICmpPredicate(cir::CmpOpKind kind, | ||
bool isSigned) { | ||
|
@@ -1023,9 +1042,9 @@ mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite( | |
isUnsigned = strideTy.isUnsigned(); | ||
index = promoteIndex(rewriter, index, *layoutWidth, isUnsigned); | ||
} | ||
|
||
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>( | ||
ptrStrideOp, resultTy, elementTy, adaptor.getBase(), index); | ||
ptrStrideOp, resultTy, elementTy, adaptor.getBase(), index, | ||
convertPtrStrideKindToGEPFlags(adaptor.getNoWrapFlags())); | ||
return mlir::success(); | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,49 +1,64 @@ | ||
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir | ||
// RUN: FileCheck --input-file=%t.cir %s | ||
// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR | ||
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o - 2>&1 | FileCheck %s --check-prefix=LLVM | ||
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o - 2>&1 | FileCheck %s --check-prefix=OGCG | ||
|
||
// Should generate basic pointer arithmetics. | ||
void foo(int *iptr, char *cptr, unsigned ustride) { | ||
*(iptr + 2) = 1; | ||
// CHECK: %[[#STRIDE:]] = cir.const #cir.int<2> : !s32i | ||
// CHECK: cir.ptr_stride(%{{.+}} : !cir.ptr<!s32i>, %[[#STRIDE]] : !s32i), !cir.ptr<!s32i> | ||
// CIR: %[[#STRIDE:]] = cir.const #cir.int<2> : !s32i | ||
// CIR: cir.ptr_stride(%{{.+}} : !cir.ptr<!s32i>, %[[#STRIDE]] : !s32i, inbounds), !cir.ptr<!s32i> | ||
// LLVM: getelementptr inbounds | ||
// OGCG: getelementptr inbounds | ||
*(cptr + 3) = 1; | ||
// CHECK: %[[#STRIDE:]] = cir.const #cir.int<3> : !s32i | ||
// CHECK: cir.ptr_stride(%{{.+}} : !cir.ptr<!s8i>, %[[#STRIDE]] : !s32i), !cir.ptr<!s8i> | ||
// CIR: %[[#STRIDE:]] = cir.const #cir.int<3> : !s32i | ||
// CIR: cir.ptr_stride(%{{.+}} : !cir.ptr<!s8i>, %[[#STRIDE]] : !s32i, inbounds), !cir.ptr<!s8i> | ||
// LLVM: getelementptr inbounds | ||
// OGCG: getelementptr inbounds | ||
*(iptr - 2) = 1; | ||
// CHECK: %[[#STRIDE:]] = cir.const #cir.int<2> : !s32i | ||
// CHECK: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#STRIDE]]) : !s32i, !s32i | ||
// CHECK: cir.ptr_stride(%{{.+}} : !cir.ptr<!s32i>, %[[#NEGSTRIDE]] : !s32i), !cir.ptr<!s32i> | ||
// CIR: %[[#STRIDE:]] = cir.const #cir.int<2> : !s32i | ||
// CIR: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#STRIDE]]) : !s32i, !s32i | ||
// CIR: cir.ptr_stride(%{{.+}} : !cir.ptr<!s32i>, %[[#NEGSTRIDE]] : !s32i, inbounds), !cir.ptr<!s32i> | ||
// LLVM: getelementptr inbounds | ||
// OGCG: getelementptr inbounds | ||
*(cptr - 3) = 1; | ||
// CHECK: %[[#STRIDE:]] = cir.const #cir.int<3> : !s32i | ||
// CHECK: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#STRIDE]]) : !s32i, !s32i | ||
// CHECK: cir.ptr_stride(%{{.+}} : !cir.ptr<!s8i>, %[[#NEGSTRIDE]] : !s32i), !cir.ptr<!s8i> | ||
// CIR: %[[#STRIDE:]] = cir.const #cir.int<3> : !s32i | ||
// CIR: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#STRIDE]]) : !s32i, !s32i | ||
// CIR: cir.ptr_stride(%{{.+}} : !cir.ptr<!s8i>, %[[#NEGSTRIDE]] : !s32i, inbounds), !cir.ptr<!s8i> | ||
// LLVM: getelementptr inbounds | ||
// OGCG: getelementptr inbounds | ||
*(iptr + ustride) = 1; | ||
// CHECK: %[[#STRIDE:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!u32i>, !u32i | ||
// CHECK: cir.ptr_stride(%{{.+}} : !cir.ptr<!s32i>, %[[#STRIDE]] : !u32i), !cir.ptr<!s32i> | ||
// CIR: %[[#STRIDE:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!u32i>, !u32i | ||
// CIR: cir.ptr_stride(%{{.+}} : !cir.ptr<!s32i>, %[[#STRIDE]] : !u32i, inbounds|nuw), !cir.ptr<!s32i> | ||
|
||
// LLVM: getelementptr inbounds nuw | ||
// OGCG: getelementptr inbounds nuw | ||
|
||
// Must convert unsigned stride to a signed one. | ||
*(iptr - ustride) = 1; | ||
// CHECK: %[[#STRIDE:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!u32i>, !u32i | ||
// CHECK: %[[#SIGNSTRIDE:]] = cir.cast(integral, %[[#STRIDE]] : !u32i), !s32i | ||
// CHECK: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#SIGNSTRIDE]]) : !s32i, !s32i | ||
// CHECK: cir.ptr_stride(%{{.+}} : !cir.ptr<!s32i>, %[[#NEGSTRIDE]] : !s32i), !cir.ptr<!s32i> | ||
// CIR: %[[#STRIDE:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!u32i>, !u32i | ||
// CIR: %[[#SIGNSTRIDE:]] = cir.cast(integral, %[[#STRIDE]] : !u32i), !s32i | ||
// CIR: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#SIGNSTRIDE]]) : !s32i, !s32i | ||
// CIR: cir.ptr_stride(%{{.+}} : !cir.ptr<!s32i>, %[[#NEGSTRIDE]] : !s32i, inbounds), !cir.ptr<!s32i> | ||
// LLVM: getelementptr inbounds | ||
// OGCG: getelementptr inbounds | ||
} | ||
|
||
void testPointerSubscriptAccess(int *ptr) { | ||
// CHECK: testPointerSubscriptAccess | ||
// CIR: testPointerSubscriptAccess | ||
ptr[1] = 2; | ||
// CHECK: %[[#V1:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i> | ||
// CHECK: %[[#V2:]] = cir.const #cir.int<1> : !s32i | ||
// CHECK: cir.ptr_stride(%[[#V1]] : !cir.ptr<!s32i>, %[[#V2]] : !s32i), !cir.ptr<!s32i> | ||
// CIR: %[[#V1:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i> | ||
// CIR: %[[#V2:]] = cir.const #cir.int<1> : !s32i | ||
// CIR: cir.ptr_stride(%[[#V1]] : !cir.ptr<!s32i>, %[[#V2]] : !s32i), !cir.ptr<!s32i> | ||
} | ||
|
||
void testPointerMultiDimSubscriptAccess(int **ptr) { | ||
// CHECK: testPointerMultiDimSubscriptAccess | ||
// CIR: testPointerMultiDimSubscriptAccess | ||
ptr[1][2] = 3; | ||
// CHECK: %[[#V1:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!cir.ptr<!s32i>>>, !cir.ptr<!cir.ptr<!s32i>> | ||
// CHECK: %[[#V2:]] = cir.const #cir.int<1> : !s32i | ||
// CHECK: %[[#V3:]] = cir.ptr_stride(%[[#V1]] : !cir.ptr<!cir.ptr<!s32i>>, %[[#V2]] : !s32i), !cir.ptr<!cir.ptr<!s32i>> | ||
// CHECK: %[[#V4:]] = cir.load{{.*}} %[[#V3]] : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i> | ||
// CHECK: %[[#V5:]] = cir.const #cir.int<2> : !s32i | ||
// CHECK: cir.ptr_stride(%[[#V4]] : !cir.ptr<!s32i>, %[[#V5]] : !s32i), !cir.ptr<!s32i> | ||
// CIR: %[[#V1:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!cir.ptr<!s32i>>>, !cir.ptr<!cir.ptr<!s32i>> | ||
// CIR: %[[#V2:]] = cir.const #cir.int<1> : !s32i | ||
// CIR: %[[#V3:]] = cir.ptr_stride(%[[#V1]] : !cir.ptr<!cir.ptr<!s32i>>, %[[#V2]] : !s32i), !cir.ptr<!cir.ptr<!s32i>> | ||
// CIR: %[[#V4:]] = cir.load{{.*}} %[[#V3]] : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i> | ||
// CIR: %[[#V5:]] = cir.const #cir.int<2> : !s32i | ||
// CIR: cir.ptr_stride(%[[#V4]] : !cir.ptr<!s32i>, %[[#V5]] : !s32i), !cir.ptr<!s32i> | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at other dialects, I think this should be: