@@ -34,6 +34,12 @@ static NRT_ExternalAllocator *
3434NRT_ExternalAllocator_new_for_usm (DPCTLSyclQueueRef qref , size_t usm_type );
3535static void * DPEXRTQueue_CreateFromFilterString (const char * device );
3636static MemInfoDtorInfo * MemInfoDtorInfo_new (NRT_MemInfo * mi , PyObject * owner );
37+ static NRT_MemInfo * DPEXRT_MemInfo_fill (NRT_MemInfo * mi ,
38+ size_t itemsize ,
39+ bool dest_is_float ,
40+ bool value_is_float ,
41+ int64_t value ,
42+ const char * device );
3743static NRT_MemInfo * NRT_MemInfo_new_from_usmndarray (PyObject * ndarrobj ,
3844 void * data ,
3945 npy_intp nitems ,
@@ -510,25 +516,47 @@ DPEXRT_MemInfo_alloc(npy_intp size, size_t usm_type, const char *device)
510516 * This function takes an allocated memory as NRT_MemInfo and fills it with
511517 * the value specified by `value`.
512518 *
513- * @param mi An NRT_MemInfo object, should be found from memory
514- * allocation.
515- * @param itemsize The itemsize, the size of each item in the array.
516- * @param is_float Flag to specify if the data being float or not.
517- * @param value The value to be used to fill an array.
518- * @param device The device on which the memory was allocated.
519- * @return NRT_MemInfo* A new NRT_MemInfo object, NULL if no NRT_MemInfo
520- * object could be created.
519+ * @param mi An NRT_MemInfo object, should be found from memory
520+ * allocation.
521+ * @param itemsize The itemsize, the size of each item in the array.
522+ * @param dest_is_float True if the destination array's dtype is float.
523+ * @param value_is_float True if the value to be filled is float.
524+ * @param value The value to be used to fill an array.
525+ * @param device The device on which the memory was allocated.
526+ * @return NRT_MemInfo* A new NRT_MemInfo object, NULL if no NRT_MemInfo
527+ * object could be created.
521528 */
522529static NRT_MemInfo * DPEXRT_MemInfo_fill (NRT_MemInfo * mi ,
523530 size_t itemsize ,
524- bool is_float ,
525- uint8_t value ,
531+ bool dest_is_float ,
532+ bool value_is_float ,
533+ int64_t value ,
526534 const char * device )
527535{
528536 DPCTLSyclQueueRef qref = NULL ;
529537 DPCTLSyclEventRef eref = NULL ;
530538 size_t count = 0 , size = 0 , exp = 0 ;
531539
540+ /**
541+ * @brief A union for bit conversion from the input int64_t value
542+ * to a uintX_t bit-pattern with appropriate type conversion when the
543+ * input value represents a float.
544+ */
545+ typedef union
546+ {
547+ float f_ ; /**< The float to be represented. */
548+ double d_ ;
549+ int8_t i8_ ;
550+ int16_t i16_ ;
551+ int32_t i32_ ;
552+ int64_t i64_ ;
553+ uint8_t ui8_ ;
554+ uint16_t ui16_ ;
555+ uint32_t ui32_ ; /**< The bit representation. */
556+ uint64_t ui64_ ; /**< The bit representation. */
557+ } bitcaster_t ;
558+
559+ bitcaster_t bc ;
532560 size = mi -> size ;
533561 while (itemsize >>= 1 )
534562 exp ++ ;
@@ -552,40 +580,86 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
552580 switch (exp ) {
553581 case 3 :
554582 {
555- uint64_t value_assign = (uint64_t )value ;
556- if (is_float ) {
557- double const_val = (double )value ;
583+ if (dest_is_float && value_is_float ) {
584+ double * p = (double * )(& value );
585+ bc .d_ = * p ;
586+ }
587+ else if (dest_is_float && !value_is_float ) {
558588 // To stop warning: dereferencing type-punned pointer
559589 // will break strict-aliasing rules [-Wstrict-aliasing]
560- double * p = & const_val ;
561- value_assign = * ((uint64_t * )(p ));
590+ double cd = (double )value ;
591+ bc .d_ = * ((double * )(& cd ));
592+ }
593+ else if (!dest_is_float && value_is_float ) {
594+ double * p = (double * )& value ;
595+ bc .i64_ = * p ;
596+ }
597+ else {
598+ bc .i64_ = value ;
562599 }
563- if (!(eref = DPCTLQueue_Fill64 (qref , mi -> data , value_assign , count )))
600+
601+ if (!(eref = DPCTLQueue_Fill64 (qref , mi -> data , bc .ui64_ , count )))
564602 goto error ;
565603 break ;
566604 }
567605 case 2 :
568606 {
569- uint32_t value_assign = (uint32_t )value ;
570- if (is_float ) {
571- float const_val = (float )value ;
607+ if (dest_is_float && value_is_float ) {
608+ double * p = (double * )(& value );
609+ bc .f_ = * p ;
610+ }
611+ else if (dest_is_float && !value_is_float ) {
572612 // To stop warning: dereferencing type-punned pointer
573613 // will break strict-aliasing rules [-Wstrict-aliasing]
574- float * p = & const_val ;
575- value_assign = * ((uint32_t * )(p ));
614+ float cf = (float )value ;
615+ bc .f_ = * ((float * )(& cf ));
616+ }
617+ else if (!dest_is_float && value_is_float ) {
618+ double * p = (double * )& value ;
619+ bc .i32_ = * p ;
576620 }
577- if (!(eref = DPCTLQueue_Fill32 (qref , mi -> data , value_assign , count )))
621+ else {
622+ bc .i32_ = (int32_t )value ;
623+ }
624+
625+ if (!(eref = DPCTLQueue_Fill32 (qref , mi -> data , bc .ui32_ , count )))
578626 goto error ;
579627 break ;
580628 }
581629 case 1 :
582- if (!(eref = DPCTLQueue_Fill16 (qref , mi -> data , value , count )))
630+ {
631+ if (dest_is_float )
632+ goto error ;
633+
634+ if (value_is_float ) {
635+ double * p = (double * )& value ;
636+ bc .i16_ = * p ;
637+ }
638+ else {
639+ bc .i16_ = (int16_t )value ;
640+ }
641+
642+ if (!(eref = DPCTLQueue_Fill16 (qref , mi -> data , bc .ui16_ , count )))
583643 goto error ;
584644 break ;
645+ }
585646 case 0 :
586- if (!(eref = DPCTLQueue_Fill8 (qref , mi -> data , value , count )))
647+ {
648+ if (dest_is_float )
649+ goto error ;
650+
651+ if (value_is_float ) {
652+ double * p = (double * )& value ;
653+ bc .i8_ = * p ;
654+ }
655+ else {
656+ bc .i8_ = (int8_t )value ;
657+ }
658+
659+ if (!(eref = DPCTLQueue_Fill8 (qref , mi -> data , bc .ui8_ , count )))
587660 goto error ;
588661 break ;
662+ }
589663 default :
590664 goto error ;
591665 }
0 commit comments