Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions sycl/source/detail/pi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ bool bindPlugin(void *Library, PiPlugin *PluginInformation) {
}

// Load the plugin based on SYCL_BE.
// TODO: Currently only accepting OpenCL and CUDA plugins. Edit it to identify and load
// other kinds of plugins, do the required changes in the findPlugins,
// loadPlugin and bindPlugin functions.
// TODO: Currently only accepting OpenCL and CUDA plugins. Edit it to identify
// and load other kinds of plugins, do the required changes in the
// findPlugins, loadPlugin and bindPlugin functions.
vector_class<plugin> initialize() {
vector_class<plugin> Plugins;

Expand Down Expand Up @@ -196,11 +196,18 @@ vector_class<plugin> initialize() {
std::cerr << "Failed to bind PI APIs to the plugin: " << PluginNames[I]
<< std::endl;
}
if (useBackend(SYCL_BE_PI_OPENCL) &&
PluginNames[I].find("opencl") != std::string::npos) {
// Use the OpenCL plugin as the GlobalPlugin
GlobalPlugin = std::make_shared<plugin>(PluginInformation);
}
if (useBackend(SYCL_BE_PI_CUDA) &&
PluginNames[I].find("cuda") != std::string::npos) {
// Use the CUDA plugin as the GlobalPlugin
GlobalPlugin = std::make_shared<plugin>(PluginInformation);
}
Plugins.push_back(plugin(PluginInformation));
}
// TODO: Correct the logic to store the appropriate plugin into GlobalPlugin
// variable. Currently it saves the last plugin found.
GlobalPlugin = std::make_shared<plugin>(PluginInformation);
return Plugins;
}

Expand Down
19 changes: 19 additions & 0 deletions sycl/source/device_selector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,25 @@ device device_selector::select_device() const {
}

int default_selector::operator()(const device &dev) const {

// Take note of the SYCL_BE environment variable when doing default selection
const char *SYCL_BE = std::getenv("SYCL_BE");
std::string backend = (SYCL_BE ? SYCL_BE : "");
if (backend != "") {
Copy link

@bjoernknafla bjoernknafla Feb 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If would be more efficient (no string comparison) to check SYCL_BE for being NULL - though probably does not matter much:

Suggested change
if (backend != "") {
if (SYCL_BE != nullptr) {

// Taking the version information from the platform gives us more useful
// information than the driver_version of the device.
const platform Platform = dev.get_info<info::device::platform>();
const std::string PlatformVersion = Platform.get_info<info::platform::version>();
// If using PI_CUDA, don't accept a non-CUDA device
if (PlatformVersion.find("CUDA") == std::string::npos && backend == "PI_CUDA") {
return -1;
}
// If using PI_OPENCL, don't accept a non-OpenCL device
if (PlatformVersion.find("OpenCL") == std::string::npos && backend == "PI_OPENCL") {
return -1;
}
}

if (dev.is_gpu())
return 500;

Expand Down