Skip to content

Commit 518406a

Browse files
author
Diptorup Deb
committed
Small improvements.
- Move the union for bit-casting into the DPEXRT_MemInfo_fill function, as it is the only place it is used. - Use the union for bit conversions for all case to make the code easier to follow. - Changes to numba_dpex.dpnp_iface._intrinsic.fill_arrayobj: - Ensure a bitcast is added by unconditionally and not just for floats. - Change assignment of value_is_float and dest_is_float objects to if-else statements, otherwise we will have two constants in the LLVM IR whene the type is a Float.
1 parent 54ae6ee commit 518406a

File tree

2 files changed

+57
-47
lines changed

2 files changed

+57
-47
lines changed

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,6 @@
2323
#include "_queuestruct.h"
2424
#include "numba/_arraystruct.h"
2525

26-
/**
27-
* @brief A union for bit representations of float.
28-
* This is useful in DPEXRT_MemInfo_fill() function.
29-
*/
30-
typedef union
31-
{
32-
float f_; /**< The float to be represented. */
33-
uint32_t i_; /**< The bit representation. */
34-
} float_uint32_t;
35-
36-
/**
37-
* @brief A union for bit representations of double.
38-
* This is useful in DPEXRT_MemInfo_fill() function.
39-
*/
40-
typedef union
41-
{
42-
double d_; /**< The double to be represented. */
43-
uint64_t i_; /**< The bit representation. */
44-
} double_uint64_t;
45-
4626
// forward declarations
4727
static struct PyUSMArrayObject *PyUSMNdArray_ARRAYOBJ(PyObject *obj);
4828
static npy_intp product_of_shape(npy_intp *shape, npy_intp ndim);
@@ -557,6 +537,26 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
557537
DPCTLSyclEventRef eref = NULL;
558538
size_t count = 0, size = 0, exp = 0;
559539

540+
/**
541+
* @brief A union for bit conversion from the input int64_t value
542+
* to a uintX_t bit-pattern with appropriate type conversion when the
543+
* input value represents a float.
544+
*/
545+
typedef union
546+
{
547+
float f_; /**< The float to be represented. */
548+
double d_;
549+
int8_t i8_;
550+
int16_t i16_;
551+
int32_t i32_;
552+
int64_t i64_;
553+
uint8_t ui8_;
554+
uint16_t ui16_;
555+
uint32_t ui32_; /**< The bit representation. */
556+
uint64_t ui64_; /**< The bit representation. */
557+
} bitcaster_t;
558+
559+
bitcaster_t bc;
560560
size = mi->size;
561561
while (itemsize >>= 1)
562562
exp++;
@@ -580,77 +580,83 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
580580
switch (exp) {
581581
case 3:
582582
{
583-
int64_t value_assign = (int64_t)value;
584583
if (dest_is_float && value_is_float) {
585-
double_uint64_t du;
586584
double *p = (double *)(&value);
587-
du.d_ = *p;
588-
value_assign = du.i_;
585+
bc.d_ = *p;
589586
}
590587
else if (dest_is_float && !value_is_float) {
591-
double_uint64_t du;
592588
// To stop warning: dereferencing type-punned pointer
593589
// will break strict-aliasing rules [-Wstrict-aliasing]
594590
double cd = (double)value;
595-
du.d_ = *((double *)(&cd));
596-
value_assign = du.i_;
591+
bc.d_ = *((double *)(&cd));
597592
}
598593
else if (!dest_is_float && value_is_float) {
599594
double *p = (double *)&value;
600-
value_assign = *p;
595+
bc.i64_ = *p;
596+
}
597+
else {
598+
bc.i64_ = value;
601599
}
602-
if (!(eref = DPCTLQueue_Fill64(qref, mi->data, value_assign, count)))
600+
601+
if (!(eref = DPCTLQueue_Fill64(qref, mi->data, bc.ui64_, count)))
603602
goto error;
604603
break;
605604
}
606605
case 2:
607606
{
608-
int32_t value_assign = (int32_t)value;
609607
if (dest_is_float && value_is_float) {
610-
float_uint32_t fu;
611608
double *p = (double *)(&value);
612-
fu.f_ = *p;
613-
value_assign = fu.i_;
609+
bc.f_ = *p;
614610
}
615611
else if (dest_is_float && !value_is_float) {
616-
float_uint32_t fu;
617612
// To stop warning: dereferencing type-punned pointer
618613
// will break strict-aliasing rules [-Wstrict-aliasing]
619614
float cf = (float)value;
620-
fu.f_ = *((float *)(&cf));
621-
value_assign = fu.i_;
615+
bc.f_ = *((float *)(&cf));
622616
}
623617
else if (!dest_is_float && value_is_float) {
624618
double *p = (double *)&value;
625-
value_assign = *p;
619+
bc.i32_ = *p;
620+
}
621+
else {
622+
bc.i32_ = (int32_t)value;
626623
}
627-
if (!(eref = DPCTLQueue_Fill32(qref, mi->data, value_assign, count)))
624+
625+
if (!(eref = DPCTLQueue_Fill32(qref, mi->data, bc.ui32_, count)))
628626
goto error;
629627
break;
630628
}
631629
case 1:
632630
{
633631
if (dest_is_float)
634632
goto error;
635-
int16_t value_assign = (int16_t)value;
633+
636634
if (value_is_float) {
637635
double *p = (double *)&value;
638-
value_assign = *p;
636+
bc.i16_ = *p;
637+
}
638+
else {
639+
bc.i16_ = (int16_t)value;
639640
}
640-
if (!(eref = DPCTLQueue_Fill16(qref, mi->data, value_assign, count)))
641+
642+
if (!(eref = DPCTLQueue_Fill16(qref, mi->data, bc.ui16_, count)))
641643
goto error;
642644
break;
643645
}
644646
case 0:
645647
{
646648
if (dest_is_float)
647649
goto error;
648-
int8_t value_assign = (int8_t)value;
650+
649651
if (value_is_float) {
650652
double *p = (double *)&value;
651-
value_assign = *p;
653+
bc.i8_ = *p;
652654
}
653-
if (!(eref = DPCTLQueue_Fill8(qref, mi->data, value_assign, count)))
655+
else {
656+
bc.i8_ = (int8_t)value;
657+
}
658+
659+
if (!(eref = DPCTLQueue_Fill8(qref, mi->data, bc.ui8_, count)))
654660
goto error;
655661
break;
656662
}

numba_dpex/dpnp_iface/_intrinsic.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,18 @@ def fill_arrayobj(context, builder, sig, llargs, value, is_like=False):
7878
)
7979
device = context.insert_const_string(builder.module, arrtype[0].device)
8080

81-
value_is_float = context.get_constant(types.boolean, 0)
81+
# Do a bitcast of the input to a 64-bit int.
82+
value = builder.bitcast(value, llvmir.IntType(64))
83+
8284
if isinstance(sig.args[1], types.scalars.Float):
83-
value = builder.bitcast(value, llvmir.IntType(64))
8485
value_is_float = context.get_constant(types.boolean, 1)
86+
else:
87+
value_is_float = context.get_constant(types.boolean, 0)
8588

86-
dest_is_float = context.get_constant(types.boolean, 0)
8789
if isinstance(arrtype[0].dtype, types.scalars.Float):
8890
dest_is_float = context.get_constant(types.boolean, 1)
91+
else:
92+
dest_is_float = context.get_constant(types.boolean, 0)
8993

9094
dpexrtCtx = dpexrt.DpexRTContext(context)
9195
dpexrtCtx.meminfo_fill(

0 commit comments

Comments
 (0)