Skip to content

Commit 94bfb6a

Browse files
committed
Enable device caching for kernels
1 parent 84035d3 commit 94bfb6a

File tree

8 files changed

+284
-48
lines changed

8 files changed

+284
-48
lines changed

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "_queuestruct.h"
2525
#include "_usmarraystruct.h"
2626

27+
#include "experimental/kernel_caching.h"
2728
#include "experimental/nrt_reserve_meminfo.h"
2829
#include "numba/core/runtime/nrt_external.h"
2930

@@ -1493,6 +1494,7 @@ static PyObject *build_c_helpers_dict(void)
14931494
_declpointer("DPEXRT_sycl_event_init", &DPEXRT_sycl_event_init);
14941495
_declpointer("DPEXRT_nrt_acquire_meminfo_and_schedule_release",
14951496
&DPEXRT_nrt_acquire_meminfo_and_schedule_release);
1497+
_declpointer("DPEXRT_build_or_get_kernel", &DPEXRT_build_or_get_kernel);
14961498

14971499
#undef _declpointer
14981500
return dct;
@@ -1563,6 +1565,9 @@ MOD_INIT(_dpexrt_python)
15631565
PyModule_AddObject(
15641566
m, "DPEXRT_nrt_acquire_meminfo_and_schedule_release",
15651567
PyLong_FromVoidPtr(&DPEXRT_nrt_acquire_meminfo_and_schedule_release));
1568+
PyModule_AddObject(m, "DPEXRT_build_or_get_kernel",
1569+
PyLong_FromVoidPtr(&DPEXRT_build_or_get_kernel));
1570+
15661571
PyModule_AddObject(m, "c_helpers", build_c_helpers_dict());
15671572
return MOD_SUCCESS_VAL(m);
15681573
}

