Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6929,6 +6929,38 @@ class FreeFunctionPrinter {
/// \param ParmList The parameter list of the function.
void printFreeFunctionShim(const FunctionDecl *FD, const unsigned ShimCounter,
const std::string &ParmList) {
// If a free function kernel is a template, where a template parameter is a
// class template instantiation, then we need to forward declare this type
// before __sycl_shim.
if (FD->getPrimaryTemplate()) {
auto TAL = FD->getTemplateSpecializationArgs();
for (unsigned i = 0; i < TAL->size(); ++i) {
auto arg = TAL->get(i);
if (arg.getKind() != clang::TemplateArgument::Type)
continue;
clang::QualType QT = arg.getAsType().getCanonicalType();
auto RT = QT->getAs<clang::RecordType>();
if (!RT)
continue;
const clang::RecordDecl *RD = RT->getDecl();
if (RD && RD->isStruct()) {
const clang::TagType *TT = QT->getAs<clang::TagType>();
if (!TT) {
return;
};
const clang::TagDecl *TD = TT->getDecl()->getFirstDecl();
if (!TD) {
return;
}
if (auto *CTD = llvm::dyn_cast<clang::ClassTemplateSpecializationDecl>(TD))
if (auto *CTpl = CTD->getSpecializedTemplate())
CTpl->getTemplateParameters()->print(O, Context, false);

O << TD->getKindName() << " " << TD->getName() << ";\n";
}
}
}

// Generate a shim function that returns the address of the free function.
O << "static constexpr auto __sycl_shim" << ShimCounter << "() {\n";
O << " return (void (*)(" << ParmList << "))";
Expand Down
Loading