1212///
1313//===----------------------------------------------------------------------===//
1414
15+ #include <inttypes.h>
16+ #include <stdio.h>
17+
1518#include "dpctl_capi.h"
1619#include "dpctl_sycl_interface.h"
1720
@@ -424,6 +427,123 @@ DPEXRT_MemInfo_alloc(npy_intp size, size_t usm_type, const char *device)
424427 return NULL ;
425428}
426429
430+ /**
431+ * @brief Interface for the core.runtime.context.DpexRTContext.meminfo_alloc.
432+ * This function takes an allocated memory as NRT_MemInfo and fills it with
433+ * the value specified by `value`.
434+ *
435+ * @param mi An NRT_MemInfo object, should be found from memory
436+ * allocation.
437+ * @param itemsize The itemsize, the size of each item in the array.
438+ * @param is_float Flag to specify if the data being float or not.
439+ * @param value The value to be used to fill an array.
440+ * @param device The device on which the memory was allocated.
441+ * @return NRT_MemInfo* A new NRT_MemInfo object, NULL if no NRT_MemInfo
442+ * object could be created.
443+ */
444+ static NRT_MemInfo * DPEXRT_MemInfo_fill (NRT_MemInfo * mi ,
445+ size_t itemsize ,
446+ bool is_float ,
447+ uint8_t value ,
448+ const char * device )
449+ {
450+ DPCTLSyclDeviceSelectorRef dselector = NULL ;
451+ DPCTLSyclDeviceRef dref = NULL ;
452+ DPCTLSyclQueueRef qref = NULL ;
453+ DPCTLSyclEventRef eref = NULL ;
454+ size_t count = 0 , size = 0 , exp = 0 ;
455+
456+ size = mi -> size ;
457+ while (itemsize >>= 1 )
458+ exp ++ ;
459+ count = (unsigned int )(size >> exp );
460+
461+ NRT_Debug (nrt_debug_print (
462+ "DPEXRT-DEBUG: mi->size = %u, itemsize = %u, count = %u, "
463+ "value = %u, Inside DPEXRT_MemInfo_fill %s, line %d\n" ,
464+ mi -> size , itemsize << exp , count , value , __FILE__ , __LINE__ ));
465+
466+ if (mi -> data == NULL ) {
467+ NRT_Debug (nrt_debug_print ("DPEXRT-DEBUG: mi->data is NULL, "
468+ "Inside DPEXRT_MemInfo_fill %s, line %d\n" ,
469+ __FILE__ , __LINE__ ));
470+ goto error ;
471+ }
472+
473+ if (!(dselector = DPCTLFilterSelector_Create (device ))) {
474+ NRT_Debug (nrt_debug_print (
475+ "DPEXRT-ERROR: Could not create a sycl::device_selector from "
476+ "filter string: %s at %s %d.\n" ,
477+ device , __FILE__ , __LINE__ ));
478+ goto error ;
479+ }
480+
481+ if (!(dref = DPCTLDevice_CreateFromSelector (dselector )))
482+ goto error ;
483+
484+ if (!(qref = DPCTLQueue_CreateForDevice (dref , NULL , 0 )))
485+ goto error ;
486+
487+ DPCTLDeviceSelector_Delete (dselector );
488+ DPCTLDevice_Delete (dref );
489+
490+ switch (exp ) {
491+ case 3 :
492+ {
493+ uint64_t value_assign = (uint64_t )value ;
494+ if (is_float ) {
495+ double const_val = (double )value ;
496+ // To stop warning: dereferencing type-punned pointer
497+ // will break strict-aliasing rules [-Wstrict-aliasing]
498+ double * p = & const_val ;
499+ value_assign = * ((uint64_t * )(p ));
500+ }
501+ if (!(eref = DPCTLQueue_Fill64 (qref , mi -> data , value_assign , count )))
502+ goto error ;
503+ break ;
504+ }
505+ case 2 :
506+ {
507+ uint32_t value_assign = (uint32_t )value ;
508+ if (is_float ) {
509+ float const_val = (float )value ;
510+ // To stop warning: dereferencing type-punned pointer
511+ // will break strict-aliasing rules [-Wstrict-aliasing]
512+ float * p = & const_val ;
513+ value_assign = * ((uint32_t * )(p ));
514+ }
515+ if (!(eref = DPCTLQueue_Fill32 (qref , mi -> data , value_assign , count )))
516+ goto error ;
517+ break ;
518+ }
519+ case 1 :
520+ if (!(eref = DPCTLQueue_Fill16 (qref , mi -> data , value , count )))
521+ goto error ;
522+ break ;
523+ case 0 :
524+ if (!(eref = DPCTLQueue_Fill8 (qref , mi -> data , value , count )))
525+ goto error ;
526+ break ;
527+ default :
528+ goto error ;
529+ }
530+
531+ DPCTLEvent_Wait (eref );
532+
533+ DPCTLQueue_Delete (qref );
534+ DPCTLEvent_Delete (eref );
535+
536+ return mi ;
537+
538+ error :
539+ DPCTLQueue_Delete (qref );
540+ DPCTLEvent_Delete (eref );
541+ DPCTLDeviceSelector_Delete (dselector );
542+ DPCTLDevice_Delete (dref );
543+
544+ return NULL ;
545+ }
546+
427547/*----------------------------------------------------------------------------*/
428548/*--------- Helpers to get attributes out of a dpnp.ndarray PyObject ---------*/
429549/*----------------------------------------------------------------------------*/
@@ -487,12 +607,13 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
487607 arystruct_t * arystruct )
488608{
489609 struct PyUSMArrayObject * arrayobj = NULL ;
490- int i , ndim ;
610+ int i = 0 , ndim = 0 , exp = 0 ;
491611 npy_intp * shape = NULL , * strides = NULL ;
492- npy_intp * p = NULL , nitems , itemsize ;
612+ npy_intp * p = NULL , nitems ;
493613 void * data = NULL ;
494614 DPCTLSyclQueueRef qref = NULL ;
495615 PyGILState_STATE gstate ;
616+ npy_intp itemsize = 0 ;
496617
497618 // Increment the ref count on obj to prevent CPython from garbage
498619 // collecting the array.
@@ -546,20 +667,29 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
546667
547668 p = arystruct -> shape_and_strides ;
548669
670+ // Calculate the exponent from the arystruct->itemsize as we know
671+ // itemsize is a power of two
672+ while (itemsize >>= 1 )
673+ exp ++ ;
674+
549675 for (i = 0 ; i < ndim ; ++ i , ++ p )
550676 * p = shape [i ];
551677
552- // DPCTL returns a NULL pointer if the array is contiguous
678+ // DPCTL returns a NULL pointer if the array is contiguous. dpctl stores
679+ // strides as number of elements and Numba stores strides as bytes, for
680+ // that reason we are multiplying stride by itemsize when unboxing the
681+ // external array.
682+
553683 // FIXME: Stride computation should check order and adjust how strides are
554684 // calculated. Right now strides are assuming that order is C contigous.
555685 if (strides ) {
556686 for (i = 0 ; i < ndim ; ++ i , ++ p ) {
557- * p = strides [i ];
687+ * p = strides [i ] << exp ;
558688 }
559689 }
560690 else {
561691 for (i = 1 ; i < ndim ; ++ i , ++ p ) {
562- * p = shape [i ];
692+ * p = shape [i ] << exp ;
563693 }
564694 * p = 1 ;
565695 }
@@ -598,11 +728,12 @@ static PyObject *box_from_arystruct_parent(arystruct_t *arystruct,
598728 int ndim ,
599729 PyArray_Descr * descr )
600730{
601- int i ;
602- npy_intp * p ;
731+ int i = 0 , exp = 0 ;
732+ npy_intp * p = NULL ;
603733 npy_intp * shape = NULL , * strides = NULL ;
604734 PyObject * array = arystruct -> parent ;
605735 struct PyUSMArrayObject * arrayobj = NULL ;
736+ npy_intp itemsize = 0 ;
606737
607738 NRT_Debug (nrt_debug_print ("DPEXRT-DEBUG: In try_to_return_parent.\n" ));
608739
@@ -623,9 +754,16 @@ static PyObject *box_from_arystruct_parent(arystruct_t *arystruct,
623754 if (shape [i ] != * p )
624755 return NULL ;
625756 }
626-
757+ // Calculate the exponent from the arystruct->itemsize as we know
758+ // itemsize is a power of two
759+ itemsize = arystruct -> itemsize ;
760+ while (itemsize >>= 1 )
761+ exp ++ ;
762+ // dpctl stores strides as number of elements and Numba stores strides as
763+ // bytes, for that reason we are multiplying stride by itemsize when
764+ // unboxing the external array.
627765 if (strides ) {
628- if (strides [i ] != * p )
766+ if (strides [i ] << exp != * p )
629767 return NULL ;
630768 }
631769 else {
@@ -680,6 +818,8 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(arystruct_t *arystruct,
680818 npy_intp * shape = NULL , * strides = NULL ;
681819 int typenum = 0 ;
682820 int status = 0 ;
821+ int exp = 0 ;
822+ npy_intp itemsize = 0 ;
683823
684824 NRT_Debug (nrt_debug_print (
685825 "DPEXRT-DEBUG: In DPEXRT_sycl_usm_ndarray_to_python_acqref.\n" ));
@@ -750,7 +890,20 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(arystruct_t *arystruct,
750890 }
751891
752892 shape = arystruct -> shape_and_strides ;
753- strides = shape + ndim ;
893+
894+ // Calculate the exponent from the arystruct->itemsize as we know
895+ // itemsize is a power of two
896+ itemsize = arystruct -> itemsize ;
897+ while (itemsize >>= 1 )
898+ exp ++ ;
899+
900+ // Numba internally stores strides as bytes and not as elements. Divide
901+ // the stride by itemsize to get number of elements.
902+ for (size_t idx = ndim ; idx < 2 * ((size_t )ndim ); ++ idx )
903+ arystruct -> shape_and_strides [idx ] =
904+ arystruct -> shape_and_strides [idx ] >> exp ;
905+ strides = (shape + ndim );
906+
754907 typenum = descr -> type_num ;
755908 usm_ndarr_obj = UsmNDArray_MakeFromPtr (
756909 ndim , shape , typenum , strides , (DPCTLSyclUSMRef )arystruct -> data ,
@@ -845,6 +998,7 @@ static PyObject *build_c_helpers_dict(void)
845998 _declpointer ("DPEXRT_sycl_usm_ndarray_to_python_acqref" ,
846999 & DPEXRT_sycl_usm_ndarray_to_python_acqref );
8471000 _declpointer ("DPEXRT_MemInfo_alloc" , & DPEXRT_MemInfo_alloc );
1001+ _declpointer ("DPEXRT_MemInfo_fill" , & DPEXRT_MemInfo_fill );
8481002 _declpointer ("NRT_ExternalAllocator_new_for_usm" ,
8491003 & NRT_ExternalAllocator_new_for_usm );
8501004
@@ -895,7 +1049,8 @@ MOD_INIT(_dpexrt_python)
8951049 PyLong_FromVoidPtr (& DPEXRT_sycl_usm_ndarray_to_python_acqref ));
8961050 PyModule_AddObject (m , "DPEXRT_MemInfo_alloc" ,
8971051 PyLong_FromVoidPtr (& DPEXRT_MemInfo_alloc ));
898-
1052+ PyModule_AddObject (m , "DPEXRT_MemInfo_fill" ,
1053+ PyLong_FromVoidPtr (& DPEXRT_MemInfo_fill ));
8991054 PyModule_AddObject (m , "c_helpers" , build_c_helpers_dict ());
9001055 return MOD_SUCCESS_VAL (m );
9011056}
0 commit comments