Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,16 @@ static CIRGenCallee emitDirectCallee(CIRGenModule &CGM, GlobalDecl GD) {

mlir::Operation *CalleePtr = emitFunctionDeclPointer(CGM, GD);

if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
FD->hasAttr<CUDAGlobalAttr>())
CalleePtr = CGM.getCUDARuntime().getKernelStub(CalleePtr);
if ((CGM.getLangOpts().HIP || CGM.getLangOpts().CUDA) &&
!CGM.getLangOpts().CUDAIsDevice && FD->hasAttr<CUDAGlobalAttr>()) {

// Ensure the handle is created and use it as the lookup key.
auto *Handle = CGM.getCUDARuntime().getKernelHandle(
llvm::cast<cir::FuncOp>(CalleePtr), GD);

// Now look up the stub via the handle
CalleePtr = CGM.getCUDARuntime().getKernelStub(Handle);
}

return CIRGenCallee::forDirect(CalleePtr, GD);
}
Expand Down Expand Up @@ -1560,8 +1567,6 @@ RValue CIRGenFunction::emitCall(clang::QualType CalleeType,
Callee.setFunctionPointer(Fn);
}

assert(!CGM.getLangOpts().HIP && "HIP NYI");

assert(!MustTailCall && "Must tail NYI");
cir::CIRCallOpInterface callOP;
RValue Call = emitCall(FnInfo, Callee, ReturnValue, Args, &callOP,
Expand Down
13 changes: 1 addition & 12 deletions clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2548,20 +2548,9 @@ cir::FuncOp CIRGenModule::GetAddrOfFunction(clang::GlobalDecl gd, mlir::Type ty,
auto f = GetOrCreateCIRFunction(mangledName, ty, gd, forVTable, dontDefer,
/*IsThunk=*/false, isForDefinition);

// As __global__ functions (kernels) always reside on device,
// when we access them from host, we must refer to the kernel handle.
// For HIP, we should never directly access the host device addr, but
// instead the Global Variable of that stub. For CUDA, it's just the device
// stub. For HIP, it's something different.
if ((langOpts.HIP || langOpts.CUDA) && !langOpts.CUDAIsDevice &&
cast<FunctionDecl>(gd.getDecl())->hasAttr<CUDAGlobalAttr>()) {
cast<FunctionDecl>(gd.getDecl())->hasAttr<CUDAGlobalAttr>())
(void)getCUDARuntime().getKernelHandle(f, gd);
if (isForDefinition)
return f;

if (langOpts.HIP)
llvm_unreachable("NYI");
}

return f;
}
Expand Down
18 changes: 18 additions & 0 deletions clang/test/CIR/CodeGen/HIP/simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,21 @@ __global__ void global_fn(int a) {}
// The stub has the mangled name of the function
// CIR-HOST: cir.get_global @_Z9global_fni
// CIR-HOST: cir.call @hipLaunchKernel

int main() {
global_fn<<<1, 1>>>(1);
}
// CIR-DEVICE-NOT: cir.func dso_local @main()

// CIR-HOST: cir.func dso_local @main()
// CIR-HOST: cir.call @_ZN4dim3C1Ejjj
// CIR-HOST: cir.call @_ZN4dim3C1Ejjj
// CIR-HOST: [[Push:%[0-9]+]] = cir.call @__hipPushCallConfiguration
// CIR-HOST: [[ConfigOK:%[0-9]+]] = cir.cast int_to_bool [[Push]]
// CIR-HOST: cir.if [[ConfigOK]] {
// CIR-HOST: } else {
// CIR-HOST: [[Arg:%[0-9]+]] = cir.const #cir.int<1>
// CIR-HOST: cir.call @_Z24__device_stub__global_fni([[Arg]])
// CIR-HOST: }


Loading