@@ -300,37 +300,6 @@ def __getitem__(self, args):
300300 return self .configure (device_env , gs , ls )
301301
302302
303- #_CacheEntry = namedtuple("_CachedEntry", ['symbol', 'executable',
304- # 'kernarg_region'])
305-
306-
307- # class _CachedProgram(object):
308- # def __init__(self, entry_name, binary):
309- # self._entry_name = entry_name
310- # self._binary = binary
311- # # key: ocl context
312- # self._cache = {}
313- #
314- # def get(self, device):
315- # context = device.get_context()
316- # result = self._cache.get(context)
317- #
318- # if result is not None:
319- # program = result[1]
320- # kernel = result[2]
321- # else:
322- # # First-time compilation
323- # spirv_bc = spirv.llvm_to_spirv(self._binary)
324- # program = context.create_program_from_il(spirv_bc)
325- # program.build()
326- # kernel = program.create_kernel(self._entry_name)
327- #
328- # # Cache the just built cl_program, its cl_device and a cl_kernel
329- # self._cache[context] = (device, program, kernel)
330- #
331- # return context, device, program, kernel
332-
333-
334303class DPPyKernel (DPPyKernelBase ):
335304 """
336305 A OCL kernel object
@@ -345,24 +314,15 @@ def __init__(self, device_env, llvm_module, name, argtypes,
345314 self .argument_types = tuple (argtypes )
346315 self .ordered_arg_access_types = ordered_arg_access_types
347316 self ._argloc = []
348- # cached finalized program
349- # self._cacheprog = _CachedProgram(entry_name=self.entry_name,
350- # binary=self.binary)
351317 # First-time compilation using SPIRV-Tools
352318 if DEBUG :
353319 with open ("llvm_kernel.ll" , "w" ) as f :
354320 f .write (self .binary )
355321 self .spirv_bc = spirv_generator .llvm_to_spirv (self .binary )
356- #print("DPPyKernel:", self.spirv_bc, type(self.spirv_bc))
357322 # create a program
358323 self .program = driver .Program (device_env , self .spirv_bc )
359324 # create a kernel
360325 self .kernel = driver .Kernel (device_env , self .program , self .entry_name )
361- # def bind(self):
362- # """
363- # Bind kernel to device
364- # """
365- # return self._cacheprog.get(self.device)
366326
367327 def __call__ (self , * args ):
368328
@@ -425,10 +385,6 @@ def _unpack_argument(self, ty, val, device_env, retr, kernelargs,
425385 """
426386 Convert arguments to ctypes and append to kernelargs
427387 """
428- # DRD : Check if the val is of type driver.DeviceArray before checking
429- # if ty is of type ndarray. Argtypes returns ndarray for both
430- # DeviceArray and ndarray. This is a hack to get around the issue,
431- # till I understand the typing infrastructure of NUMBA better.
432388 device_arrs .append (None )
433389 if isinstance (val , driver .DeviceArray ):
434390 self ._unpack_device_array_argument (val , kernelargs )
@@ -499,10 +455,7 @@ def __init__(self, func, access_types):
499455 super (JitDPPyKernel , self ).__init__ ()
500456
501457 self .py_func = func
502- # DRD: Caching definitions this way can lead to unexpected consequences
503- # E.g. A kernel compiled for a given device would not get recompiled
504- # and lead to OpenCL runtime errors.
505- #self.definitions = {}
458+ self .definitions = {}
506459 self .access_types = access_types
507460
508461 from .descriptor import dppy_target
@@ -525,10 +478,10 @@ def __call__(self, *args, **kwargs):
525478 def specialize (self , * args ):
526479 argtypes = tuple ([self .typingctx .resolve_argument_type (a )
527480 for a in args ])
528- kernel = None # self.definitions.get( argtypes)
529-
481+ key_definitions = ( self .device_env . _env_ptr , argtypes )
482+ kernel = self . definitions . get ( key_definitions )
530483 if kernel is None :
531484 kernel = compile_kernel (self .device_env , self .py_func , argtypes ,
532485 self .access_types )
533- # self.definitions[argtypes ] = kernel
486+ self .definitions [key_definitions ] = kernel
534487 return kernel
0 commit comments