@@ -97,7 +97,30 @@ def init(self):
9797 .SPIR_DATA_LAYOUT [utils .MACHINE_BITS ]))
9898 # Override data model manager to SPIR model
9999 self .data_model_manager = spirv_data_model_manager
100- self .done_once = False
100+
101+ from numba .np .ufunc_db import _ufunc_db as ufunc_db , _lazy_init_db
102+ import copy
103+ _lazy_init_db ()
104+ self .ufunc_db = copy .deepcopy (ufunc_db )
105+
106+
107+ def replace_numpy_ufunc_with_opencl_supported_functions (self ):
108+ from numba .dppy .ocl .mathimpl import lower_ocl_impl , sig_mapper
109+
110+ ufuncs = [("fabs" , np .fabs ), ("exp" , np .exp ), ("log" , np .log ),
111+ ("log10" , np .log10 ), ("expm1" , np .expm1 ), ("log1p" , np .log1p ),
112+ ("sqrt" , np .sqrt ), ("sin" , np .sin ), ("cos" , np .cos ),
113+ ("tan" , np .tan ), ("asin" , np .arcsin ), ("acos" , np .arccos ),
114+ ("atan" , np .arctan ), ("atan2" , np .arctan2 ), ("sinh" , np .sinh ),
115+ ("cosh" , np .cosh ), ("tanh" , np .tanh ), ("asinh" , np .arcsinh ),
116+ ("acosh" , np .arccosh ), ("atanh" , np .arctanh ), ("ldexp" , np .ldexp ),
117+ ("floor" , np .floor ), ("ceil" , np .ceil ), ("trunc" , np .trunc )]
118+
119+ for name , ufunc in ufuncs :
120+ for sig in self .ufunc_db [ufunc ].keys ():
121+ if sig in sig_mapper and (name , sig_mapper [sig ]) in lower_ocl_impl :
122+ self .ufunc_db [ufunc ][sig ] = lower_ocl_impl [(name , sig_mapper [sig ])]
123+
101124
102125 def load_additional_registries (self ):
103126 from .ocl import oclimpl , mathimpl
@@ -111,9 +134,7 @@ def load_additional_registries(self):
111134 functions we will redirect some of NUMBA's NumPy
112135 ufunc with OpenCL's.
113136 """
114- if not self .done_once :
115- _replace_numpy_ufunc_with_opencl_supported_functions ()
116- self .done_once = True
137+ self .replace_numpy_ufunc_with_opencl_supported_functions ()
117138
118139
119140 @cached_property
0 commit comments