Skip to content

Commit d18b2a2

Browse files
Totmenina, ElenaDeb, Diptorup
authored andcommitted
(numba/dppy) Add caching (#60)
* Add caching * Del comment about errors in caching * Add test for caching kernel * Remove old commented out code.
1 parent e57e5d3 commit d18b2a2

File tree

2 files changed

+57
-51
lines changed

2 files changed

+57
-51
lines changed

compiler.py

Lines changed: 4 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -300,37 +300,6 @@ def __getitem__(self, args):
300300
return self.configure(device_env, gs, ls)
301301

302302

303-
#_CacheEntry = namedtuple("_CachedEntry", ['symbol', 'executable',
304-
# 'kernarg_region'])
305-
306-
307-
# class _CachedProgram(object):
308-
# def __init__(self, entry_name, binary):
309-
# self._entry_name = entry_name
310-
# self._binary = binary
311-
# # key: ocl context
312-
# self._cache = {}
313-
#
314-
# def get(self, device):
315-
# context = device.get_context()
316-
# result = self._cache.get(context)
317-
#
318-
# if result is not None:
319-
# program = result[1]
320-
# kernel = result[2]
321-
# else:
322-
# # First-time compilation
323-
# spirv_bc = spirv.llvm_to_spirv(self._binary)
324-
# program = context.create_program_from_il(spirv_bc)
325-
# program.build()
326-
# kernel = program.create_kernel(self._entry_name)
327-
#
328-
# # Cache the just built cl_program, its cl_device and a cl_kernel
329-
# self._cache[context] = (device, program, kernel)
330-
#
331-
# return context, device, program, kernel
332-
333-
334303
class DPPyKernel(DPPyKernelBase):
335304
"""
336305
A OCL kernel object
@@ -345,24 +314,15 @@ def __init__(self, device_env, llvm_module, name, argtypes,
345314
self.argument_types = tuple(argtypes)
346315
self.ordered_arg_access_types = ordered_arg_access_types
347316
self._argloc = []
348-
# cached finalized program
349-
# self._cacheprog = _CachedProgram(entry_name=self.entry_name,
350-
# binary=self.binary)
351317
# First-time compilation using SPIRV-Tools
352318
if DEBUG:
353319
with open("llvm_kernel.ll", "w") as f:
354320
f.write(self.binary)
355321
self.spirv_bc = spirv_generator.llvm_to_spirv(self.binary)
356-
#print("DPPyKernel:", self.spirv_bc, type(self.spirv_bc))
357322
# create a program
358323
self.program = driver.Program(device_env, self.spirv_bc)
359324
# create a kernel
360325
self.kernel = driver.Kernel(device_env, self.program, self.entry_name)
361-
# def bind(self):
362-
# """
363-
# Bind kernel to device
364-
# """
365-
# return self._cacheprog.get(self.device)
366326

367327
def __call__(self, *args):
368328

@@ -425,10 +385,6 @@ def _unpack_argument(self, ty, val, device_env, retr, kernelargs,
425385
"""
426386
Convert arguments to ctypes and append to kernelargs
427387
"""
428-
# DRD : Check if the val is of type driver.DeviceArray before checking
429-
# if ty is of type ndarray. Argtypes returns ndarray for both
430-
# DeviceArray and ndarray. This is a hack to get around the issue,
431-
# till I understand the typing infrastructure of NUMBA better.
432388
device_arrs.append(None)
433389
if isinstance(val, driver.DeviceArray):
434390
self._unpack_device_array_argument(val, kernelargs)
@@ -499,10 +455,7 @@ def __init__(self, func, access_types):
499455
super(JitDPPyKernel, self).__init__()
500456

501457
self.py_func = func
502-
# DRD: Caching definitions this way can lead to unexpected consequences
503-
# E.g. A kernel compiled for a given device would not get recompiled
504-
# and lead to OpenCL runtime errors.
505-
#self.definitions = {}
458+
self.definitions = {}
506459
self.access_types = access_types
507460

508461
from .descriptor import dppy_target
@@ -525,10 +478,10 @@ def __call__(self, *args, **kwargs):
525478
def specialize(self, *args):
526479
argtypes = tuple([self.typingctx.resolve_argument_type(a)
527480
for a in args])
528-
kernel = None #self.definitions.get(argtypes)
529-
481+
key_definitions = (self.device_env._env_ptr, argtypes)
482+
kernel = self.definitions.get(key_definitions)
530483
if kernel is None:
531484
kernel = compile_kernel(self.device_env, self.py_func, argtypes,
532485
self.access_types)
533-
#self.definitions[argtypes] = kernel
486+
self.definitions[key_definitions] = kernel
534487
return kernel

tests/dppy/test_caching.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from __future__ import print_function
2+
from timeit import default_timer as time
3+
4+
import sys
5+
import numpy as np
6+
from numba import dppy
7+
import dppy.core as ocldrv
8+
from numba.dppy.testing import unittest
9+
from numba.dppy.testing import DPPYTestCase
10+
11+
12+
def data_parallel_sum(a, b, c):
13+
i = dppy.get_global_id(0)
14+
c[i] = a[i] + b[i]
15+
16+
17+
class TestCaching(DPPYTestCase):
18+
def test_caching_kernel(self):
19+
global_size = 10
20+
N = global_size
21+
22+
a = np.array(np.random.random(N), dtype=np.float32)
23+
b = np.array(np.random.random(N), dtype=np.float32)
24+
c = np.ones_like(a)
25+
26+
device_env = None
27+
28+
try:
29+
device_env = ocldrv.runtime.get_gpu_device()
30+
print("Selected GPU device")
31+
except:
32+
try:
33+
device_env = ocldrv.runtime.get_cpu_device()
34+
print("Selected CPU device")
35+
except:
36+
print("No OpenCL devices found on the system")
37+
raise SystemExit()
38+
39+
# Copy the data to the device
40+
dA = device_env.copy_array_to_device(a)
41+
dB = device_env.copy_array_to_device(b)
42+
dC = ocldrv.DeviceArray(device_env.get_env_ptr(), c)
43+
44+
func = dppy.kernel(data_parallel_sum)
45+
caching_kernel = func[device_env, global_size].specialize(dA, dB, dC)
46+
47+
for i in range(10):
48+
cached_kernel = func[device_env, global_size].specialize(dA, dB, dC)
49+
self.assertIs(caching_kernel, cached_kernel)
50+
51+
52+
if __name__ == '__main__':
53+
unittest.main()

0 commit comments

Comments
 (0)