Skip to content

Commit 3122e76

Browse files
committed
[CIR][CUDA] Register __device__ global variables
1 parent e4003cd commit 3122e76

File tree

3 files changed

+108
-4
lines changed

3 files changed

+108
-4
lines changed

clang/include/clang/CIR/Dialect/IR/CIRDataLayout.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ class CIRDataLayout {
129129
mlir::Type getCharType(mlir::MLIRContext *ctx) const {
130130
return typeSizeInfo.getCharType(ctx);
131131
}
132+
133+
mlir::Type getSizeType(mlir::MLIRContext *ctx) const {
134+
return typeSizeInfo.getSizeType(ctx);
135+
}
132136
};
133137

134138
/// Used to lazily calculate structure layout information for a target machine,

clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,17 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
127127

128128
// Maps CUDA kernel name to device stub function.
129129
llvm::StringMap<FuncOp> cudaKernelMap;
130+
// Maps CUDA device-side variable name to host-side (shadow) GlobalOp.
131+
llvm::StringMap<GlobalOp> cudaVarMap;
130132

131133
void buildCUDAModuleCtor();
132134
std::optional<FuncOp> buildCUDAModuleDtor();
133135
std::optional<FuncOp> buildCUDARegisterGlobals();
134136

135137
void buildCUDARegisterGlobalFunctions(cir::CIRBaseBuilderTy &builder,
136138
FuncOp regGlobalFunc);
139+
void buildCUDARegisterVars(cir::CIRBaseBuilderTy &builder,
140+
FuncOp regGlobalFunc);
137141

138142
///
139143
/// AST related
@@ -1198,8 +1202,7 @@ std::optional<FuncOp> LoweringPreparePass::buildCUDARegisterGlobals() {
11981202
builder.setInsertionPointToStart(regGlobalFunc.addEntryBlock());
11991203

12001204
buildCUDARegisterGlobalFunctions(builder, regGlobalFunc);
1201-
1202-
// TODO(cir): registration for global variables.
1205+
buildCUDARegisterVars(builder, regGlobalFunc);
12031206

12041207
ReturnOp::create(builder, loc);
12051208
return regGlobalFunc;
@@ -1273,6 +1276,83 @@ void LoweringPreparePass::buildCUDARegisterGlobalFunctions(
12731276
}
12741277
}
12751278

