Skip to content

Commit 4c301eb

Browse files
authored
Improve global metadata handling wrt llvm.used (rust-lang#939)
* Improve global metadata handling wrt llvm.used * Move alloc to presllvm
1 parent 9013de3 commit 4c301eb

File tree

3 files changed

+249
-216
lines changed

3 files changed

+249
-216
lines changed

enzyme/Enzyme/Clang/EnzymeClang.cpp

+29-12
Original file line numberDiff line numberDiff line change
@@ -58,23 +58,40 @@ class EnzymePlugin final : public clang::ASTConsumer {
5858

5959
// Forcibly require emission of all libdevice
6060
for (it = dg.begin(); it != dg.end(); ++it) {
61-
auto FD = dyn_cast<FunctionDecl>(*it);
62-
if (!FD)
63-
continue;
61+
if (auto FD = dyn_cast<FunctionDecl>(*it)) {
62+
if (!FD->hasAttr<clang::CUDADeviceAttr>())
63+
continue;
6464

65-
if (!FD->hasAttr<clang::CUDADeviceAttr>())
66-
continue;
65+
if (!FD->getIdentifier())
66+
continue;
67+
if (!StringRef(FD->getLocation().printToString(CI.getSourceManager()))
68+
.contains("/__clang_cuda_math.h"))
69+
continue;
6770

68-
if (!FD->getIdentifier())
69-
continue;
70-
if (!StringRef(FD->getLocation().printToString(CI.getSourceManager()))
71-
.contains("/__clang_cuda_math.h"))
72-
continue;
73-
74-
FD->addAttr(UsedAttr::CreateImplicit(CI.getASTContext()));
71+
FD->addAttr(UsedAttr::CreateImplicit(CI.getASTContext()));
72+
}
73+
if (auto FD = dyn_cast<VarDecl>(*it)) {
74+
HandleCXXStaticMemberVarInstantiation(FD);
75+
}
7576
}
7677
return true;
7778
}
79+
void HandleCXXStaticMemberVarInstantiation(clang::VarDecl *V) override {
80+
if (!V->getIdentifier())
81+
return;
82+
auto name = V->getName();
83+
if (!(name.contains("__enzyme_inactive_global") ||
84+
name.contains("__enzyme_inactivefn") ||
85+
name.contains("__enzyme_function_like") ||
86+
name.contains("__enzyme_allocation_like") ||
87+
name.contains("__enzyme_register_gradient") ||
88+
name.contains("__enzyme_register_derivative") ||
89+
name.contains("__enzyme_register_splitderivative")))
90+
return;
91+
92+
V->addAttr(clang::UsedAttr::CreateImplicit(CI.getASTContext()));
93+
return;
94+
}
7895
};
7996

8097
// register the PluginASTAction in the registry.

enzyme/Enzyme/Enzyme.cpp

-203
Original file line numberDiff line numberDiff line change
@@ -313,203 +313,6 @@ handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g,
313313
globalsToErase.push_back(&g);
314314
}
315315

