2323#include "_queuestruct.h"
2424#include "numba/_arraystruct.h"
2525
26+ /**
27+ * @brief A union for bit representations of float.
28+ * This is useful in DPEXRT_MemInfo_fill() function.
29+ */
30+ typedef union
31+ {
32+ float f_ ; /**< The float to be represented. */
33+ uint32_t i_ ; /**< The bit representation. */
34+ } float_uint32_t ;
35+
36+ /**
37+ * @brief A union for bit representations of double.
38+ * This is useful in DPEXRT_MemInfo_fill() function.
39+ */
40+ typedef union
41+ {
42+ double d_ ; /**< The double to be represented. */
43+ uint64_t i_ ; /**< The bit representation. */
44+ } double_uint64_t ;
45+
2646// forward declarations
2747static struct PyUSMArrayObject * PyUSMNdArray_ARRAYOBJ (PyObject * obj );
2848static npy_intp product_of_shape (npy_intp * shape , npy_intp ndim );
@@ -34,6 +54,12 @@ static NRT_ExternalAllocator *
3454NRT_ExternalAllocator_new_for_usm (DPCTLSyclQueueRef qref , size_t usm_type );
3555static void * DPEXRTQueue_CreateFromFilterString (const char * device );
3656static MemInfoDtorInfo * MemInfoDtorInfo_new (NRT_MemInfo * mi , PyObject * owner );
57+ static NRT_MemInfo * DPEXRT_MemInfo_fill (NRT_MemInfo * mi ,
58+ size_t itemsize ,
59+ bool dest_is_float ,
60+ bool value_is_float ,
61+ int64_t value ,
62+ const char * device );
3763static NRT_MemInfo * NRT_MemInfo_new_from_usmndarray (PyObject * ndarrobj ,
3864 void * data ,
3965 npy_intp nitems ,
@@ -510,19 +536,21 @@ DPEXRT_MemInfo_alloc(npy_intp size, size_t usm_type, const char *device)
510536 * This function takes an allocated memory as NRT_MemInfo and fills it with
511537 * the value specified by `value`.
512538 *
513- * @param mi An NRT_MemInfo object, should be found from memory
514- * allocation.
515- * @param itemsize The itemsize, the size of each item in the array.
516- * @param is_float Flag to specify if the data being float or not.
517- * @param value The value to be used to fill an array.
518- * @param device The device on which the memory was allocated.
519- * @return NRT_MemInfo* A new NRT_MemInfo object, NULL if no NRT_MemInfo
520- * object could be created.
539+ * @param mi An NRT_MemInfo object, should be found from memory
540+ * allocation.
541+ * @param itemsize The itemsize, the size of each item in the array.
542+ * @param dest_is_float True if the destination array's dtype is float.
543+ * @param value_is_float True if the value to be filled is float.
544+ * @param value The value to be used to fill an array.
545+ * @param device The device on which the memory was allocated.
546+ * @return NRT_MemInfo* A new NRT_MemInfo object, NULL if no NRT_MemInfo
547+ * object could be created.
521548 */
522549static NRT_MemInfo * DPEXRT_MemInfo_fill (NRT_MemInfo * mi ,
523550 size_t itemsize ,
524- bool is_float ,
525- uint8_t value ,
551+ bool dest_is_float ,
552+ bool value_is_float ,
553+ int64_t value ,
526554 const char * device )
527555{
528556 DPCTLSyclQueueRef qref = NULL ;
@@ -552,40 +580,80 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
552580 switch (exp ) {
553581 case 3 :
554582 {
555- uint64_t value_assign = (uint64_t )value ;
556- if (is_float ) {
557- double const_val = (double )value ;
583+ int64_t value_assign = (int64_t )value ;
584+ if (dest_is_float && value_is_float ) {
585+ double_uint64_t du ;
586+ double * p = (double * )(& value );
587+ du .d_ = * p ;
588+ value_assign = du .i_ ;
589+ }
590+ else if (dest_is_float && !value_is_float ) {
591+ double_uint64_t du ;
558592 // To stop warning: dereferencing type-punned pointer
559593 // will break strict-aliasing rules [-Wstrict-aliasing]
560- double * p = & const_val ;
561- value_assign = * ((uint64_t * )(p ));
594+ double cd = (double )value ;
595+ du .d_ = * ((double * )(& cd ));
596+ value_assign = du .i_ ;
597+ }
598+ else if (!dest_is_float && value_is_float ) {
599+ double * p = (double * )& value ;
600+ value_assign = * p ;
562601 }
563602 if (!(eref = DPCTLQueue_Fill64 (qref , mi -> data , value_assign , count )))
564603 goto error ;
565604 break ;
566605 }
567606 case 2 :
568607 {
569- uint32_t value_assign = (uint32_t )value ;
570- if (is_float ) {
571- float const_val = (float )value ;
608+ int32_t value_assign = (int32_t )value ;
609+ if (dest_is_float && value_is_float ) {
610+ float_uint32_t fu ;
611+ double * p = (double * )(& value );
612+ fu .f_ = * p ;
613+ value_assign = fu .i_ ;
614+ }
615+ else if (dest_is_float && !value_is_float ) {
616+ float_uint32_t fu ;
572617 // To stop warning: dereferencing type-punned pointer
573618 // will break strict-aliasing rules [-Wstrict-aliasing]
574- float * p = & const_val ;
575- value_assign = * ((uint32_t * )(p ));
619+ float cf = (float )value ;
620+ fu .f_ = * ((float * )(& cf ));
621+ value_assign = fu .i_ ;
622+ }
623+ else if (!dest_is_float && value_is_float ) {
624+ double * p = (double * )& value ;
625+ value_assign = * p ;
576626 }
577627 if (!(eref = DPCTLQueue_Fill32 (qref , mi -> data , value_assign , count )))
578628 goto error ;
579629 break ;
580630 }
581631 case 1 :
582- if (!(eref = DPCTLQueue_Fill16 (qref , mi -> data , value , count )))
632+ {
633+ if (dest_is_float )
634+ goto error ;
635+ int16_t value_assign = (int16_t )value ;
636+ if (value_is_float ) {
637+ double * p = (double * )& value ;
638+ value_assign = * p ;
639+ }
640+ if (!(eref = DPCTLQueue_Fill16 (qref , mi -> data , value_assign , count )))
583641 goto error ;
584642 break ;
643+ }
585644 case 0 :
586- if (!(eref = DPCTLQueue_Fill8 (qref , mi -> data , value , count )))
645+ {
646+ if (dest_is_float )
647+ goto error ;
648+ int8_t value_assign = (int8_t )value ;
649+ if (value_is_float ) {
650+ double * p = (double * )& value ;
651+ value_assign = * p ;
652+ }
653+ if (!(eref = DPCTLQueue_Fill8 (qref , mi -> data , value_assign , count )))
587654 goto error ;
588655 break ;
656+ }
589657 default :
590658 goto error ;
591659 }
0 commit comments