@@ -424,6 +424,123 @@ DPEXRT_MemInfo_alloc(npy_intp size, size_t usm_type, const char *device)
424424 return NULL ;
425425}
426426
427+ /**
428+ * @brief Interface for the core.runtime.context.DpexRTContext.meminfo_alloc.
429+ * This function takes an allocated memory as NRT_MemInfo and fills it with
430+ * the value specified by `value`.
431+ *
432+ * @param mi An NRT_MemInfo object, should be found from memory
433+ * allocation.
434+ * @param itemsize The itemsize, the size of each item in the array.
435+ * @param is_float Flag to specify if the data being float or not.
436+ * @param value The value to be used to fill an array.
437+ * @param device The device on which the memory was allocated.
438+ * @return NRT_MemInfo* A new NRT_MemInfo object, NULL if no NRT_MemInfo
439+ * object could be created.
440+ */
441+ static NRT_MemInfo * DPEXRT_MemInfo_fill (NRT_MemInfo * mi ,
442+ size_t itemsize ,
443+ bool is_float ,
444+ uint8_t value ,
445+ const char * device )
446+ {
447+ DPCTLSyclDeviceSelectorRef dselector = NULL ;
448+ DPCTLSyclDeviceRef dref = NULL ;
449+ DPCTLSyclQueueRef qref = NULL ;
450+ DPCTLSyclEventRef eref = NULL ;
451+ size_t count = 0 , size = 0 , exp = 0 ;
452+
453+ size = mi -> size ;
454+ while (itemsize >>= 1 )
455+ exp ++ ;
456+ count = (unsigned int )(size >> exp );
457+
458+ NRT_Debug (nrt_debug_print (
459+ "DPEXRT-DEBUG: mi->size = %u, itemsize = %u, count = %u, "
460+ "value = %u, Inside DPEXRT_MemInfo_fill %s, line %d\n" ,
461+ mi -> size , itemsize << exp , count , value , __FILE__ , __LINE__ ));
462+
463+ if (mi -> data == NULL ) {
464+ NRT_Debug (nrt_debug_print ("DPEXRT-DEBUG: mi->data is NULL, "
465+ "Inside DPEXRT_MemInfo_fill %s, line %d\n" ,
466+ __FILE__ , __LINE__ ));
467+ goto error ;
468+ }
469+
470+ if (!(dselector = DPCTLFilterSelector_Create (device ))) {
471+ NRT_Debug (nrt_debug_print (
472+ "DPEXRT-ERROR: Could not create a sycl::device_selector from "
473+ "filter string: %s at %s %d.\n" ,
474+ device , __FILE__ , __LINE__ ));
475+ goto error ;
476+ }
477+
478+ if (!(dref = DPCTLDevice_CreateFromSelector (dselector )))
479+ goto error ;
480+
481+ if (!(qref = DPCTLQueue_CreateForDevice (dref , NULL , 0 )))
482+ goto error ;
483+
484+ DPCTLDeviceSelector_Delete (dselector );
485+ DPCTLDevice_Delete (dref );
486+
487+ switch (exp ) {
488+ case 3 :
489+ {
490+ uint64_t value_assign = (uint64_t )value ;
491+ if (is_float ) {
492+ double const_val = (double )value ;
493+ // To stop warning: dereferencing type-punned pointer
494+ // will break strict-aliasing rules [-Wstrict-aliasing]
495+ double * p = & const_val ;
496+ value_assign = * ((uint64_t * )(p ));
497+ }
498+ if (!(eref = DPCTLQueue_Fill64 (qref , mi -> data , value_assign , count )))
499+ goto error ;
500+ break ;
501+ }
502+ case 2 :
503+ {
504+ uint32_t value_assign = (uint32_t )value ;
505+ if (is_float ) {
506+ float const_val = (float )value ;
507+ // To stop warning: dereferencing type-punned pointer
508+ // will break strict-aliasing rules [-Wstrict-aliasing]
509+ float * p = & const_val ;
510+ value_assign = * ((uint32_t * )(p ));
511+ }
512+ if (!(eref = DPCTLQueue_Fill32 (qref , mi -> data , value_assign , count )))
513+ goto error ;
514+ break ;
515+ }
516+ case 1 :
517+ if (!(eref = DPCTLQueue_Fill16 (qref , mi -> data , value , count )))
518+ goto error ;
519+ break ;
520+ case 0 :
521+ if (!(eref = DPCTLQueue_Fill8 (qref , mi -> data , value , count )))
522+ goto error ;
523+ break ;
524+ default :
525+ goto error ;
526+ }
527+
528+ DPCTLEvent_Wait (eref );
529+
530+ DPCTLQueue_Delete (qref );
531+ DPCTLEvent_Delete (eref );
532+
533+ return mi ;
534+
535+ error :
536+ DPCTLQueue_Delete (qref );
537+ DPCTLEvent_Delete (eref );
538+ DPCTLDeviceSelector_Delete (dselector );
539+ DPCTLDevice_Delete (dref );
540+
541+ return NULL ;
542+ }
543+
427544/*----------------------------------------------------------------------------*/
428545/*--------- Helpers to get attributes out of a dpnp.ndarray PyObject ---------*/
429546/*----------------------------------------------------------------------------*/
@@ -487,12 +604,13 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
487604 arystruct_t * arystruct )
488605{
489606 struct PyUSMArrayObject * arrayobj = NULL ;
490- int i , ndim ;
607+ int i = 0 , ndim = 0 , exp = 0 ;
491608 npy_intp * shape = NULL , * strides = NULL ;
492- npy_intp * p = NULL , nitems , itemsize ;
609+ npy_intp * p = NULL , nitems ;
493610 void * data = NULL ;
494611 DPCTLSyclQueueRef qref = NULL ;
495612 PyGILState_STATE gstate ;
613+ npy_intp itemsize = 0 ;
496614
497615 // Increment the ref count on obj to prevent CPython from garbage
498616 // collecting the array.
@@ -546,20 +664,29 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
546664
547665 p = arystruct -> shape_and_strides ;
548666
667+ // Calculate the exponent from the arystruct->itemsize as we know
668+ // itemsize is a power of two
669+ while (itemsize >>= 1 )
670+ exp ++ ;
671+
549672 for (i = 0 ; i < ndim ; ++ i , ++ p )
550673 * p = shape [i ];
551674
552- // DPCTL returns a NULL pointer if the array is contiguous
675+ // DPCTL returns a NULL pointer if the array is contiguous. dpctl stores
676+ // strides as number of elements and Numba stores strides as bytes, for
677+ // that reason we are multiplying stride by itemsize when unboxing the
678+ // external array.
679+
553680 // FIXME: Stride computation should check order and adjust how strides are
554681 // calculated. Right now strides are assuming that order is C contigous.
555682 if (strides ) {
556683 for (i = 0 ; i < ndim ; ++ i , ++ p ) {
557- * p = strides [i ];
684+ * p = strides [i ] << exp ;
558685 }
559686 }
560687 else {
561688 for (i = 1 ; i < ndim ; ++ i , ++ p ) {
562- * p = shape [i ];
689+ * p = shape [i ] << exp ;
563690 }
564691 * p = 1 ;
565692 }
@@ -598,11 +725,12 @@ static PyObject *box_from_arystruct_parent(arystruct_t *arystruct,
598725 int ndim ,
599726 PyArray_Descr * descr )
600727{
601- int i ;
602- npy_intp * p ;
728+ int i = 0 , exp = 0 ;
729+ npy_intp * p = NULL ;
603730 npy_intp * shape = NULL , * strides = NULL ;
604731 PyObject * array = arystruct -> parent ;
605732 struct PyUSMArrayObject * arrayobj = NULL ;
733+ npy_intp itemsize = 0 ;
606734
607735 NRT_Debug (nrt_debug_print ("DPEXRT-DEBUG: In try_to_return_parent.\n" ));
608736
@@ -623,9 +751,16 @@ static PyObject *box_from_arystruct_parent(arystruct_t *arystruct,
623751 if (shape [i ] != * p )
624752 return NULL ;
625753 }
626-
754+ // Calculate the exponent from the arystruct->itemsize as we know
755+ // itemsize is a power of two
756+ itemsize = arystruct -> itemsize ;
757+ while (itemsize >>= 1 )
758+ exp ++ ;
759+ // dpctl stores strides as number of elements and Numba stores strides as
760+ // bytes, for that reason we are multiplying stride by itemsize when
761+ // unboxing the external array.
627762 if (strides ) {
628- if (strides [i ] != * p )
763+ if (strides [i ] << exp != * p )
629764 return NULL ;
630765 }
631766 else {
@@ -680,6 +815,8 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(arystruct_t *arystruct,
680815 npy_intp * shape = NULL , * strides = NULL ;
681816 int typenum = 0 ;
682817 int status = 0 ;
818+ int exp = 0 ;
819+ npy_intp itemsize = 0 ;
683820
684821 NRT_Debug (nrt_debug_print (
685822 "DPEXRT-DEBUG: In DPEXRT_sycl_usm_ndarray_to_python_acqref.\n" ));
@@ -750,7 +887,20 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(arystruct_t *arystruct,
750887 }
751888
752889 shape = arystruct -> shape_and_strides ;
753- strides = shape + ndim ;
890+
891+ // Calculate the exponent from the arystruct->itemsize as we know
892+ // itemsize is a power of two
893+ itemsize = arystruct -> itemsize ;
894+ while (itemsize >>= 1 )
895+ exp ++ ;
896+
897+ // Numba internally stores strides as bytes and not as elements. Divide
898+ // the stride by itemsize to get number of elements.
899+ for (size_t idx = ndim ; idx < 2 * ((size_t )ndim ); ++ idx )
900+ arystruct -> shape_and_strides [idx ] =
901+ arystruct -> shape_and_strides [idx ] >> exp ;
902+ strides = (shape + ndim );
903+
754904 typenum = descr -> type_num ;
755905 usm_ndarr_obj = UsmNDArray_MakeFromPtr (
756906 ndim , shape , typenum , strides , (DPCTLSyclUSMRef )arystruct -> data ,
@@ -845,6 +995,7 @@ static PyObject *build_c_helpers_dict(void)
845995 _declpointer ("DPEXRT_sycl_usm_ndarray_to_python_acqref" ,
846996 & DPEXRT_sycl_usm_ndarray_to_python_acqref );
847997 _declpointer ("DPEXRT_MemInfo_alloc" , & DPEXRT_MemInfo_alloc );
998+ _declpointer ("DPEXRT_MemInfo_fill" , & DPEXRT_MemInfo_fill );
848999 _declpointer ("NRT_ExternalAllocator_new_for_usm" ,
8491000 & NRT_ExternalAllocator_new_for_usm );
8501001
@@ -895,7 +1046,8 @@ MOD_INIT(_dpexrt_python)
8951046 PyLong_FromVoidPtr (& DPEXRT_sycl_usm_ndarray_to_python_acqref ));
8961047 PyModule_AddObject (m , "DPEXRT_MemInfo_alloc" ,
8971048 PyLong_FromVoidPtr (& DPEXRT_MemInfo_alloc ));
898-
1049+ PyModule_AddObject (m , "DPEXRT_MemInfo_fill" ,
1050+ PyLong_FromVoidPtr (& DPEXRT_MemInfo_fill ));
8991051 PyModule_AddObject (m , "c_helpers" , build_c_helpers_dict ());
9001052 return MOD_SUCCESS_VAL (m );
9011053}
0 commit comments