Skip to content

[clang-repl] : Fix clang-repl crash with --cuda flag #136404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions clang/include/clang/Interpreter/Interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class CXXRecordDecl;
class Decl;
class IncrementalExecutor;
class IncrementalParser;
class IncrementalCUDADeviceParser;

/// Create a pre-configured \c CompilerInstance for incremental processing.
class IncrementalCompilerBuilder {
Expand Down Expand Up @@ -93,7 +94,10 @@ class Interpreter {
std::unique_ptr<IncrementalExecutor> IncrExecutor;

// An optional parser for CUDA offloading
std::unique_ptr<IncrementalParser> DeviceParser;
std::unique_ptr<IncrementalCUDADeviceParser> DeviceParser;

// An optional action for CUDA offloading
std::unique_ptr<IncrementalAction> DeviceAct;

/// List containing information about each incrementally parsed piece of code.
std::list<PartialTranslationUnit> PTUs;
Expand Down Expand Up @@ -175,10 +179,11 @@ class Interpreter {
llvm::Expected<Expr *> ExtractValueFromExpr(Expr *E);
llvm::Expected<llvm::orc::ExecutorAddr> CompileDtorCall(CXXRecordDecl *CXXRD);

CodeGenerator *getCodeGen() const;
std::unique_ptr<llvm::Module> GenModule();
CodeGenerator *getCodeGen(IncrementalAction *Action = nullptr) const;
std::unique_ptr<llvm::Module> GenModule(IncrementalAction *Action = nullptr);
PartialTranslationUnit &RegisterPTU(TranslationUnitDecl *TU,
std::unique_ptr<llvm::Module> M = {});
std::unique_ptr<llvm::Module> M = {},
IncrementalAction *Action = nullptr);

// A cache for the compiled destructors used to for de-allocation of managed
// clang::Values.
Expand Down
45 changes: 15 additions & 30 deletions clang/lib/Interpreter/DeviceOffload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,45 +31,17 @@ IncrementalCUDADeviceParser::IncrementalCUDADeviceParser(
llvm::Error &Err, const std::list<PartialTranslationUnit> &PTUs)
: IncrementalParser(*DeviceInstance, Err), PTUs(PTUs), VFS(FS),
CodeGenOpts(HostInstance.getCodeGenOpts()),
TargetOpts(HostInstance.getTargetOpts()) {
TargetOpts(DeviceInstance->getTargetOpts()) {
if (Err)
return;
DeviceCI = std::move(DeviceInstance);
StringRef Arch = TargetOpts.CPU;
if (!Arch.starts_with("sm_") || Arch.substr(3).getAsInteger(10, SMVersion)) {
Err = llvm::joinErrors(std::move(Err), llvm::make_error<llvm::StringError>(
"Invalid CUDA architecture",
llvm::inconvertibleErrorCode()));
return;
}
}

llvm::Expected<TranslationUnitDecl *>
IncrementalCUDADeviceParser::Parse(llvm::StringRef Input) {
auto PTU = IncrementalParser::Parse(Input);
if (!PTU)
return PTU.takeError();

auto PTX = GeneratePTX();
if (!PTX)
return PTX.takeError();

auto Err = GenerateFatbinary();
if (Err)
return std::move(Err);

std::string FatbinFileName =
"/incr_module_" + std::to_string(PTUs.size()) + ".fatbin";
VFS->addFile(FatbinFileName, 0,
llvm::MemoryBuffer::getMemBuffer(
llvm::StringRef(FatbinContent.data(), FatbinContent.size()),
"", false));

CodeGenOpts.CudaGpuBinaryFileName = FatbinFileName;

FatbinContent.clear();

return PTU;
DeviceCI = std::move(DeviceInstance);
}

llvm::Expected<llvm::StringRef> IncrementalCUDADeviceParser::GeneratePTX() {
Expand Down Expand Up @@ -172,6 +144,19 @@ llvm::Error IncrementalCUDADeviceParser::GenerateFatbinary() {

FatbinContent.append(PTXCode.begin(), PTXCode.end());

const PartialTranslationUnit &PTU = PTUs.back();

std::string FatbinFileName = "/" + PTU.TheModule->getName().str() + ".fatbin";

VFS->addFile(FatbinFileName, 0,
llvm::MemoryBuffer::getMemBuffer(
llvm::StringRef(FatbinContent.data(), FatbinContent.size()),
"", false));

CodeGenOpts.CudaGpuBinaryFileName = FatbinFileName;

FatbinContent.clear();

return llvm::Error::success();
}

Expand Down
2 changes: 0 additions & 2 deletions clang/lib/Interpreter/DeviceOffload.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ class IncrementalCUDADeviceParser : public IncrementalParser {
llvm::IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> VFS,
llvm::Error &Err, const std::list<PartialTranslationUnit> &PTUs);

llvm::Expected<TranslationUnitDecl *> Parse(llvm::StringRef Input) override;

// Generate PTX for the last PTU.
llvm::Expected<llvm::StringRef> GeneratePTX();

Expand Down
59 changes: 44 additions & 15 deletions clang/lib/Interpreter/Interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,20 +481,34 @@ Interpreter::createWithCUDA(std::unique_ptr<CompilerInstance> CI,
OverlayVFS->pushOverlay(IMVFS);
CI->createFileManager(OverlayVFS);

auto Interp = Interpreter::create(std::move(CI));
if (auto E = Interp.takeError())
return std::move(E);
llvm::Expected<std::unique_ptr<Interpreter>> InterpOrErr =
Interpreter::create(std::move(CI));
if (!InterpOrErr)
return InterpOrErr;

std::unique_ptr<Interpreter> Interp = std::move(*InterpOrErr);

llvm::Error Err = llvm::Error::success();
auto DeviceParser = std::make_unique<IncrementalCUDADeviceParser>(
std::move(DCI), *(*Interp)->getCompilerInstance(), IMVFS, Err,
(*Interp)->PTUs);
llvm::LLVMContext &LLVMCtx = *Interp->TSCtx->getContext();

auto DeviceAct =
std::make_unique<IncrementalAction>(*DCI, LLVMCtx, Err, *Interp);

if (Err)
return std::move(Err);

(*Interp)->DeviceParser = std::move(DeviceParser);
Interp->DeviceAct = std::move(DeviceAct);

DCI->ExecuteAction(*Interp->DeviceAct);

auto DeviceParser = std::make_unique<IncrementalCUDADeviceParser>(
std::move(DCI), *Interp->getCompilerInstance(), IMVFS, Err, Interp->PTUs);

if (Err)
return std::move(Err);

return Interp;
Interp->DeviceParser = std::move(DeviceParser);
return std::move(Interp);
}

const CompilerInstance *Interpreter::getCompilerInstance() const {
Expand Down Expand Up @@ -532,15 +546,17 @@ size_t Interpreter::getEffectivePTUSize() const {

PartialTranslationUnit &
Interpreter::RegisterPTU(TranslationUnitDecl *TU,
std::unique_ptr<llvm::Module> M /*={}*/) {
std::unique_ptr<llvm::Module> M /*={}*/,
IncrementalAction *Action) {
PTUs.emplace_back(PartialTranslationUnit());
PartialTranslationUnit &LastPTU = PTUs.back();
LastPTU.TUPart = TU;

if (!M)
M = GenModule();
M = GenModule(Action);

assert((!getCodeGen() || M) && "Must have a llvm::Module at this point");
assert((!getCodeGen(Action) || M) &&
"Must have a llvm::Module at this point");

LastPTU.TheModule = std::move(M);
LLVM_DEBUG(llvm::dbgs() << "compile-ptu " << PTUs.size() - 1
Expand All @@ -560,6 +576,16 @@ Interpreter::Parse(llvm::StringRef Code) {
llvm::Expected<TranslationUnitDecl *> DeviceTU = DeviceParser->Parse(Code);
if (auto E = DeviceTU.takeError())
return std::move(E);

RegisterPTU(*DeviceTU, nullptr, DeviceAct.get());

llvm::Expected<llvm::StringRef> PTX = DeviceParser->GeneratePTX();
if (!PTX)
return PTX.takeError();

llvm::Error Err = DeviceParser->GenerateFatbinary();
if (Err)
return std::move(Err);
}

// Tell the interpreter sliently ignore unused expressions since value
Expand Down Expand Up @@ -736,9 +762,10 @@ llvm::Error Interpreter::LoadDynamicLibrary(const char *name) {
return llvm::Error::success();
}

std::unique_ptr<llvm::Module> Interpreter::GenModule() {
std::unique_ptr<llvm::Module>
Interpreter::GenModule(IncrementalAction *Action) {
static unsigned ID = 0;
if (CodeGenerator *CG = getCodeGen()) {
if (CodeGenerator *CG = getCodeGen(Action)) {
// Clang's CodeGen is designed to work with a single llvm::Module. In many
// cases for convenience various CodeGen parts have a reference to the
// llvm::Module (TheModule or Module) which does not change when a new
Expand All @@ -760,8 +787,10 @@ std::unique_ptr<llvm::Module> Interpreter::GenModule() {
return nullptr;
}

CodeGenerator *Interpreter::getCodeGen() const {
FrontendAction *WrappedAct = Act->getWrapped();
CodeGenerator *Interpreter::getCodeGen(IncrementalAction *Action) const {
if (!Action)
Action = Act.get();
FrontendAction *WrappedAct = Action->getWrapped();
if (!WrappedAct->hasIRSupport())
return nullptr;
return static_cast<CodeGenAction *>(WrappedAct)->getCodeGenerator();
Expand Down
Loading