@@ -127,13 +127,17 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
127127
128128 // Maps CUDA kernel name to device stub function.
129129 llvm::StringMap<FuncOp> cudaKernelMap;
130+ // Maps CUDA device-side variable name to host-side (shadow) GlobalOp.
131+ llvm::StringMap<GlobalOp> cudaVarMap;
130132
131133 void buildCUDAModuleCtor ();
132134 std::optional<FuncOp> buildCUDAModuleDtor ();
133135 std::optional<FuncOp> buildCUDARegisterGlobals ();
134136
135137 void buildCUDARegisterGlobalFunctions (cir::CIRBaseBuilderTy &builder,
136138 FuncOp regGlobalFunc);
139+ void buildCUDARegisterVars (cir::CIRBaseBuilderTy &builder,
140+ FuncOp regGlobalFunc);
137141
138142 // /
139143 // / AST related
@@ -1198,8 +1202,7 @@ std::optional<FuncOp> LoweringPreparePass::buildCUDARegisterGlobals() {
11981202 builder.setInsertionPointToStart (regGlobalFunc.addEntryBlock ());
11991203
12001204 buildCUDARegisterGlobalFunctions (builder, regGlobalFunc);
1201-
1202- // TODO(cir): registration for global variables.
1205+ buildCUDARegisterVars (builder, regGlobalFunc);
12031206
12041207 ReturnOp::create (builder, loc);
12051208 return regGlobalFunc;
@@ -1273,6 +1276,83 @@ void LoweringPreparePass::buildCUDARegisterGlobalFunctions(
12731276 }
12741277}
12751278
1279+ void LoweringPreparePass::buildCUDARegisterVars (cir::CIRBaseBuilderTy &builder,
1280+ FuncOp regGlobalFunc) {
1281+ auto loc = theModule.getLoc ();
1282+ auto cudaPrefix = getCUDAPrefix (astCtx);
1283+
1284+ auto voidTy = VoidType::get (&getContext ());
1285+ auto voidPtrTy = PointerType::get (voidTy);
1286+ auto voidPtrPtrTy = PointerType::get (voidPtrTy);
1287+ auto intTy = datalayout->getIntType (&getContext ());
1288+ auto charTy = datalayout->getCharType (&getContext ());
1289+ auto sizeTy = datalayout->getSizeType (&getContext ());
1290+
1291+ // Extract the GPU binary handle argument.
1292+ mlir::Value fatbinHandle = *regGlobalFunc.args_begin ();
1293+
1294+ cir::CIRBaseBuilderTy globalBuilder (getContext ());
1295+ globalBuilder.setInsertionPointToStart (theModule.getBody ());
1296+
1297+ // Declare CUDA internal function:
1298+ // void __cudaRegisterVar(
1299+ // void **fatbinHandle,
1300+ // char *hostVarName,
1301+ // char *deviceVarName,
1302+ // const char *deviceVarName,
1303+ // int isExtern, size_t varSize,
1304+ // int isConstant, int zero
1305+ // );
1306+ // Similar to the registration of global functions, OG does not care about
1307+ // pointer types. They will generate the same IR anyway.
1308+
1309+ FuncOp cudaRegisterVar = buildRuntimeFunction (
1310+ globalBuilder, addUnderscoredPrefix (cudaPrefix, " RegisterVar" ), loc,
1311+ FuncType::get ({voidPtrPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, intTy,
1312+ sizeTy, intTy, intTy},
1313+ voidTy));
1314+
1315+ unsigned int count = 0 ;
1316+ auto makeConstantString = [&](llvm::StringRef str) -> GlobalOp {
1317+ auto strType = ArrayType::get (&getContext (), charTy, 1 + str.size ());
1318+
1319+ auto tmpString = GlobalOp::create (
1320+ globalBuilder, loc, (" .str" + str + std::to_string (count++)).str (),
1321+ strType, /* isConstant=*/ true ,
1322+ /* linkage=*/ cir::GlobalLinkageKind::PrivateLinkage);
1323+
1324+ // We must make the string zero-terminated.
1325+ tmpString.setInitialValueAttr (ConstArrayAttr::get (
1326+ strType, StringAttr::get (&getContext (), str + " \0 " )));
1327+ tmpString.setPrivate ();
1328+ return tmpString;
1329+ };
1330+
1331+ for (auto &[deviceSideName, global] : cudaVarMap) {
1332+ GlobalOp deviceNameStr = makeConstantString (deviceSideName);
1333+ mlir::Value deviceNameValue = builder.createBitcast (
1334+ builder.createGetGlobal (deviceNameStr), voidPtrTy);
1335+
1336+ GlobalOp hostNameStr = makeConstantString (global.getName ());
1337+ mlir::Value hostNameValue =
1338+ builder.createBitcast (builder.createGetGlobal (hostNameStr), voidPtrTy);
1339+
1340+ // Every device variable that has a shadow on host will not be extern.
1341+ // See CIRGenModule::emitGlobalVarDefinition.
1342+ auto isExtern = ConstantOp::create (builder, loc, IntAttr::get (intTy, 0 ));
1343+ llvm::TypeSize size = datalayout->getTypeSizeInBits (global.getSymType ());
1344+ auto varSize = ConstantOp::create (
1345+ builder, loc, IntAttr::get (sizeTy, size.getFixedValue () / 8 ));
1346+ auto isConstant = ConstantOp::create (
1347+ builder, loc, IntAttr::get (intTy, global.getConstant ()));
1348+ auto zero = ConstantOp::create (builder, loc, IntAttr::get (intTy, 0 ));
1349+ builder.createCallOp (loc, cudaRegisterVar,
1350+ {fatbinHandle, hostNameValue, deviceNameValue,
1351+ deviceNameValue, isExtern, varSize, isConstant,
1352+ zero});
1353+ }
1354+ }
1355+
12761356std::optional<FuncOp> LoweringPreparePass::buildCUDAModuleDtor () {
12771357 if (!theModule->getAttr (CIRDialect::getCUDABinaryHandleAttrName ()))
12781358 return {};
@@ -1585,8 +1665,13 @@ void LoweringPreparePass::runOnOp(Operation *op) {
15851665 lowerVAArgOp (vaArgOp);
15861666 } else if (auto deleteArrayOp = dyn_cast<DeleteArrayOp>(op)) {
15871667 lowerDeleteArrayOp (deleteArrayOp);
1588- } else if (auto getGlobal = dyn_cast<GlobalOp>(op)) {
1589- lowerGlobalOp (getGlobal);
1668+ } else if (auto global = dyn_cast<GlobalOp>(op)) {
1669+ lowerGlobalOp (global);
1670+ if (auto attr = op->getAttr (cir::CUDAShadowNameAttr::getMnemonic ())) {
1671+ auto shadowNameAttr = dyn_cast<CUDAShadowNameAttr>(attr);
1672+ std::string deviceSideName = shadowNameAttr.getDeviceSideName ();
1673+ cudaVarMap[deviceSideName] = global;
1674+ }
15901675 } else if (auto dynamicCast = dyn_cast<DynamicCastOp>(op)) {
15911676 lowerDynamicCastOp (dynamicCast);
15921677 } else if (auto stdFind = dyn_cast<StdFindOp>(op)) {
0 commit comments