Skip to content

Commit 54ae6ee

Browse files
author
khaled
committed
dpnp.full() complete and test case added
Remove 'like' from dpnp.empty() overload ty_x --> ty_x1 Bitcasts are done through unions Addressed all review comments
1 parent bbf978e commit 54ae6ee

File tree

11 files changed

+620
-149
lines changed

11 files changed

+620
-149
lines changed

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 90 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,26 @@
2323
#include "_queuestruct.h"
2424
#include "numba/_arraystruct.h"
2525

26+
/**
27+
* @brief A union for bit representations of float.
28+
* This is useful in DPEXRT_MemInfo_fill() function.
29+
*/
30+
typedef union
31+
{
32+
float f_; /**< The float to be represented. */
33+
uint32_t i_; /**< The bit representation. */
34+
} float_uint32_t;
35+
36+
/**
37+
* @brief A union for bit representations of double.
38+
* This is useful in DPEXRT_MemInfo_fill() function.
39+
*/
40+
typedef union
41+
{
42+
double d_; /**< The double to be represented. */
43+
uint64_t i_; /**< The bit representation. */
44+
} double_uint64_t;
45+
2646
// forward declarations
2747
static struct PyUSMArrayObject *PyUSMNdArray_ARRAYOBJ(PyObject *obj);
2848
static npy_intp product_of_shape(npy_intp *shape, npy_intp ndim);
@@ -34,6 +54,12 @@ static NRT_ExternalAllocator *
3454
NRT_ExternalAllocator_new_for_usm(DPCTLSyclQueueRef qref, size_t usm_type);
3555
static void *DPEXRTQueue_CreateFromFilterString(const char *device);
3656
static MemInfoDtorInfo *MemInfoDtorInfo_new(NRT_MemInfo *mi, PyObject *owner);
57+
static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
58+
size_t itemsize,
59+
bool dest_is_float,
60+
bool value_is_float,
61+
int64_t value,
62+
const char *device);
3763
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
3864
void *data,
3965
npy_intp nitems,
@@ -510,19 +536,21 @@ DPEXRT_MemInfo_alloc(npy_intp size, size_t usm_type, const char *device)
510536
* This function takes an allocated memory as NRT_MemInfo and fills it with
511537
* the value specified by `value`.
512538
*
513-
* @param mi An NRT_MemInfo object, should be found from memory
514-
* allocation.
515-
* @param itemsize The itemsize, the size of each item in the array.
516-
* @param is_float Flag to specify if the data being float or not.
517-
* @param value The value to be used to fill an array.
518-
* @param device The device on which the memory was allocated.
519-
* @return NRT_MemInfo* A new NRT_MemInfo object, NULL if no NRT_MemInfo
520-
* object could be created.
539+
* @param mi An NRT_MemInfo object, should be found from memory
540+
* allocation.
541+
* @param itemsize The itemsize, the size of each item in the array.
542+
* @param dest_is_float True if the destination array's dtype is float.
543+
* @param value_is_float True if the value to be filled is float.
544+
* @param value The value to be used to fill an array.
545+
* @param device The device on which the memory was allocated.
546+
* @return NRT_MemInfo* A new NRT_MemInfo object, NULL if no NRT_MemInfo
547+
* object could be created.
521548
*/
522549
static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
523550
size_t itemsize,
524-
bool is_float,
525-
uint8_t value,
551+
bool dest_is_float,
552+
bool value_is_float,
553+
int64_t value,
526554
const char *device)
527555
{
528556
DPCTLSyclQueueRef qref = NULL;
@@ -552,40 +580,80 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
552580
switch (exp) {
553581
case 3:
554582
{
555-
uint64_t value_assign = (uint64_t)value;
556-
if (is_float) {
557-
double const_val = (double)value;
583+
int64_t value_assign = (int64_t)value;
584+
if (dest_is_float && value_is_float) {
585+
double_uint64_t du;
586+
double *p = (double *)(&value);
587+
du.d_ = *p;
588+
value_assign = du.i_;
589+
}
590+
else if (dest_is_float && !value_is_float) {
591+
double_uint64_t du;
558592
// To stop warning: dereferencing type-punned pointer
559593
// will break strict-aliasing rules [-Wstrict-aliasing]
560-
double *p = &const_val;
561-
value_assign = *((uint64_t *)(p));
594+
double cd = (double)value;
595+
du.d_ = *((double *)(&cd));
596+
value_assign = du.i_;
597+
}
598+
else if (!dest_is_float && value_is_float) {
599+
double *p = (double *)&value;
600+
value_assign = *p;
562601
}
563602
if (!(eref = DPCTLQueue_Fill64(qref, mi->data, value_assign, count)))
564603
goto error;
565604
break;
566605
}
567606
case 2:
568607
{
569-
uint32_t value_assign = (uint32_t)value;
570-
if (is_float) {
571-
float const_val = (float)value;
608+
int32_t value_assign = (int32_t)value;
609+
if (dest_is_float && value_is_float) {
610+
float_uint32_t fu;
611+
double *p = (double *)(&value);
612+
fu.f_ = *p;
613+
value_assign = fu.i_;
614+
}
615+
else if (dest_is_float && !value_is_float) {
616+
float_uint32_t fu;
572617
// To stop warning: dereferencing type-punned pointer
573618
// will break strict-aliasing rules [-Wstrict-aliasing]
574-
float *p = &const_val;
575-
value_assign = *((uint32_t *)(p));
619+
float cf = (float)value;
620+
fu.f_ = *((float *)(&cf));
621+
value_assign = fu.i_;
622+
}
623+
else if (!dest_is_float && value_is_float) {
624+
double *p = (double *)&value;
625+
value_assign = *p;
576626
}
577627
if (!(eref = DPCTLQueue_Fill32(qref, mi->data, value_assign, count)))
578628
goto error;
579629
break;
580630
}
581631
case 1:
582-
if (!(eref = DPCTLQueue_Fill16(qref, mi->data, value, count)))
632+
{
633+
if (dest_is_float)
634+
goto error;
635+
int16_t value_assign = (int16_t)value;
636+
if (value_is_float) {
637+
double *p = (double *)&value;
638+
value_assign = *p;
639+
}
640+
if (!(eref = DPCTLQueue_Fill16(qref, mi->data, value_assign, count)))
583641
goto error;
584642
break;
643+
}
585644
case 0:
586-
if (!(eref = DPCTLQueue_Fill8(qref, mi->data, value, count)))
645+
{
646+
if (dest_is_float)
647+
goto error;
648+
int8_t value_assign = (int8_t)value;
649+
if (value_is_float) {
650+
double *p = (double *)&value;
651+
value_assign = *p;
652+
}
653+
if (!(eref = DPCTLQueue_Fill8(qref, mi->data, value_assign, count)))
587654
goto error;
588655
break;
656+
}
589657
default:
590658
goto error;
591659
}

numba_dpex/core/runtime/context.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,37 @@ def wrap(self, builder, *args, **kwargs):
2929

3030
@_check_null_result
3131
def meminfo_alloc(self, builder, size, usm_type, device):
32-
"""A wrapped caller for meminfo_alloc_unchecked() with null check."""
32+
"""
33+
A wrapped caller for :func:`~context.DpexRTContext.meminfo_alloc_unchecked`
34+
with null check. Please refer to that function for the details on how the
35+
null check is done.
36+
"""
3337
return self.meminfo_alloc_unchecked(builder, size, usm_type, device)
3438

3539
@_check_null_result
36-
def meminfo_fill(self, builder, meminfo, itemsize, is_float, value, device):
37-
"""A wrapped caller for meminfo_fill_unchecked() with null check."""
40+
def meminfo_fill(
41+
self,
42+
builder,
43+
meminfo,
44+
itemsize,
45+
dest_is_float,
46+
value_is_float,
47+
value,
48+
device,
49+
):
50+
"""
51+
A wrapped caller for :func:`~context.DpexRTContext.meminfo_fill_unchecked`
52+
with null check. Please refer to that function for the details on how the
53+
null check is done.
54+
"""
3855
return self.meminfo_fill_unchecked(
39-
builder, meminfo, itemsize, is_float, value, device
56+
builder,
57+
meminfo,
58+
itemsize,
59+
dest_is_float,
60+
value_is_float,
61+
value,
62+
device,
4063
)
4164

4265
def meminfo_alloc_unchecked(self, builder, size, usm_type, device):
@@ -71,7 +94,14 @@ def meminfo_alloc_unchecked(self, builder, size, usm_type, device):
7194
return ret
7295

7396
def meminfo_fill_unchecked(
74-
self, builder, meminfo, itemsize, is_float, value, device
97+
self,
98+
builder,
99+
meminfo,
100+
itemsize,
101+
dest_is_float,
102+
value_is_float,
103+
value,
104+
device,
75105
):
76106
"""Fills an allocated `MemInfo` with the value specified.
77107
@@ -96,12 +126,15 @@ def meminfo_fill_unchecked(
96126
b = llvmir.IntType(1)
97127
fnty = llvmir.FunctionType(
98128
cgutils.voidptr_t,
99-
[cgutils.voidptr_t, u64, b, cgutils.int8_t, cgutils.voidptr_t],
129+
[cgutils.voidptr_t, u64, b, b, cgutils.intp_t, cgutils.voidptr_t],
100130
)
101131
fn = cgutils.get_or_insert_function(mod, fnty, "DPEXRT_MemInfo_fill")
102132
fn.return_value.add_attribute("noalias")
103133

104-
ret = builder.call(fn, [meminfo, itemsize, is_float, value, device])
134+
ret = builder.call(
135+
fn,
136+
[meminfo, itemsize, dest_is_float, value_is_float, value, device],
137+
)
105138

106139
return ret
107140

numba_dpex/core/types/usm_ndarray_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(
7474

7575
if not dtype:
7676
dummy_tensor = dpctl.tensor.empty(
77-
shape=1, order=layout, usm_type=usm_type, sycl_queue=self.queue
77+
1, order=layout, usm_type=usm_type, sycl_queue=self.queue
7878
)
7979
# convert dpnp type to numba/numpy type
8080
_dtype = dummy_tensor.dtype

0 commit comments

Comments
 (0)