316-
static void
317-
handleInactiveFunction(llvm::Module &M, llvm::GlobalVariable &g,
318-
SmallVectorImpl<GlobalVariable *> &globalsToErase) {
319-
if (g.hasInitializer()) {
320-
Value *V = g.getInitializer();
321-
while (auto CE = dyn_cast<ConstantExpr>(V)) {
322-
V = CE->getOperand(0);
323-
}
324-
if (auto CA = dyn_cast<ConstantAggregate>(V))
325-
V = CA->getOperand(0);
326-
while (auto CE = dyn_cast<ConstantExpr>(V)) {
327-
V = CE->getOperand(0);
328-
}
329-
if (auto F = dyn_cast<Function>(V)) {
330-
F->addAttribute(AttributeList::FunctionIndex,
331-
Attribute::get(g.getContext(), "enzyme_inactive"));
332-
} else {
333-
llvm::errs() << M << "\n";
334-
llvm::errs() << "Param of __enzyme_inactivefn must be a "
335-
"function"
336-
<< g << "\n"
337-
<< *V << "\n";
338-
llvm_unreachable("__enzyme_inactivefn");
339-
}
340-
} else {
341-
llvm::errs() << M << "\n";
342-
llvm::errs() << "Use of __enzyme_inactivefn must be a "
343-
"constant function "
344-
<< g << "\n";
345-
llvm_unreachable("__enzyme_register_gradient");
346-
}
347-
globalsToErase.push_back(&g);
348-
}
349-
350-
static void
351-
handleFunctionLike(llvm::Module &M, llvm::GlobalVariable &g,
352-
SmallVectorImpl<GlobalVariable *> &globalsToErase) {
353-
if (g.hasInitializer()) {
354-
if (auto CA = dyn_cast<ConstantAggregate>(g.getInitializer())) {
355-
if (CA->getNumOperands() < 2) {
356-
llvm::errs() << M << "\n";
357-
llvm::errs() << "Use of "
358-
<< "enzyme_function_like"
359-
<< " must be a "
360-
"constant of size at least "
361-
<< 2 << " " << g << "\n";
362-
llvm_unreachable("enzyme_function_like");
363-
}
364-
Value *V = CA->getOperand(0);
365-
Value *name = CA->getOperand(1);
366-
while (auto CE = dyn_cast<ConstantExpr>(V)) {
367-
V = CE->getOperand(0);
368-
}
369-
while (auto CE = dyn_cast<ConstantExpr>(name)) {
370-
name = CE->getOperand(0);
371-
}
372-
StringRef nameVal;
373-
if (auto GV = dyn_cast<GlobalVariable>(name))
374-
if (GV->isConstant())
375-
if (auto C = GV->getInitializer())
376-
if (auto CA = dyn_cast<ConstantDataArray>(C))
377-
if (CA->getType()->getElementType()->isIntegerTy(8) &&
378-
CA->isCString())
379-
nameVal = CA->getAsCString();
380-
381-
if (nameVal == "") {
382-
llvm::errs() << *name << "\n";
383-
llvm::errs() << "Use of "
384-
<< "enzyme_function_like"
385-
<< "requires a non-empty function name"
386-
<< "\n";
387-
llvm_unreachable("enzyme_function_like");
388-
}
389-
if (auto F = dyn_cast<Function>(V)) {
390-
F->addAttribute(AttributeList::FunctionIndex,
391-
Attribute::get(g.getContext(), "enzyme_math", nameVal));
392-
} else {
393-
llvm::errs() << M << "\n";
394-
llvm::errs() << "Param of __enzyme_function_like must be a "
395-
"function"
396-
<< g << "\n"
397-
<< *V << "\n";
398-
llvm_unreachable("__enzyme_inactivefn");
399-
}
400-
} else {
401-
llvm::errs() << M << "\n";
402-
llvm::errs() << "Use of __enzyme_function_like must be a "
403-
"constant function "
404-
<< g << "\n";
405-
llvm_unreachable("__enzyme_register_gradient");
406-
}
407-
globalsToErase.push_back(&g);
408-
}
409-
}
410-
411-
static void
412-
handleAllocationLike(llvm::Module &M, llvm::GlobalVariable &g,
413-
SmallVectorImpl<GlobalVariable *> &globalsToErase) {
414-
if (g.hasInitializer()) {
415-
if (auto CA = dyn_cast<ConstantAggregate>(g.getInitializer())) {
416-
if (CA->getNumOperands() != 4) {
417-
llvm::errs() << M << "\n";
418-
llvm::errs() << "Use of "
419-
<< "enzyme_allocation_like"
420-
<< " must be a "
421-
"constant of size at least "
422-
<< 4 << " " << g << "\n";
423-
llvm_unreachable("enzyme_allocation_like");
424-
}
425-
Value *V = CA->getOperand(0);
426-
Value *name = CA->getOperand(1);
427-
while (auto CE = dyn_cast<ConstantExpr>(V)) {
428-
V = CE->getOperand(0);
429-
}
430-
while (auto CE = dyn_cast<ConstantExpr>(name)) {
431-
name = CE->getOperand(0);
432-
}
433-
Value *deallocind = CA->getOperand(2);
434-
while (auto CE = dyn_cast<ConstantExpr>(deallocind)) {
435-
deallocind = CE->getOperand(0);
436-
}
437-
Value *deallocfn = CA->getOperand(3);
438-
while (auto CE = dyn_cast<ConstantExpr>(deallocfn)) {
439-
deallocfn = CE->getOperand(0);
440-
}
441-
size_t index = 0;
442-
if (auto CI = dyn_cast<ConstantInt>(name)) {
443-
index = CI->getZExtValue();
444-
} else {
445-
llvm::errs() << *name << "\n";
446-
llvm::errs() << "Use of "
447-
<< "enzyme_allocation_like"
448-
<< "requires an integer index"
449-
<< "\n";
450-
llvm_unreachable("enzyme_allocation_like");
451-
}
452-
453-
StringRef deallocIndStr;
454-
bool foundInd = false;
455-
if (auto GV = dyn_cast<GlobalVariable>(deallocind))
456-
if (GV->isConstant())
457-
if (auto C = GV->getInitializer())
458-
if (auto CA = dyn_cast<ConstantDataArray>(C))
459-
if (CA->getType()->getElementType()->isIntegerTy(8) &&
460-
CA->isCString()) {
461-
deallocIndStr = CA->getAsCString();
462-
foundInd = true;
463-
}
464-
465-
if (!foundInd) {
466-
llvm::errs() << *deallocind << "\n";
467-
llvm::errs() << "Use of "
468-
<< "enzyme_allocation_like"
469-
<< "requires a deallocation index string"
470-
<< "\n";
471-
llvm_unreachable("enzyme_allocation_like");
472-
}
473-
if (auto F = dyn_cast<Function>(V)) {
474-
F->addAttribute(AttributeList::FunctionIndex,
475-
Attribute::get(g.getContext(), "enzyme_allocator",
476-
std::to_string(index)));
477-
} else {
478-
llvm::errs() << M << "\n";
479-
llvm::errs() << "Param of __enzyme_allocation_like must be a "
480-
"function"
481-
<< g << "\n"
482-
<< *V << "\n";
483-
llvm_unreachable("__enzyme_allocation_like");
484-
}
485-
cast<Function>(V)->addAttribute(
486-
AttributeList::FunctionIndex,
487-
Attribute::get(g.getContext(), "enzyme_deallocator", deallocIndStr));
488-
489-
if (auto F = dyn_cast<Function>(deallocfn)) {
490-
cast<Function>(V)->setMetadata(
491-
"enzyme_deallocator_fn",
492-
llvm::MDTuple::get(F->getContext(),
493-
{llvm::ValueAsMetadata::get(F)}));
494-
} else {
495-
llvm::errs() << M << "\n";
496-
llvm::errs() << "Free fn of __enzyme_allocation_like must be a "
497-
"function"
498-
<< g << "\n"
499-
<< *deallocfn << "\n";
500-
llvm_unreachable("__enzyme_allocation_like");
501-
}
502-
} else {
503-
llvm::errs() << M << "\n";
504-
llvm::errs() << "Use of __enzyme_allocation_like must be a "
505-
"constant function "
506-
<< g << "\n";
507-
llvm_unreachable("__enzyme_allocation_like");
508-
}
509-
globalsToErase.push_back(&g);
510-
}
511-
}
512-
513316
static void handleKnownFunctions(llvm::Function &F) {
514317
if (F.getName() == "memcmp") {
515318
F.addFnAttr(Attribute::ReadOnly);
@@ -2550,12 +2353,6 @@ class Enzyme final : public ModulePass {
25502353
handleCustomDerivative<splitderivative_handler_name,
25512354
DerivativeMode::ForwardModeSplit, 3>(
25522355
M, g, globalsToErase);
2553-
} else if (g.getName().contains("__enzyme_inactivefn")) {
2554-
handleInactiveFunction(M, g, globalsToErase);
2555-
} else if (g.getName().contains("__enzyme_function_like")) {
2556-
handleFunctionLike(M, g, globalsToErase);
2557-
} else if (g.getName().contains("__enzyme_allocation_like")) {
2558-
handleAllocationLike(M, g, globalsToErase);
25592356
}
25602357
}
25612358
for (auto g : globalsToErase) {

0 commit comments

Comments
 (0)