Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND] Added sycl backend support #234

Merged
merged 45 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
736db4a
Added sycl backend support
Sarbojit2019 Jan 11, 2024
9c2e602
Added coding-style scan result
Sarbojit2019 Jan 11, 2024
f7afacd
pre-commit scan result
Sarbojit2019 Jan 11, 2024
4443562
pre-commit fix
Sarbojit2019 Jan 11, 2024
32b428b
Moving to latest IPEX
Sarbojit2019 Jan 11, 2024
6f22c40
Fix for test failure
Sarbojit2019 Jan 11, 2024
819d63d
Working on test failure
Sarbojit2019 Jan 11, 2024
38871ef
Merge branch 'llvm-target' into sycl-backend
Sarbojit2019 Jan 11, 2024
156b894
Fixed last failing test
Sarbojit2019 Jan 12, 2024
0c6269f
Merge branch 'llvm-target' into sycl-backend
Sarbojit2019 Jan 12, 2024
b9a782d
reverting the script change
Sarbojit2019 Jan 12, 2024
312d631
Merge to latest
Sarbojit2019 Jan 18, 2024
d57248e
Merge to latest
Sarbojit2019 Jan 18, 2024
39d3394
Merge branch 'llvm-target' into sycl-backend
Sarbojit2019 Jan 18, 2024
5cd228c
test fix
Sarbojit2019 Jan 18, 2024
acd5bcd
Merge remote-tracking branch 'origin/sycl-backend' into sycl-backend
Sarbojit2019 Jan 18, 2024
8ad992f
Addressing review comments
Sarbojit2019 Jan 19, 2024
007a42f
Merge to latest
Sarbojit2019 Jan 19, 2024
26e661b
Merge branch 'llvm-target' into sycl-backend
Sarbojit2019 Jan 22, 2024
eb79f82
Merge branch 'llvm-target' into sycl-backend
Sarbojit2019 Jan 23, 2024
1a93565
Merge to latest
Sarbojit2019 Jan 23, 2024
acfa6d1
Fixed test
Sarbojit2019 Jan 23, 2024
4fbb967
test fix
Sarbojit2019 Jan 23, 2024
71ad1c1
Merge branch 'llvm-target' into sycl-backend
etiotto Jan 23, 2024
be7a9f8
Merge latst
Sarbojit2019 Jan 24, 2024
f2aeb85
Address review comments
Sarbojit2019 Jan 25, 2024
20ccda1
Merge branch 'llvm-target' into sycl-backend
Sarbojit2019 Jan 25, 2024
3b5bc24
Fixed test failure
Sarbojit2019 Jan 25, 2024
4877716
Fixed the test
Sarbojit2019 Jan 25, 2024
bf7a60f
clearing resources
Sarbojit2019 Jan 25, 2024
df4c365
with clang format scan
Sarbojit2019 Jan 25, 2024
11176c8
Merge remote-tracking branch 'origin/llvm-target' into sycl-backend
whitneywhtsang Jan 25, 2024
26e46b3
Address review comment
whitneywhtsang Jan 25, 2024
e9ecf3c
Temporarily revert some changes to make ci pass
whitneywhtsang Jan 26, 2024
b7edfca
Merge branch 'llvm-target' into sycl-backend
whitneywhtsang Jan 26, 2024
7f7dfd7
[NFC] Remove commented out code
whitneywhtsang Jan 26, 2024
0faf4a8
Fixed test failure
Sarbojit2019 Jan 26, 2024
7247851
Fixed pre-commit check
Sarbojit2019 Jan 26, 2024
a5553a3
Address review comment
Sarbojit2019 Jan 26, 2024
07c43fb
Address another review comment
Sarbojit2019 Jan 26, 2024
15905d7
Address another review comments
Sarbojit2019 Jan 26, 2024
ee9fe60
More clean-up
Sarbojit2019 Jan 26, 2024
50bc594
Merge branch 'llvm-target' into sycl-backend
Sarbojit2019 Jan 26, 2024
ddd4cf8
Added XPU check
Sarbojit2019 Jan 26, 2024
055f356
Address review comment
whitneywhtsang Jan 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from ..runtime.autotuner import OutOfResources
from ..runtime.cache import get_cache_manager
from ..runtime.driver import driver
from ..runtime.jit import (get_dev_ctxt_queue_objs, get_event_pool, get_imm_cmd_list)
from ..runtime.jit import (get_event_pool)
# TODO: this shouldn't be here
from ..backends.xpu.compiler import InfoFromBackendForTensorMap

