2323#include "_queuestruct.h"
2424#include "numba/_arraystruct.h"
2525
26+ /* Debugging facilities - enabled at compile-time */
27+ /* #undef NDEBUG */
28+ #if 0
29+ #include <stdio.h>
30+ #define DPEXRT_DEBUG (X ) \
31+ { \
32+ X; \
33+ fflush(stdout); \
34+ }
35+ #else
36+ #define DPEXRT_DEBUG (X ) \
37+ if (0) { \
38+ X; \
39+ }
40+ #endif
41+
42+ typedef union
43+ {
44+ float f_ ;
45+ uint32_t i_ ;
46+ } float32bits ;
47+
2648// forward declarations
2749static struct PyUSMArrayObject * PyUSMNdArray_ARRAYOBJ (PyObject * obj );
2850static npy_intp product_of_shape (npy_intp * shape , npy_intp ndim );
@@ -34,6 +56,12 @@ static NRT_ExternalAllocator *
3456NRT_ExternalAllocator_new_for_usm (DPCTLSyclQueueRef qref , size_t usm_type );
3557static void * DPEXRTQueue_CreateFromFilterString (const char * device );
3658static MemInfoDtorInfo * MemInfoDtorInfo_new (NRT_MemInfo * mi , PyObject * owner );
59+ static NRT_MemInfo * DPEXRT_MemInfo_fill (NRT_MemInfo * mi ,
60+ size_t itemsize ,
61+ bool dest_is_float ,
62+ bool value_is_float ,
63+ int64_t value ,
64+ const char * device );
3765static NRT_MemInfo * NRT_MemInfo_new_from_usmndarray (PyObject * ndarrobj ,
3866 void * data ,
3967 npy_intp nitems ,
@@ -521,8 +549,9 @@ DPEXRT_MemInfo_alloc(npy_intp size, size_t usm_type, const char *device)
521549 */
522550static NRT_MemInfo * DPEXRT_MemInfo_fill (NRT_MemInfo * mi ,
523551 size_t itemsize ,
524- bool is_float ,
525- uint8_t value ,
552+ bool dest_is_float ,
553+ bool value_is_float ,
554+ int64_t value ,
526555 const char * device )
527556{
528557 DPCTLSyclQueueRef qref = NULL ;
@@ -553,12 +582,16 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
553582 case 3 :
554583 {
555584 uint64_t value_assign = (uint64_t )value ;
556- if (is_float ) {
557- double const_val = ( double ) value ;
585+ if (dest_is_float && ! value_is_float ) {
586+ double const_val = value ;
558587 // To stop warning: dereferencing type-punned pointer
559588 // will break strict-aliasing rules [-Wstrict-aliasing]
560- double * p = & const_val ;
561- value_assign = * ((uint64_t * )(p ));
589+ double * p = (double * )& const_val ;
590+ value_assign = * ((int64_t * )(p ));
591+ }
592+ else if (!dest_is_float && value_is_float ) {
593+ double * p = (double * )& value ;
594+ value_assign = * p ;
562595 }
563596 if (!(eref = DPCTLQueue_Fill64 (qref , mi -> data , value_assign , count )))
564597 goto error ;
@@ -567,25 +600,53 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
567600 case 2 :
568601 {
569602 uint32_t value_assign = (uint32_t )value ;
570- if (is_float ) {
603+ if (dest_is_float && value_is_float ) {
604+ float32bits fb ;
605+ double * p = (double * )(& value );
606+ fb .f_ = * p ;
607+ value_assign = fb .i_ ;
608+ }
609+ else if (dest_is_float && !value_is_float ) {
571610 float const_val = (float )value ;
572611 // To stop warning: dereferencing type-punned pointer
573612 // will break strict-aliasing rules [-Wstrict-aliasing]
574- float * p = & const_val ;
575- value_assign = * ((uint32_t * )(p ));
613+ float * p = (float * )& const_val ;
614+ value_assign = * ((int32_t * )(p ));
615+ }
616+ else if (!dest_is_float && value_is_float ) {
617+ double * p = (double * )& value ;
618+ value_assign = * p ;
576619 }
577620 if (!(eref = DPCTLQueue_Fill32 (qref , mi -> data , value_assign , count )))
578621 goto error ;
579622 break ;
580623 }
581624 case 1 :
582- if (!(eref = DPCTLQueue_Fill16 (qref , mi -> data , value , count )))
625+ {
626+ if (dest_is_float )
627+ goto error ;
628+ uint16_t value_assign = (uint16_t )value ;
629+ if (value_is_float ) {
630+ double * p = (double * )& value ;
631+ value_assign = * p ;
632+ }
633+ if (!(eref = DPCTLQueue_Fill16 (qref , mi -> data , value_assign , count )))
583634 goto error ;
584635 break ;
636+ }
585637 case 0 :
586- if (!(eref = DPCTLQueue_Fill8 (qref , mi -> data , value , count )))
638+ {
639+ if (dest_is_float )
640+ goto error ;
641+ uint8_t value_assign = (uint8_t )value ;
642+ if (value_is_float ) {
643+ double * p = (double * )& value ;
644+ value_assign = * p ;
645+ }
646+ if (!(eref = DPCTLQueue_Fill8 (qref , mi -> data , value_assign , count )))
587647 goto error ;
588648 break ;
649+ }
589650 default :
590651 goto error ;
591652 }
0 commit comments