numba_dpex/core/runtime/context.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,3 +471,39 @@ def acquire_meminfo_and_schedule_release(
471471
ret = builder.call(fn, args)
472472

473473
return ret
474+
475+
def build_or_get_kernel(self, builder: llvmir.IRBuilder, args):
476+
"""Inserts LLVM IR to call build_or_get_kernel.
477+
478+
DPCTLSyclKernelRef
479+
DPEXRT_build_or_get_kernel(
480+
const DPCTLSyclContextRef ctx,
481+
const DPCTLSyclDeviceRef dev,
482+
size_t il_hash,
483+
const char *il,
484+
size_t il_length,
485+
const char *compile_opts,
486+
const char *kernel_name,
487+
);
488+
489+
"""
490+
mod = builder.module
491+
492+
func_ty = llvmir.FunctionType(
493+
cgutils.voidptr_t,
494+
[
495+
cgutils.voidptr_t,
496+
cgutils.voidptr_t,
497+
llvmir.IntType(64),
498+
cgutils.voidptr_t,
499+
llvmir.IntType(64),
500+
cgutils.voidptr_t,
501+
cgutils.voidptr_t,
502+
],
503+
)
504+
fn = cgutils.get_or_insert_function(
505+
mod, func_ty, "DPEXRT_build_or_get_kernel"
506+
)
507+
ret = builder.call(fn, args)
508+
509+
return ret
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
//===----------------------------------------------------------------------===//
6+
///
7+
/// \file
8+
/// A Python module that pprovides constructors to create a Numba MemInfo
9+
/// PyObject using a sycl USM allocator as the external memory allocator.
10+
/// The Module also provides the Numba box and unbox implementations for a
11+
/// dpnp.ndarray object.
12+
///
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "kernel_caching.h"
16+
#include <unordered_map>
17+
18+
extern "C"
19+
{
20+
#include "dpctl_capi.h"
21+
#include "dpctl_sycl_interface.h"
22+
23+
#include "_dbg_printer.h"
24+
25+
#include "numba/core/runtime/nrt_external.h"
26+
}
27+
28+
#include "syclinterface/dpctl_sycl_type_casters.hpp"
29+
#include "tools/dpctl.hpp"
30+
#include "tools/hash_tuple.hpp"
31+
32+
using CacheKey = std::tuple<DPCTLSyclContextRef, DPCTLSyclDeviceRef, size_t>;
33+
34+
class CacheKeysAreEqual
35+
{
36+
public:
37+
bool operator()(CacheKey const &lhs, CacheKey const &rhs) const
38+
{
39+
// TODO: implement full comparison
40+
return DPCTLDevice_AreEq(std::get<DPCTLSyclDeviceRef>(lhs),
41+
std::get<DPCTLSyclDeviceRef>(rhs)) &&
42+
DPCTLContext_AreEq(std::get<DPCTLSyclContextRef>(lhs),
43+
std::get<DPCTLSyclContextRef>(rhs)) &&
44+
std::get<size_t>(lhs) == std::get<size_t>(rhs);
45+
}
46+
};
47+
48+
// TODO: add cache cleaning
49+
std::unordered_map<CacheKey,
50+
DPCTLSyclKernelRef,
51+
std::hash<CacheKey>,
52+
CacheKeysAreEqual>
53+
sycl_kernel_cache = std::unordered_map<CacheKey,
54+
DPCTLSyclKernelRef,
55+
std::hash<CacheKey>,
56+
CacheKeysAreEqual>();
57+
58+
template <class M, class Key, class F>
59+
typename M::mapped_type &get_else_compute(M &m, Key const &k, F f)
60+
{
61+
typedef typename M::mapped_type V;
62+
std::pair<typename M::iterator, bool> r =
63+
m.insert(typename M::value_type(k, V()));
64+
V &v = r.first->second;
65+
if (r.second) {
66+
DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: building kernel.\n"););
67+
f(v);
68+
}
69+
else {
70+
DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: using cached kernel.\n"););
71+
}
72+
return v;
73+
}
74+
75+
extern "C"
76+
{
77+
DPCTLSyclKernelRef DPEXRT_build_or_get_kernel(const DPCTLSyclContextRef ctx,
78+
const DPCTLSyclDeviceRef dev,
79+
size_t il_hash,
80+
const char *il,
81+
size_t il_length,
82+
const char *compile_opts,
83+
const char *kernel_name)
84+
{
85+
DPEXRT_DEBUG(
86+
drt_debug_print("DPEXRT-DEBUG: in build or get kernel.\n"););
87+
88+
CacheKey key = std::make_tuple(ctx, dev, il_hash);
89+
90+
DPEXRT_DEBUG(auto ctx_hash = std::hash<DPCTLSyclContextRef>{}(ctx);
91+
auto dev_hash = std::hash<DPCTLSyclDeviceRef>{}(dev);
92+
drt_debug_print("DPEXRT-DEBUG: key hashes: %d %d %d.\n",
93+
ctx_hash, dev_hash, il_hash););
94+
95+
auto k_ref = get_else_compute(
96+
sycl_kernel_cache, key,
97+
[ctx, dev, il, il_length, compile_opts,
98+
kernel_name](DPCTLSyclKernelRef &k_ref) {
99+
auto kb_ref = DPCTLKernelBundle_CreateFromSpirv(
100+
ctx, dev, il, il_length, compile_opts);
101+
k_ref = DPCTLKernelBundle_GetKernel(kb_ref, kernel_name);
102+
DPCTLKernelBundle_Delete(kb_ref);
103+
});
104+
return DPCTLKernel_Copy(k_ref);
105+
}
106+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
//===----------------------------------------------------------------------===//
6+
///
7+
/// \file
8+
/// A Python module that pprovides constructors to create a Numba MemInfo
9+
/// PyObject using a sycl USM allocator as the external memory allocator.
10+
/// The Module also provides the Numba box and unbox implementations for a
11+
/// dpnp.ndarray object.
12+
///
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifdef __cplusplus
16+
extern "C"
17+
{
18+
#endif
19+
20+
#include "dpctl_capi.h"
21+
#include "dpctl_sycl_interface.h"
22+
23+
DPCTLSyclKernelRef DPEXRT_build_or_get_kernel(const DPCTLSyclContextRef ctx,
24+
const DPCTLSyclDeviceRef dev,
25+
size_t il_hash,
26+
const char *il,
27+
size_t il_length,
28+
const char *compile_opts,
29+
const char *kernel_name);
30+
#ifdef __cplusplus
31+
}
32+
#endif
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
3+
#include "syclinterface/dpctl_sycl_type_casters.hpp"
4+
#include <CL/sycl.hpp>
5+
6+
namespace std
7+
{
8+
template <> struct hash<DPCTLSyclDeviceRef>
9+
{
10+
size_t operator()(const DPCTLSyclDeviceRef &DRef) const
11+
{
12+
using dpctl::syclinterface::unwrap;
13+
return hash<sycl::device>()(*unwrap<sycl::device>(DRef));
14+
}
15+
};
16+
17+
template <> struct hash<DPCTLSyclContextRef>
18+
{
19+
size_t operator()(const DPCTLSyclContextRef &CRef) const
20+
{
21+
using dpctl::syclinterface::unwrap;
22+
return hash<sycl::context>()(*unwrap<sycl::context>(CRef));
23+
}
24+
};
25+
} // namespace std
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#pragma once
2+
3+
#include <tuple>
4+
namespace std
5+
{
6+
namespace
7+
{
8+
9+
// Code from boost
10+
// Reciprocal of the golden ratio helps spread entropy
11+
// and handles duplicates.
12+
// See Mike Seymour in magic-numbers-in-boosthash-combine:
13+
// http://stackoverflow.com/questions/4948780
14+
15+
template <class T> inline void hash_combine(std::size_t &seed, T const &v)
16+
{
17+
seed ^= std::hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
18+
}
19+
20+
// Recursive template code derived from Matthieu M.
21+
template <class Tuple, size_t Index = std::tuple_size<Tuple>::value - 1>
22+
struct HashValueImpl
23+
{
24+
static void apply(size_t &seed, Tuple const &tuple)
25+
{
26+
HashValueImpl<Tuple, Index - 1>::apply(seed, tuple);
27+
hash_combine(seed, std::get<Index>(tuple));
28+
}
29+
};
30+
31+
template <class Tuple> struct HashValueImpl<Tuple, 0>
32+
{
33+
static void apply(size_t &seed, Tuple const &tuple)
34+
{
35+
hash_combine(seed, std::get<0>(tuple));
36+
}
37+
};
38+
} // namespace
39+
40+
template <typename... TT> struct hash<std::tuple<TT...>>
41+
{
42+
size_t operator()(std::tuple<TT...> const &tt) const
43+
{
44+
size_t seed = 0;
45+
HashValueImpl<std::tuple<TT...>>::apply(seed, tt);
46+
return seed;
47+
}
48+
};
49+
} // namespace std

numba_dpex/experimental/launcher.py

Lines changed: 26 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ def __init__(
5454
):
5555
self.context = codegen_targetctx
5656
self.builder = builder
57+
# TODO: get dpex RT from cached property once the PR is merged
58+
# https://github.com/IntelPython/numba-dpex/pull/1027
59+
# and get rid of the global variable. Use self.context.dpexrt instead.
60+
self.dpexrt = DpexRTContext(self.context)
5761

5862
if config.DEBUG_KERNEL_LAUNCHER:
5963
cgutils.printf(
@@ -139,7 +143,7 @@ def get_queue_ref_val(
139143

140144
return ptr_to_queue_ref
141145

142-
def get_kernel(self, qref, kernel_module: _KernelModule):
146+
def get_kernel(self, queue_ref, kernel_module: _KernelModule):
143147
"""Returns the pointer to the sycl::kernel object in a passed in
144148
sycl::kernel_bundle wrapper object.
145149
"""
@@ -150,23 +154,32 @@ def get_kernel(self, qref, kernel_module: _KernelModule):
150154
bytes=kernel_module.kernel_bitcode,
151155
)
152156

153-
# Create a sycl::kernel_bundle object and return it as an opaque pointer
154-
# using dpctl's libsyclinterface.
155-
kbref = self.create_kernel_bundle_from_spirv(
156-
queue_ref=qref,
157-
kernel_bc=kernel_bc_byte_str,
158-
kernel_bc_size_in_bytes=len(kernel_module.kernel_bitcode),
159-
)
160-
161157
kernel_name = self.context.insert_const_string(
162158
self.builder.module, kernel_module.kernel_name
163159
)
164160

165-
kernel_ref = sycl.dpctl_kernel_bundle_get_kernel(
166-
self.builder, kbref, kernel_name
161+
context_ref = sycl.dpctl_queue_get_context(self.builder, queue_ref)
162+
device_ref = sycl.dpctl_queue_get_device(self.builder, queue_ref)
163+
164+
kernel_ref = self.dpexrt.build_or_get_kernel(
165+
self.builder,
166+
[
167+
context_ref,
168+
device_ref,
169+
llvmir.Constant(
170+
llvmir.IntType(64), hash(kernel_module.kernel_bitcode)
171+
),
172+
kernel_bc_byte_str,
173+
llvmir.Constant(
174+
llvmir.IntType(64), len(kernel_module.kernel_bitcode)
175+
),
176+
self.builder.load(create_null_ptr(self.builder, self.context)),
177+
kernel_name,
178+
],
167179
)
168180

169-
sycl.dpctl_kernel_bundle_delete(self.builder, kbref)
181+
sycl.dpctl_context_delete(self.builder, context_ref)
182+
sycl.dpctl_device_delete(self.builder, device_ref)
170183

171184
return kernel_ref
172185

@@ -210,36 +223,6 @@ def create_llvm_values_for_index_space(
210223

211224
return LLRange(global_range_extents, local_range_extents)
212225

213-
def create_kernel_bundle_from_spirv(
214-
self,
215-
queue_ref: llvmir.PointerType,
216-
kernel_bc: llvmir.Constant,
217-
kernel_bc_size_in_bytes: int,
218-
) -> llvmir.CallInstr:
219-
"""Calls DPCTLKernelBundle_CreateFromSpirv to create an opaque pointer
220-
to a sycl::kernel_bundle from the SPIR-V generated for a kernel.
221-
"""
222-
device_ref = sycl.dpctl_queue_get_device(self.builder, queue_ref)
223-
context_ref = sycl.dpctl_queue_get_context(self.builder, queue_ref)
224-
args = [
225-
context_ref,
226-
device_ref,
227-
kernel_bc,
228-
llvmir.Constant(llvmir.IntType(64), kernel_bc_size_in_bytes),
229-
self.builder.load(create_null_ptr(self.builder, self.context)),
230-
]
231-
kb_ref = sycl.dpctl_kernel_bundle_create_from_spirv(self.builder, *args)
232-
sycl.dpctl_context_delete(self.builder, context_ref)
233-
sycl.dpctl_device_delete(self.builder, device_ref)
234-
235-
if config.DEBUG_KERNEL_LAUNCHER:
236-
cgutils.printf(
237-
self.builder,
238-
"DPEX-DEBUG: Generated kernel_bundle from SPIR-V.\n",
239-
)
240-
241-
return kb_ref
242-
243226
def acquire_meminfo_and_schedule_release(
244227
self,
245228
queue_ref,
@@ -259,12 +242,7 @@ def acquire_meminfo_and_schedule_release(
259242
status_ptr = cgutils.alloca_once(
260243
self.builder, self.context.get_value_type(types.uint64)
261244
)
262-
# TODO: get dpex RT from cached property once the PR is merged
263-
# https://github.com/IntelPython/numba-dpex/pull/1027
264-
# host_eref = ctx.dpexrt.acquire_meminfo_and_schedule_release( # noqa: W0621
265-
host_eref = DpexRTContext(
266-
self.context
267-
).acquire_meminfo_and_schedule_release(
245+
host_eref = self.dpexrt.acquire_meminfo_and_schedule_release(
268246
self.builder,
269247
[
270248
self.context.nrt.get_nrt_api(self.builder),

0 commit comments

Comments
 (0)