2323from  numba .core .typing .templates  import  CallableTemplate 
2424from  numba .np .arrayobj  import  _array_copy 
2525
26- import  dpctl .dptensor .numpy_usm_shared  as  numpy_usm_shared 
27- from  dpctl .dptensor .numpy_usm_shared  import  ndarray , functions_list 
26+ import  dpctl .dptensor .numpy_usm_shared  as  nus 
27+ from  dpctl .dptensor .numpy_usm_shared  import  ndarray , functions_list ,  class_list 
2828
2929
3030debug  =  config .DEBUG 
@@ -233,7 +233,7 @@ def numba_register_lower_builtin():
233233    cur_mod  =  importlib .import_module (__name__ )
234234    for  impl , func , types  in  todo  +  todo_builtin :
235235        try :
236-             usmarray_func  =  eval ("numpy_usm_shared." + func .__name__ )
236+             usmarray_func  =  eval ("dpctl.dptensor. numpy_usm_shared."   +   func .__name__ )
237237        except :
238238            dprint ("failed to eval" , func .__name__ )
239239            continue 
@@ -260,28 +260,44 @@ def numba_register_typing():
260260    # For all Numpy identifiers that have been registered for typing in Numba... 
261261    for  ig  in  typing_registry .globals :
262262        val , typ  =  ig 
263+         dprint ("Numpy registered:" , val , type (val ), typ , type (typ ))
263264        # If it is a Numpy function... 
264265        if  isinstance (val , (ftype , bftype )):
265266            # If we have overloaded that function in the usmarray module (always True right now)... 
266267            if  val .__name__  in  functions_list :
267268                todo .append (ig )
268269        if  isinstance (val , type ):
269-             todo_classes .append (ig )
270+             if  isinstance (typ , numba .core .types .functions .Function ):
271+                 todo .append (ig )
272+             elif  isinstance (typ , numba .core .types .functions .NumberClass ):
273+                 pass 
274+                 #todo_classes.append(ig) 
270275
271276    for  tgetattr  in  templates_registry .attributes :
272277        if  tgetattr .key  ==  types .Array :
273278            todo_getattr .append (tgetattr )
274279
280+     for  val , typ  in  todo_classes :
281+         dprint ("todo_classes:" , val , typ , type (typ ))
282+ 
283+         try :
284+             dptype  =  eval ("dpctl.dptensor.numpy_usm_shared."  +  val .__name__ )
285+         except :
286+             dprint ("failed to eval" , val .__name__ )
287+             continue 
288+ 
289+         typing_registry .register_global (dptype , numba .core .types .NumberClass (typ .instance_type ))
290+ 
275291    for  val , typ  in  todo :
276292        assert  len (typ .templates ) ==  1 
277293        # template is the typing class to invoke generic() upon. 
278294        template  =  typ .templates [0 ]
295+         dprint ("need to re-register for usmarray" , val , typ , typ .typing_key )
279296        try :
280-             dpval  =  eval ("numpy_usm_shared." + val .__name__ )
297+             dpval  =  eval ("dpctl.dptensor. numpy_usm_shared."   +   val .__name__ )
281298        except :
282299            dprint ("failed to eval" , val .__name__ )
283300            continue 
284-         dprint ("need to re-register for usmarray" , val , typ , typ .typing_key )
285301        """ 
286302        if debug: 
287303            print("--------------------------------------------------------------") 
@@ -307,9 +323,7 @@ def set_key_original(cls, key, original):
307323        def  generic_impl (self ):
308324            original_typer  =  self .__class__ .original .generic (self .__class__ .original )
309325            ot_argspec  =  inspect .getfullargspec (original_typer )
310-             # print("ot_argspec:", ot_argspec) 
311326            astr  =  argspec_to_string (ot_argspec )
312-             # print("astr:", astr) 
313327
314328            typer_func  =  """def typer({}): 
315329                                original_res = original_typer({}) 
@@ -321,8 +335,6 @@ def generic_impl(self):
321335                astr , "," .join (ot_argspec .args )
322336            )
323337
324-             # print("typer_func:", typer_func) 
325- 
326338            try :
327339                gs  =  globals ()
328340                ls  =  locals ()
@@ -344,7 +356,6 @@ def generic_impl(self):
344356                print ("eval failed!" , sys .exc_info ()[0 ])
345357                sys .exit (0 )
346358
347-             # print("exec_res:", exec_res) 
348359            return  exec_res 
349360
350361        new_usmarray_template  =  type (
@@ -370,7 +381,6 @@ def set_key(cls, key):
370381
371382        def  getattr_impl (self , attr ):
372383            if  attr .startswith ("resolve_" ):
373-                 # print("getattr_impl starts with resolve_:", self, type(self), attr) 
374384                def  wrapper (* args , ** kwargs ):
375385                    attr_res  =  tgetattr .__getattribute__ (self , attr )(* args , ** kwargs )
376386                    if  isinstance (attr_res , types .Array ):
@@ -394,15 +404,7 @@ def wrapper(*args, **kwargs):
394404        templates_registry .register_attr (new_usmarray_template )
395405
396406
397- def  from_ndarray (x ):
398-     return  copy (x )
399- 
400- 
401- def  as_ndarray (x ):
402-     return  np .copy (x )
403- 
404- 
405- @typing_registry .register_global (as_ndarray ) 
407+ @typing_registry .register_global (nus .as_ndarray ) 
406408class  DparrayAsNdarray (CallableTemplate ):
407409    def  generic (self ):
408410        def  typer (arg ):
@@ -411,7 +413,7 @@ def typer(arg):
411413        return  typer 
412414
413415
414- @typing_registry .register_global (from_ndarray ) 
416+ @typing_registry .register_global (nus . from_ndarray ) 
415417class  DparrayFromNdarray (CallableTemplate ):
416418    def  generic (self ):
417419        def  typer (arg ):
@@ -420,11 +422,11 @@ def typer(arg):
420422        return  typer 
421423
422424
423- @lower_registry .lower (as_ndarray , UsmSharedArrayType ) 
425+ @lower_registry .lower (nus . as_ndarray , UsmSharedArrayType ) 
424426def  usmarray_conversion_as (context , builder , sig , args ):
425427    return  _array_copy (context , builder , sig , args )
426428
427429
428- @lower_registry .lower (from_ndarray , types .Array ) 
430+ @lower_registry .lower (nus . from_ndarray , types .Array ) 
429431def  usmarray_conversion_from (context , builder , sig , args ):
430432    return  _array_copy (context , builder , sig , args )
0 commit comments