Skip to content

Commit 71e7fac

Browse files
committed
[SYCL] Support connection with multiple plugins
This commit enables including multiple devices of the same device_type and changed the logic of device selection to just prefer a SYCL_BE device if present. If someone uses SYCL_BE but appropriate device is not present, we will simply use any other device. Signed-off-by: Artur Gainullin <artur.gainullin@intel.com>
1 parent 8445ee8 commit 71e7fac

File tree

11 files changed

+231
-105
lines changed

11 files changed

+231
-105
lines changed

sycl/include/CL/sycl/detail/pi.hpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,19 @@ enum class PiApiKind {
4343
class plugin;
4444
namespace pi {
4545

46+
// The SYCL_PI_TRACE sets what we will trace.
47+
// This is a bit-mask of various things we'd want to trace.
48+
enum TraceLevel {
49+
PI_TRACE_BASIC = 0x1,
50+
PI_TRACE_CALLS = 0x2,
51+
PI_TRACE_ALL = -1
52+
};
53+
54+
// Return true if we want to trace PI related activities.
55+
bool trace(TraceLevel level);
56+
57+
const char *traceLabel();
58+
4659
#ifdef SYCL_RT_OS_WINDOWS
4760
#define OPENCL_PLUGIN_NAME "pi_opencl.dll"
4861
#define CUDA_PLUGIN_NAME "pi_cuda.dll"
@@ -115,8 +128,8 @@ void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName);
115128
// environment variable.
116129
enum Backend { SYCL_BE_PI_OPENCL, SYCL_BE_PI_CUDA, SYCL_BE_PI_OTHER };
117130

118-
// Check for manually selected BE at run-time.
119-
bool useBackend(Backend Backend);
131+
// Get the preferred BE (selected with SYCL_BE).
132+
Backend getPreferredBE();
120133

121134
// Get a string representing a _pi_platform_info enum
122135
std::string platformInfoToString(pi_platform_info info);

sycl/source/detail/pi.cpp

Lines changed: 107 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
#include <cstring>
2323
#include <iostream>
2424
#include <map>
25+
#include <sstream>
2526
#include <stddef.h>
2627
#include <string>
27-
#include <sstream>
2828

2929
#ifdef XPTI_ENABLE_INSTRUMENTATION
3030
// Include the headers necessary for emitting
@@ -141,39 +141,80 @@ std::string memFlagsToString(pi_mem_flags Flags) {
141141
return Sstream.str();
142142
}
143143

144-
// Check for manually selected BE at run-time.
145-
static Backend getBackend() {
146-
static const char *GetEnv = std::getenv("SYCL_BE");
147-
// Current default backend as SYCL_BE_PI_OPENCL
148-
// Valid values of GetEnv are "PI_OPENCL", "PI_CUDA" and "PI_OTHER"
149-
std::string StringGetEnv = (GetEnv ? GetEnv : "PI_OPENCL");
150-
static const Backend Use =
151-
std::map<std::string, Backend>{
152-
{ "PI_OPENCL", SYCL_BE_PI_OPENCL },
153-
{ "PI_CUDA", SYCL_BE_PI_CUDA },
154-
{ "PI_OTHER", SYCL_BE_PI_OTHER }
155-
}[ GetEnv ? StringGetEnv : "PI_OPENCL"];
156-
return Use;
144+
// A singleton class to aid that PI configuration parameters
145+
// are processed only once, like reading a string from environment
146+
// and converting it into a typed object.
147+
//
148+
template <typename T, const char *E> class Config {
149+
static Config *m_Instance;
150+
T m_Data;
151+
Config();
152+
153+
public:
154+
static T get() {
155+
if (!m_Instance) {
156+
m_Instance = new Config();
157+
}
158+
return m_Instance->m_Data;
159+
}
160+
};
161+
162+
template <typename T, const char *E>
163+
Config<T, E> *Config<T, E>::m_Instance = nullptr;
164+
165+
// Lists valid configuration environment variables.
166+
static constexpr char SYCL_BE[] = "SYCL_BE";
167+
static constexpr char SYCL_INTEROP_BE[] = "SYCL_INTEROP_BE";
168+
static constexpr char SYCL_PI_TRACE[] = "SYCL_PI_TRACE";
169+
170+
// SYCL_PI_TRACE gives the mask of enabled tracing components (0 default)
171+
template <> Config<int, SYCL_PI_TRACE>::Config() {
172+
const char *Env = std::getenv(SYCL_PI_TRACE);
173+
m_Data = (Env ? std::atoi(Env) : 0);
174+
}
175+
176+
static Backend getBE(const char *EnvVar) {
177+
const char *BE = std::getenv(EnvVar);
178+
const std::map<std::string, Backend> SyclBeMap{
179+
{"PI_OTHER", SYCL_BE_PI_OTHER},
180+
{"PI_CUDA", SYCL_BE_PI_CUDA},
181+
{"PI_OPENCL", SYCL_BE_PI_OPENCL}};
182+
if (BE) {
183+
auto It = SyclBeMap.find(BE);
184+
if (It == SyclBeMap.end())
185+
pi::die("Invalid backend. "
186+
"Valid values are PI_OPENCL/PI_CUDA");
187+
return It->second;
188+
}
189+
// Default backend
190+
return SYCL_BE_PI_OPENCL;
157191
}
158192

159-
// Check for manually selected BE at run-time.
160-
bool useBackend(Backend TheBackend) {
161-
return TheBackend == getBackend();
193+
template <> Config<Backend, SYCL_BE>::Config() { m_Data = getBE(SYCL_BE); }
194+
195+
// SYCL_INTEROP_BE is a way to specify the interoperability plugin.
196+
template <> Config<Backend, SYCL_INTEROP_BE>::Config() {
197+
m_Data = getBE(SYCL_INTEROP_BE);
162198
}
163199

200+
// Helper interface to not expose "pi::Config" outside of pi.cpp
201+
Backend getPreferredBE() { return Config<Backend, SYCL_BE>::get(); }
202+
164203
// GlobalPlugin is a global Plugin used with Interoperability constructors that
165204
// use OpenCL objects to construct SYCL class objects.
166205
std::shared_ptr<plugin> GlobalPlugin;
167206

168207
// Find the plugin at the appropriate location and return the location.
169-
// TODO: Change the function appropriately when there are multiple plugins.
170-
bool findPlugins(vector_class<std::string> &PluginNames) {
208+
bool findPlugins(vector_class<std::pair<std::string, Backend>> &PluginNames) {
171209
// TODO: Based on final design discussions, change the location where the
172210
// plugin must be searched; how to identify the plugins etc. Currently the
173211
// search is done for libpi_opencl.so/pi_opencl.dll file in LD_LIBRARY_PATH
174212
// env only.
175-
PluginNames.push_back(OPENCL_PLUGIN_NAME);
176-
PluginNames.push_back(CUDA_PLUGIN_NAME);
213+
//
214+
PluginNames.push_back(std::make_pair<std::string, Backend>(
215+
OPENCL_PLUGIN_NAME, SYCL_BE_PI_OPENCL));
216+
PluginNames.push_back(
217+
std::make_pair<std::string, Backend>(CUDA_PLUGIN_NAME, SYCL_BE_PI_CUDA));
177218
return true;
178219
}
179220

@@ -207,52 +248,66 @@ bool bindPlugin(void *Library, PiPlugin *PluginInformation) {
207248
return true;
208249
}
209250

210-
// Load the plugin based on SYCL_BE.
211-
// TODO: Currently only accepting OpenCL and CUDA plugins. Edit it to identify
212-
// and load other kinds of plugins, do the required changes in the
213-
// findPlugins, loadPlugin and bindPlugin functions.
214-
vector_class<plugin> initialize() {
215-
vector_class<plugin> Plugins;
251+
bool trace(TraceLevel Level) {
252+
auto TraceLevelMask = Config<int, SYCL_PI_TRACE>::get();
253+
return (TraceLevelMask & Level) == Level;
254+
}
216255

217-
if (!useBackend(SYCL_BE_PI_OPENCL) && !useBackend(SYCL_BE_PI_CUDA)) {
218-
die("Unknown SYCL_BE");
256+
const char *traceLabel() {
257+
auto TraceLevelMask = Config<int, SYCL_PI_TRACE>::get();
258+
switch (TraceLevelMask) {
259+
case PI_TRACE_BASIC:
260+
return "SYCL_PI_TRACE[PI_TRACE_BASIC]: ";
261+
case PI_TRACE_CALLS:
262+
return "SYCL_PI_TRACE[PI_TRACE_CALLS]: ";
263+
case PI_TRACE_ALL:
264+
return "SYCL_PI_TRACE[PI_TRACE_ALL]: ";
265+
default:
266+
assert("Unsupported trace level");
219267
}
268+
return nullptr;
269+
}
220270

221-
bool EnableTrace = (std::getenv("SYCL_PI_TRACE") != nullptr);
222-
223-
vector_class<std::string> PluginNames;
271+
// Initializes all available Plugins.
272+
vector_class<plugin> initialize() {
273+
vector_class<plugin> Plugins;
274+
vector_class<std::pair<std::string, Backend>> PluginNames;
224275
findPlugins(PluginNames);
225276

226-
if (PluginNames.empty() && EnableTrace)
227-
std::cerr << "No Plugins Found." << std::endl;
277+
if (PluginNames.empty() && trace(PI_TRACE_ALL))
278+
std::cerr << traceLabel() << "No Plugins Found." << std::endl;
228279

229-
PiPlugin PluginInformation; // TODO: include.
280+
PiPlugin PluginInformation;
230281
for (unsigned int I = 0; I < PluginNames.size(); I++) {
231-
void *Library = loadPlugin(PluginNames[I]);
282+
void *Library = loadPlugin(PluginNames[I].first);
232283

233284
if (!Library) {
234-
if (EnableTrace) {
235-
std::cerr << "Check if plugin is present. Failed to load plugin: "
236-
<< PluginNames[I] << std::endl;
285+
if (trace(PI_TRACE_ALL)) {
286+
std::cerr << traceLabel() << "Check if plugin is present. "
287+
<< "Failed to load plugin: " << PluginNames[I].first
288+
<< std::endl;
237289
}
238290
continue;
239291
}
240292

241-
if (!bindPlugin(Library, &PluginInformation) && EnableTrace) {
242-
std::cerr << "Failed to bind PI APIs to the plugin: " << PluginNames[I]
243-
<< std::endl;
244-
}
245-
if (useBackend(SYCL_BE_PI_OPENCL) &&
246-
PluginNames[I].find("opencl") != std::string::npos) {
247-
// Use the OpenCL plugin as the GlobalPlugin
248-
GlobalPlugin = std::make_shared<plugin>(PluginInformation);
293+
if (!bindPlugin(Library, &PluginInformation)) {
294+
if (trace(PI_TRACE_ALL)) {
295+
std::cerr << traceLabel() << "Failed to bind PI APIs to the plugin: "
296+
<< PluginNames[I].first << std::endl;
297+
}
298+
continue;
249299
}
250-
if (useBackend(SYCL_BE_PI_CUDA) &&
251-
PluginNames[I].find("cuda") != std::string::npos) {
252-
// Use the CUDA plugin as the GlobalPlugin
253-
GlobalPlugin = std::make_shared<plugin>(PluginInformation);
300+
// Set the Global Plugin based on SYCL_INTEROP_BE.
301+
// Rework this when it will be explicit in the code which BE is used in the
302+
// interoperability methods.
303+
if (Config<Backend, SYCL_INTEROP_BE>::get() == PluginNames[I].second) {
304+
GlobalPlugin =
305+
std::make_shared<plugin>(PluginInformation, PluginNames[I].second);
254306
}
255-
Plugins.push_back(plugin(PluginInformation));
307+
Plugins.emplace_back(plugin(PluginInformation, PluginNames[I].second));
308+
if (trace(TraceLevel::PI_TRACE_BASIC))
309+
std::cerr << traceLabel() << "Plugin found and successfully loaded: "
310+
<< PluginNames[I].first << std::endl;
256311
}
257312

258313
#ifdef XPTI_ENABLE_INSTRUMENTATION

sycl/source/detail/plugin.hpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ class plugin {
2323
public:
2424
plugin() = delete;
2525

26-
plugin(RT::PiPlugin Plugin) : MPlugin(Plugin) {
27-
MPiEnableTrace = (std::getenv("SYCL_PI_TRACE") != nullptr);
28-
}
26+
plugin(RT::PiPlugin Plugin, RT::Backend UseBackend)
27+
: MPlugin(Plugin), MBackend(UseBackend) {}
2928

3029
~plugin() = default;
3130

@@ -52,13 +51,13 @@ class plugin {
5251
template <PiApiKind PiApiOffset, typename... ArgsT>
5352
RT::PiResult call_nocheck(ArgsT... Args) const {
5453
RT::PiFuncInfo<PiApiOffset> PiCallInfo;
55-
if (MPiEnableTrace) {
54+
if (pi::trace(pi::TraceLevel::PI_TRACE_CALLS)) {
5655
std::string FnName = PiCallInfo.getFuncName();
5756
std::cout << "---> " << FnName << "(" << std::endl;
5857
RT::printArgs(Args...);
5958
}
6059
RT::PiResult R = PiCallInfo.getFuncPtr(MPlugin)(Args...);
61-
if (MPiEnableTrace) {
60+
if (pi::trace(pi::TraceLevel::PI_TRACE_CALLS)) {
6261
std::cout << ") ---> ";
6362
RT::printArgs(R);
6463
}
@@ -74,10 +73,11 @@ class plugin {
7473
checkPiResult(Err);
7574
}
7675

76+
RT::Backend getBackend(void) const { return MBackend; }
77+
7778
private:
7879
RT::PiPlugin MPlugin;
79-
bool MPiEnableTrace;
80-
80+
const RT::Backend MBackend;
8181
}; // class plugin
8282
} // namespace detail
8383
} // namespace sycl

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,8 @@ static bool isDeviceBinaryTypeSupported(const context &C,
270270
}
271271

272272
// OpenCL 2.1 and greater require clCreateProgramWithIL
273-
if (pi::useBackend(pi::SYCL_BE_PI_OPENCL) &&
273+
pi::Backend CBackend = (detail::getSyclObjImpl(C)->getPlugin()).getBackend();
274+
if ((CBackend == pi::SYCL_BE_PI_OPENCL) &&
274275
C.get_platform().get_info<info::platform::version>() >= "2.1")
275276
return true;
276277

sycl/source/detail/scheduler/commands.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1672,7 +1672,7 @@ cl_int ExecCGCommand::enqueueImp() {
16721672
Requirement *Req = (Requirement *)(Arg.MPtr);
16731673
AllocaCommandBase *AllocaCmd = getAllocaForReq(Req);
16741674
RT::PiMem MemArg = (RT::PiMem)AllocaCmd->getMemAllocation();
1675-
if (RT::useBackend(pi::Backend::SYCL_BE_PI_OPENCL)) {
1675+
if (Plugin.getBackend() == (pi::Backend::SYCL_BE_PI_OPENCL)) {
16761676
Plugin.call<PiApiKind::piKernelSetArg>(Kernel, Arg.MIndex,
16771677
sizeof(RT::PiMem), &MemArg);
16781678
} else {

0 commit comments

Comments
 (0)