diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp index 62cd9d48d399..3e66ccee4ebe 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp @@ -535,9 +535,16 @@ static CIRGenCallee emitDirectCallee(CIRGenModule &CGM, GlobalDecl GD) { mlir::Operation *CalleePtr = emitFunctionDeclPointer(CGM, GD); - if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice && - FD->hasAttr()) - CalleePtr = CGM.getCUDARuntime().getKernelStub(CalleePtr); + if ((CGM.getLangOpts().HIP || CGM.getLangOpts().CUDA) && + !CGM.getLangOpts().CUDAIsDevice && FD->hasAttr()) { + + // Ensure the handle is created and use it as the lookup key. + auto *Handle = CGM.getCUDARuntime().getKernelHandle( + llvm::cast(CalleePtr), GD); + + // Now look up the stub via the handle + CalleePtr = CGM.getCUDARuntime().getKernelStub(Handle); + } return CIRGenCallee::forDirect(CalleePtr, GD); } @@ -1560,8 +1567,6 @@ RValue CIRGenFunction::emitCall(clang::QualType CalleeType, Callee.setFunctionPointer(Fn); } - assert(!CGM.getLangOpts().HIP && "HIP NYI"); - assert(!MustTailCall && "Must tail NYI"); cir::CIRCallOpInterface callOP; RValue Call = emitCall(FnInfo, Callee, ReturnValue, Args, &callOP, diff --git a/clang/lib/CIR/CodeGen/CIRGenModule.cpp b/clang/lib/CIR/CodeGen/CIRGenModule.cpp index d80f8b70964e..362ec7828265 100644 --- a/clang/lib/CIR/CodeGen/CIRGenModule.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenModule.cpp @@ -2548,20 +2548,9 @@ cir::FuncOp CIRGenModule::GetAddrOfFunction(clang::GlobalDecl gd, mlir::Type ty, auto f = GetOrCreateCIRFunction(mangledName, ty, gd, forVTable, dontDefer, /*IsThunk=*/false, isForDefinition); - // As __global__ functions (kernels) always reside on device, - // when we access them from host, we must refer to the kernel handle. - // For HIP, we should never directly access the host device addr, but - // instead the Global Variable of that stub. For CUDA, it's just the device - // stub. For HIP, it's something different. if ((langOpts.HIP || langOpts.CUDA) && !langOpts.CUDAIsDevice && - cast(gd.getDecl())->hasAttr()) { + cast(gd.getDecl())->hasAttr()) (void)getCUDARuntime().getKernelHandle(f, gd); - if (isForDefinition) - return f; - - if (langOpts.HIP) - llvm_unreachable("NYI"); - } return f; } diff --git a/clang/test/CIR/CodeGen/HIP/simple.cpp b/clang/test/CIR/CodeGen/HIP/simple.cpp index 9c00b149e1c7..27e6e5a4cd67 100644 --- a/clang/test/CIR/CodeGen/HIP/simple.cpp +++ b/clang/test/CIR/CodeGen/HIP/simple.cpp @@ -32,3 +32,21 @@ __global__ void global_fn(int a) {} // The stub has the mangled name of the function // CIR-HOST: cir.get_global @_Z9global_fni // CIR-HOST: cir.call @hipLaunchKernel + +int main() { + global_fn<<<1, 1>>>(1); +} +// CIR-DEVICE-NOT: cir.func dso_local @main() + +// CIR-HOST: cir.func dso_local @main() +// CIR-HOST: cir.call @_ZN4dim3C1Ejjj +// CIR-HOST: cir.call @_ZN4dim3C1Ejjj +// CIR-HOST: [[Push:%[0-9]+]] = cir.call @__hipPushCallConfiguration +// CIR-HOST: [[ConfigOK:%[0-9]+]] = cir.cast int_to_bool [[Push]] +// CIR-HOST: cir.if [[ConfigOK]] { +// CIR-HOST: } else { +// CIR-HOST: [[Arg:%[0-9]+]] = cir.const #cir.int<1> +// CIR-HOST: cir.call @_Z24__device_stub__global_fni([[Arg]]) +// CIR-HOST: } + +