Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 48 additions & 15 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6666,6 +6666,34 @@ class FreeFunctionPrinter {
FD->getTemplateSpecializationArgs());
}

/// Emits free function kernel info specialization for shimN.
/// \param ShimCounter The counter for the shim function.
/// \param KParamsSize The number of kernel free function arguments.
/// \param KName The name of the kernel free function.
void printFreeFunctionKernelInfo(const unsigned ShimCounter,
const size_t KParamsSize,
std::string_view KName) {
O << "\n";
O << "namespace sycl {\n";
O << "inline namespace _V1 {\n";
O << "namespace detail {\n";
O << "//Free Function Kernel info specialization for shim" << ShimCounter
<< "\n";
O << "template <> struct FreeFunctionInfoData<__sycl_shim" << ShimCounter
<< "()> {\n";
O << " __SYCL_DLL_LOCAL\n";
O << " static constexpr unsigned getNumParams() { return " << KParamsSize
<< "; }\n";
O << " __SYCL_DLL_LOCAL\n";
O << " static constexpr const char *getFunctionName() { return ";
O << "\"" << KName << "\"; }\n";
O << "};\n";
O << "} // namespace detail\n"
<< "} // namespace _V1\n"
<< "} // namespace sycl\n";
O << "\n";
}

private:
/// Helper method to get string with template types
/// \param TAL The template argument list.
Expand Down Expand Up @@ -6915,6 +6943,11 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
O << " \"\",\n";
O << "};\n\n";

O << "static constexpr unsigned kernel_args_sizes[] = {";
for (unsigned I = 0; I < KernelDescs.size(); I++) {
O << KernelDescs[I].Params.size() << ", ";
}
O << "};\n\n";
O << "// array representing signatures of all kernels defined in the\n";
O << "// corresponding source\n";
O << "static constexpr\n";
Expand Down Expand Up @@ -7127,6 +7160,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
FFPrinter.printFreeFunctionShim(K.SyclKernel, ShimCounter, ParmList);
O << ";\n";
O << "}\n";
FFPrinter.printFreeFunctionKernelInfo(ShimCounter, K.Params.size(), K.Name);
Policy.SuppressDefaultTemplateArgs = true;
Policy.EnforceDefaultTemplateArgs = false;

Expand Down Expand Up @@ -7156,22 +7190,21 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {

if (FreeFunctionCount > 0) {
O << "\n#include <sycl/kernel_bundle.hpp>\n";
}
ShimCounter = 1;
for (const KernelDesc &K : KernelDescs) {
if (!S.isFreeFunction(K.SyclKernel))
continue;

O << "\n// Definition of kernel_id of " << K.Name << "\n";
O << "#include <sycl/detail/kernel_global_info.hpp>\n";
O << "namespace sycl {\n";
O << "template <>\n";
O << "inline kernel_id ext::oneapi::experimental::get_kernel_id<__sycl_shim"
<< ShimCounter << "()>() {\n";
O << " return sycl::detail::get_kernel_id_impl(std::string_view{\""
<< K.Name << "\"});\n";
O << "}\n";
O << "}\n";
++ShimCounter;
O << "inline namespace _V1 {\n";
O << "namespace detail {\n";
O << "struct GlobalMapUpdater {\n";
O << " GlobalMapUpdater() {\n";
O << " sycl::detail::free_function_info_map::add("
<< "sycl::detail::kernel_names, sycl::detail::kernel_args_sizes, "
<< KernelDescs.size() << ");\n";
O << " }\n";
O << "};\n";
O << "static GlobalMapUpdater updater;\n";
O << "} // namespace detail\n";
O << "} // namespace _V1\n";
O << "} // namespace sycl\n";
}
}

Expand Down
Loading
Loading