Sarbojit2019 marked this conversation as resolved.
Show resolved Hide resolved
from dataclasses import dataclass
from .code_generator import ast_to_ttir
from pathlib import Path
Expand Down Expand Up @@ -284,8 +285,10 @@ def _init_handles(self):
if self.shared > max_shared:
raise OutOfResources(self.shared, max_shared, "shared memory")
# TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
self.module, self.function, self.n_regs, self.n_spills = driver.utils.load_binary(
self.name, self.kernel, self.shared, device)
import torch
self.module, self.function, self.n_regs, self.n_spills = driver.utils.load_sycl_binary(
self.name, self.kernel, self.shared,
torch.xpu.device(device).sycl_device)
Sarbojit2019 marked this conversation as resolved.
Show resolved Hide resolved

def __getattribute__(self, name):
if name == 'run':
Expand All @@ -299,13 +302,11 @@ def runner(*args, stream=None):
args_expand = driver.assemble_tensormap_to_arg(self.tensormaps_info, args)
if self.is_spirv:
use_icl = 1
if stream is None:
dev_obj, ctxt_obj, q_obj = get_dev_ctxt_queue_objs(self.is_spirv)
if q_obj == 0:
stream = get_imm_cmd_list()
else:
stream = 0
use_icl = 0
import torch
stream = torch.xpu.current_stream().sycl_queue
kurapov-peter marked this conversation as resolved.
Show resolved Hide resolved
dev_obj = 0
ctxt_obj = 0
q_obj = 0
event_pool = get_event_pool(self.is_spirv)
self.run(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.cluster_dims[0],
self.cluster_dims[1], self.cluster_dims[2], self.shared, use_icl, stream, q_obj, dev_obj,
Expand Down
10 changes: 4 additions & 6 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,12 +446,10 @@ def run(self, *args, grid, warmup, **kwargs):
if self.is_spirv:
event_pool = get_event_pool(self.is_spirv)
use_icl = 1
dev_obj, ctxt_obj, q_obj = get_dev_ctxt_queue_objs(self.is_spirv)
if q_obj == 0:
stream = get_imm_cmd_list()
else:
stream = 0
use_icl = 0
stream = torch.xpu.current_stream().sycl_queue
dev_obj = 0
ctxt_obj = 0
q_obj = 0
kernel.run(grid_0, grid_1, grid_2, kernel.num_warps,
kernel.num_ctas, # number of warps/ctas per instance
kernel.cluster_dims[0], kernel.cluster_dims[1], kernel.cluster_dims[2], # cluster
Expand Down
173 changes: 173 additions & 0 deletions third_party/xpu/backend/driver.c
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,177 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
mem_bus_width);
}