1279+
void LoweringPreparePass::buildCUDARegisterVars(cir::CIRBaseBuilderTy &builder,
1280+
FuncOp regGlobalFunc) {
1281+
auto loc = theModule.getLoc();
1282+
auto cudaPrefix = getCUDAPrefix(astCtx);
1283+
1284+
auto voidTy = VoidType::get(&getContext());
1285+
auto voidPtrTy = PointerType::get(voidTy);
1286+
auto voidPtrPtrTy = PointerType::get(voidPtrTy);
1287+
auto intTy = datalayout->getIntType(&getContext());
1288+
auto charTy = datalayout->getCharType(&getContext());
1289+
auto sizeTy = datalayout->getSizeType(&getContext());
1290+
1291+
// Extract the GPU binary handle argument.
1292+
mlir::Value fatbinHandle = *regGlobalFunc.args_begin();
1293+
1294+
cir::CIRBaseBuilderTy globalBuilder(getContext());
1295+
globalBuilder.setInsertionPointToStart(theModule.getBody());
1296+
1297+
// Declare CUDA internal function:
1298+
// void __cudaRegisterVar(
1299+
// void **fatbinHandle,
1300+
// char *hostVarName,
1301+
// char *deviceVarName,
1302+
// const char *deviceVarName,
1303+
// int isExtern, size_t varSize,
1304+
// int isConstant, int zero
1305+
// );
1306+
// Similar to the registration of global functions, OG does not care about
1307+
// pointer types. They will generate the same IR anyway.
1308+
1309+
FuncOp cudaRegisterVar = buildRuntimeFunction(
1310+
globalBuilder, addUnderscoredPrefix(cudaPrefix, "RegisterVar"), loc,
1311+
FuncType::get({voidPtrPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, intTy,
1312+
sizeTy, intTy, intTy},
1313+
voidTy));
1314+
1315+
unsigned int count = 0;
1316+
auto makeConstantString = [&](llvm::StringRef str) -> GlobalOp {
1317+
auto strType = ArrayType::get(&getContext(), charTy, 1 + str.size());
1318+
1319+
auto tmpString = GlobalOp::create(
1320+
globalBuilder, loc, (".str" + str + std::to_string(count++)).str(),
1321+
strType, /*isConstant=*/true,
1322+
/*linkage=*/cir::GlobalLinkageKind::PrivateLinkage);
1323+
1324+
// We must make the string zero-terminated.
1325+
tmpString.setInitialValueAttr(ConstArrayAttr::get(
1326+
strType, StringAttr::get(&getContext(), str + "\0")));
1327+
tmpString.setPrivate();
1328+
return tmpString;
1329+
};
1330+
1331+
for (auto &[deviceSideName, global] : cudaVarMap) {
1332+
GlobalOp deviceNameStr = makeConstantString(deviceSideName);
1333+
mlir::Value deviceNameValue = builder.createBitcast(
1334+
builder.createGetGlobal(deviceNameStr), voidPtrTy);
1335+
1336+
GlobalOp hostNameStr = makeConstantString(global.getName());
1337+
mlir::Value hostNameValue =
1338+
builder.createBitcast(builder.createGetGlobal(hostNameStr), voidPtrTy);
1339+
1340+
// Every device variable that has a shadow on host will not be extern.
1341+
// See CIRGenModule::emitGlobalVarDefinition.
1342+
auto isExtern = ConstantOp::create(builder, loc, IntAttr::get(intTy, 0));
1343+
llvm::TypeSize size = datalayout->getTypeSizeInBits(global.getSymType());
1344+
auto varSize = ConstantOp::create(
1345+
builder, loc, IntAttr::get(sizeTy, size.getFixedValue() / 8));
1346+
auto isConstant = ConstantOp::create(
1347+
builder, loc, IntAttr::get(intTy, global.getConstant()));
1348+
auto zero = ConstantOp::create(builder, loc, IntAttr::get(intTy, 0));
1349+
builder.createCallOp(loc, cudaRegisterVar,
1350+
{fatbinHandle, hostNameValue, deviceNameValue,
1351+
deviceNameValue, isExtern, varSize, isConstant,
1352+
zero});
1353+
}
1354+
}
1355+
12761356
std::optional<FuncOp> LoweringPreparePass::buildCUDAModuleDtor() {
12771357
if (!theModule->getAttr(CIRDialect::getCUDABinaryHandleAttrName()))
12781358
return {};
@@ -1585,8 +1665,13 @@ void LoweringPreparePass::runOnOp(Operation *op) {
15851665
lowerVAArgOp(vaArgOp);
15861666
} else if (auto deleteArrayOp = dyn_cast<DeleteArrayOp>(op)) {
15871667
lowerDeleteArrayOp(deleteArrayOp);
1588-
} else if (auto getGlobal = dyn_cast<GlobalOp>(op)) {
1589-
lowerGlobalOp(getGlobal);
1668+
} else if (auto global = dyn_cast<GlobalOp>(op)) {
1669+
lowerGlobalOp(global);
1670+
if (auto attr = op->getAttr(cir::CUDAShadowNameAttr::getMnemonic())) {
1671+
auto shadowNameAttr = dyn_cast<CUDAShadowNameAttr>(attr);
1672+
std::string deviceSideName = shadowNameAttr.getDeviceSideName();
1673+
cudaVarMap[deviceSideName] = global;
1674+
}
15901675
} else if (auto dynamicCast = dyn_cast<DynamicCastOp>(op)) {
15911676
lowerDynamicCastOp(dynamicCast);
15921677
} else if (auto stdFind = dyn_cast<StdFindOp>(op)) {

clang/test/CIR/CodeGen/CUDA/registration.cu

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050

5151
__global__ void fn() {}
5252

53+
__device__ int a;
54+
5355
// CIR-HOST: cir.func internal private @__cuda_register_globals(%[[FatbinHandle:[a-zA-Z0-9]+]]{{.*}}) {
5456
// CIR-HOST: %[[#NULL:]] = cir.const #cir.ptr<null>
5557
// CIR-HOST: %[[#T1:]] = cir.get_global @".str_Z2fnv"
@@ -64,6 +66,16 @@ __global__ void fn() {}
6466
// CIR-HOST-SAME: %[[#DeviceFn]],
6567
// CIR-HOST-SAME: %[[#MinusOne]],
6668
// CIR-HOST-SAME: %[[#NULL]], %[[#NULL]], %[[#NULL]], %[[#NULL]], %[[#NULL]])
69+
// CIR-HOST: %[[#T3:]] = cir.get_global @".stra0"
70+
// CIR-HOST: %[[#Device:]] = cir.cast bitcast %7
71+
// CIR-HOST: %[[#T4:]] = cir.get_global @".stra1"
72+
// CIR-HOST: %[[#Host:]] = cir.cast bitcast %9
73+
// CIR-HOST: %[[#Ext:]] = cir.const #cir.int<0>
74+
// CIR-HOST: %[[#Sz:]] = cir.const #cir.int<4>
75+
// CIR-HOST: %[[#Const:]] = cir.const #cir.int<0>
76+
// CIR-HOST: %[[#Zero:]] = cir.const #cir.int<0>
77+
// CIR-HOST: cir.call @__cudaRegisterVar(%arg0, %[[#Host]], %[[#Device]], %[[#Device]],
78+
// CIR-HOST-SAME: %[[#Ext]], %[[#Sz]], %[[#Const]], %[[#Zero]])
6779
// CIR-HOST: }
6880

6981
// LLVM-HOST: define internal void @__cuda_register_globals(ptr %[[#LLVMFatbin:]]) {
@@ -74,6 +86,9 @@ __global__ void fn() {}
7486
// LLVM-HOST-SAME: ptr @.str_Z2fnv,
7587
// LLVM-HOST-SAME: i32 -1,
7688
// LLVM-HOST-SAME: ptr null, ptr null, ptr null, ptr null, ptr null)
89+
// LLVM-HOST: call void @__cudaRegisterVar(
90+
// LLVM-HOST-SAME: ptr %0, ptr @.stra1, ptr @.stra0, ptr @.stra0,
91+
// LLVM-HOST-SAME: i32 0, i64 4, i32 0, i32 0)
7792
// LLVM-HOST: }
7893

7994
// The content in const array should be the same as echoed above,

0 commit comments

Comments
 (0)