/*Sycl code Start*/
bool getBoolEnv(const std::string &env) {
const char *s = std::getenv(env.c_str());
std::string str(s ? s : "");
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c) { return std::tolower(c); });
return (str == "on" || str == "true" || str == "1");
}
#define EXPECT_EQ(value1, value2) \
Sarbojit2019 marked this conversation as resolved.
Show resolved Hide resolved
{ \
auto result = (value2); \
if ((value1) != (result)) { \
std::string err_log("L0 API error code: "); \
std::stringstream ss; \
ss << std::hex << result << std::endl; \
throw std::runtime_error(err_log + ss.str()); \
} \
}
#define EXPECT_TRUE(value1) EXPECT_EQ(true, value1)
ze_module_handle_t create_module(ze_context_handle_t context,
ze_device_handle_t device,
uint32_t *binary_ptr, size_t binary_size) {
Sarbojit2019 marked this conversation as resolved.
Show resolved Hide resolved
//std::cout<<"Inside create_module 1"<<std::endl;
const char *build_flags = "";
const ze_module_format_t format = ZE_MODULE_FORMAT_IL_SPIRV;
ze_module_desc_t module_description = {};
module_description.stype = ZE_STRUCTURE_TYPE_MODULE_DESC;
module_description.format = format;
module_description.inputSize = static_cast<uint32_t>(binary_size * sizeof(uint32_t));
module_description.pInputModule = (uint8_t *)binary_ptr;
module_description.pBuildFlags = build_flags;
ze_module_build_log_handle_t buildlog;
ze_module_handle_t module;
auto context_initial = context;
auto device_initial = device;
auto error_no = ZE_RESULT_SUCCESS;
//std::cout<<context<<" | "<<device<<" | "<<module<<" | "<<module_description.inputSize<<std::endl;
error_no = zeModuleCreate(context, device, &module_description, &module, &buildlog);
if (error_no != ZE_RESULT_SUCCESS) {
size_t szLog = 0;
EXPECT_EQ(ZE_RESULT_SUCCESS,
zeModuleBuildLogGetString(buildlog, &szLog, nullptr));
char *strLog = (char *)malloc(szLog);
EXPECT_EQ(ZE_RESULT_SUCCESS,
zeModuleBuildLogGetString(buildlog, &szLog, strLog));
std::cerr<<"L0 build module failed. Log: "<<strLog<<std::endl;
free(strLog);
EXPECT_EQ(ZE_RESULT_SUCCESS, zeModuleBuildLogDestroy(buildlog));
}
EXPECT_EQ(ZE_RESULT_SUCCESS, error_no);
return module;
}
void printModuleKernelName(ze_module_handle_t hModule) {
uint32_t Count = 0;
auto ret = zeModuleGetKernelNames(hModule, &Count, nullptr);
assert(ret == ZE_RESULT_SUCCESS);
std::unique_ptr<const char *[]> PNames(new const char *[Count]);
ret = zeModuleGetKernelNames(hModule, &Count, PNames.get());
assert(ret == ZE_RESULT_SUCCESS);
if (getBoolEnv("MLIR_ENABLE_DUMP")) {
for (uint32_t i = 0; i < Count; ++i) {
std::cout << std::string(PNames[i]) << std::endl;
}
}
}
ze_kernel_handle_t create_function(ze_module_handle_t module,
ze_kernel_flags_t flag,
std::string func_name) {
ze_kernel_handle_t kernel;
ze_kernel_desc_t kernel_description = {};
kernel_description.stype = ZE_STRUCTURE_TYPE_KERNEL_DESC;
kernel_description.pNext = nullptr;
kernel_description.flags = flag;
kernel_description.pKernelName = func_name.c_str();
auto module_initial = module;
if (getBoolEnv("MLIR_ENABLE_DUMP")) {
std::cout << "create kernel:" << func_name << std::endl;
}
kurapov-peter marked this conversation as resolved.
Show resolved Hide resolved
EXPECT_EQ(ZE_RESULT_SUCCESS,
zeKernelCreate(module, &kernel_description, &kernel));
// EXPECT_EQ(module, module_initial);
return kernel;
}
ze_kernel_handle_t create_function(ze_module_handle_t module,
std::string func_name) {
return create_function(module, 0, func_name);
Sarbojit2019 marked this conversation as resolved.
Show resolved Hide resolved
}
std::vector<std::unique_ptr<sycl::kernel>> compiled_kernel;
Sarbojit2019 marked this conversation as resolved.
Show resolved Hide resolved
static PyObject* loadSyclBinary(PyObject* self, PyObject* args) {
//std::cout<<"Inside loadSyclBinary"<<std::endl;
const char* name;
int shared;
PyObject *py_bytes;
PyObject *py_dev;
if(!PyArg_ParseTuple(args, "sSiO", &name, &py_bytes, &shared, &py_dev)) {
std::cout << "loadBloadSyclBinaryinary arg parse failed" << std::endl;
return NULL;
}
//std::cout<<"------------------"<<std::endl;
//std::cout<<PyCapsule_GetName(py_dev)<<std::endl;
//std::cout<<"------------------"<<std::endl;
int32_t n_regs = 0;
int32_t n_spills = 0;
void * pdevID = PyCapsule_GetPointer(py_dev, "torch.xpu.device.sycl_device");
//error;
if(pdevID == nullptr) return NULL;
Sarbojit2019 marked this conversation as resolved.
Show resolved Hide resolved
//module_desc.inputSize = PyBytes_Size(py_bytes);
//module_desc.pInputModule = (uint8_t*) PyBytes_AsString(py_bytes);
sycl::device device = *(static_cast<sycl::device*>(pdevID));
std::string kernel_name = name;
size_t binary_size = PyBytes_Size(py_bytes);
binary_size = binary_size/ sizeof(uint32_t);
//std::string binary; binary.reserve(binary_size);
uint32_t *binary_ptr = (uint32_t *)PyBytes_AsString(py_bytes);;
//std::cout<<"Kernle name "<<kernel_name<<std::endl;
auto ctx = device.get_platform().ext_oneapi_get_default_context();
auto l0_device =
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(device);
auto l0_context = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(ctx);
auto l0_module =
create_module(l0_context, l0_device, binary_ptr, binary_size);
//printModuleKernelName(l0_module);
auto l0_kernel = create_function(l0_module, kernel_name);
ze_kernel_properties_t props;
props.stype = ZE_STRUCTURE_TYPE_KERNEL_PROPERTIES;
props.pNext = nullptr;
EXPECT_EQ(ZE_RESULT_SUCCESS, zeKernelGetProperties(l0_kernel, &props));
n_spills = props.spillMemSize;
auto mod = sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,
sycl::bundle_state::executable>(
{l0_module, sycl::ext::oneapi::level_zero::ownership::transfer}, ctx);
auto fun = sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
{mod, l0_kernel, sycl::ext::oneapi::level_zero::ownership::transfer},
ctx);
if (getBoolEnv("MLIR_ENABLE_DUMP")) {
// auto kernel_ids = mod.get_kernel_ids();
// std::cout << "num_kernels:" << kernel_ids.size() << std::endl;
// for (auto& kernel_id : kernel_ids) {
// std::cout << "fun name: " << kernel_id.get_name() << std::endl;
// }
}
compiled_kernel.push_back(std::make_unique<sycl::kernel>(fun));
sycl::kernel *ptr = compiled_kernel[compiled_kernel.size() - 1].get();
if (getBoolEnv("MLIR_ENABLE_DUMP")) {
std::cout << "compiled kernel ptr: " << ptr << std::endl;
std::cout << "total kernels:" << compiled_kernel.size() << std::endl;
for (auto &k : compiled_kernel) {
std::cout << " kernel:"
<< k->get_info<sycl::info::kernel::function_name>() << " @"
<< k.get() << std::endl;
}
}
sycl::kernel *k = new sycl::kernel(*ptr);
Sarbojit2019 marked this conversation as resolved.
Show resolved Hide resolved
/*py::capsule kernel_capsulle(k, [](void *f) {
auto kk = static_cast<sycl::kernel *>(f);
delete kk;
});*/
sycl::kernel_bundle<sycl::bundle_state::executable> *kb =
new sycl::kernel_bundle<sycl::bundle_state::executable>(mod);
/*py::capsule module_capsulle(kb, [](void *f) {
auto kk =
static_cast<sycl::kernel_bundle<sycl::bundle_state::executable> *>(f);
delete kk;
});*/
/*py::tuple tup =
py::make_tuple(module_capsulle, kernel_capsulle, n_regs, n_spills);
return tup;*/
return Py_BuildValue("(KKii)", (uint64_t)kb, (uint64_t)k, n_regs, n_spills);
}
/*Sycl code end*/

static PyObject *loadBinary(PyObject *self, PyObject *args) {
const char *name;
int shared;
Expand Down Expand Up @@ -300,6 +471,8 @@ static PyObject *isUsingICL(PyObject *self, PyObject *args) {
static PyMethodDef ModuleMethods[] = {
{"load_binary", loadBinary, METH_VARARGS,
"Load provided SPV into ZE driver"},
{"load_sycl_binary", loadSyclBinary, METH_VARARGS,
"Load provided SPV into ZE driver"},
{"get_device_properties", getDeviceProperties, METH_VARARGS,
"Get the properties for a given device"},
{"init_context", initContext, METH_VARARGS,
Expand Down
Loading