Skip to content

Commit 5da25f1

Browse files
author
khaled
committed
dpnp.full() complete and test case added
Remove 'like' from dpnp.empty() overload ty_x --> ty_x1
1 parent 46c09ff commit 5da25f1

File tree

11 files changed

+401
-87
lines changed

11 files changed

+401
-87
lines changed

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,28 @@
2323
#include "_queuestruct.h"
2424
#include "numba/_arraystruct.h"
2525

26+
/* Debugging facilities - enabled at compile-time */
27+
/* #undef NDEBUG */
28+
#if 0
29+
#include <stdio.h>
30+
#define DPEXRT_DEBUG(X) \
31+
{ \
32+
X; \
33+
fflush(stdout); \
34+
}
35+
#else
36+
#define DPEXRT_DEBUG(X) \
37+
if (0) { \
38+
X; \
39+
}
40+
#endif
41+
42+
typedef union
43+
{
44+
float f_;
45+
uint32_t i_;
46+
} float32bits;
47+
2648
// forward declarations
2749
static struct PyUSMArrayObject *PyUSMNdArray_ARRAYOBJ(PyObject *obj);
2850
static npy_intp product_of_shape(npy_intp *shape, npy_intp ndim);
@@ -34,6 +56,12 @@ static NRT_ExternalAllocator *
3456
NRT_ExternalAllocator_new_for_usm(DPCTLSyclQueueRef qref, size_t usm_type);
3557
static void *DPEXRTQueue_CreateFromFilterString(const char *device);
3658
static MemInfoDtorInfo *MemInfoDtorInfo_new(NRT_MemInfo *mi, PyObject *owner);
59+
static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
60+
size_t itemsize,
61+
bool dest_is_float,
62+
bool value_is_float,
63+
int64_t value,
64+
const char *device);
3765
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
3866
void *data,
3967
npy_intp nitems,
@@ -521,8 +549,9 @@ DPEXRT_MemInfo_alloc(npy_intp size, size_t usm_type, const char *device)
521549
*/
522550
static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
523551
size_t itemsize,
524-
bool is_float,
525-
uint8_t value,
552+
bool dest_is_float,
553+
bool value_is_float,
554+
int64_t value,
526555
const char *device)
527556
{
528557
DPCTLSyclQueueRef qref = NULL;
@@ -553,12 +582,16 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
553582
case 3:
554583
{
555584
uint64_t value_assign = (uint64_t)value;
556-
if (is_float) {
557-
double const_val = (double)value;
585+
if (dest_is_float && !value_is_float) {
586+
double const_val = value;
558587
// To stop warning: dereferencing type-punned pointer
559588
// will break strict-aliasing rules [-Wstrict-aliasing]
560-
double *p = &const_val;
561-
value_assign = *((uint64_t *)(p));
589+
double *p = (double *)&const_val;
590+
value_assign = *((int64_t *)(p));
591+
}
592+
else if (!dest_is_float && value_is_float) {
593+
double *p = (double *)&value;
594+
value_assign = *p;
562595
}
563596
if (!(eref = DPCTLQueue_Fill64(qref, mi->data, value_assign, count)))
564597
goto error;
@@ -567,25 +600,53 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
567600
case 2:
568601
{
569602
uint32_t value_assign = (uint32_t)value;
570-
if (is_float) {
603+
if (dest_is_float && value_is_float) {
604+
float32bits fb;
605+
double *p = (double *)(&value);
606+
fb.f_ = *p;
607+
value_assign = fb.i_;
608+
}
609+
else if (dest_is_float && !value_is_float) {
571610
float const_val = (float)value;
572611
// To stop warning: dereferencing type-punned pointer
573612
// will break strict-aliasing rules [-Wstrict-aliasing]
574-
float *p = &const_val;
575-
value_assign = *((uint32_t *)(p));
613+
float *p = (float *)&const_val;
614+
value_assign = *((int32_t *)(p));
615+
}
616+
else if (!dest_is_float && value_is_float) {
617+
double *p = (double *)&value;
618+
value_assign = *p;
576619
}
577620
if (!(eref = DPCTLQueue_Fill32(qref, mi->data, value_assign, count)))
578621
goto error;
579622
break;
580623
}
581624
case 1:
582-
if (!(eref = DPCTLQueue_Fill16(qref, mi->data, value, count)))
625+
{
626+
if (dest_is_float)
627+
goto error;
628+
uint16_t value_assign = (uint16_t)value;
629+
if (value_is_float) {
630+
double *p = (double *)&value;
631+
value_assign = *p;
632+
}
633+
if (!(eref = DPCTLQueue_Fill16(qref, mi->data, value_assign, count)))
583634
goto error;
584635
break;
636+
}
585637
case 0:
586-
if (!(eref = DPCTLQueue_Fill8(qref, mi->data, value, count)))
638+
{
639+
if (dest_is_float)
640+
goto error;
641+
uint8_t value_assign = (uint8_t)value;
642+
if (value_is_float) {
643+
double *p = (double *)&value;
644+
value_assign = *p;
645+
}
646+
if (!(eref = DPCTLQueue_Fill8(qref, mi->data, value_assign, count)))
587647
goto error;
588648
break;
649+
}
589650
default:
590651
goto error;
591652
}

numba_dpex/core/runtime/context.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,25 @@ def meminfo_alloc(self, builder, size, usm_type, device):
3333
return self.meminfo_alloc_unchecked(builder, size, usm_type, device)
3434

3535
@_check_null_result
36-
def meminfo_fill(self, builder, meminfo, itemsize, is_float, value, device):
36+
def meminfo_fill(
37+
self,
38+
builder,
39+
meminfo,
40+
itemsize,
41+
dest_is_float,
42+
value_is_float,
43+
value,
44+
device,
45+
):
3746
"""A wrapped caller for meminfo_fill_unchecked() with null check."""
3847
return self.meminfo_fill_unchecked(
39-
builder, meminfo, itemsize, is_float, value, device
48+
builder,
49+
meminfo,
50+
itemsize,
51+
dest_is_float,
52+
value_is_float,
53+
value,
54+
device,
4055
)
4156

4257
def meminfo_alloc_unchecked(self, builder, size, usm_type, device):
@@ -71,7 +86,14 @@ def meminfo_alloc_unchecked(self, builder, size, usm_type, device):
7186
return ret
7287

7388
def meminfo_fill_unchecked(
74-
self, builder, meminfo, itemsize, is_float, value, device
89+
self,
90+
builder,
91+
meminfo,
92+
itemsize,
93+
dest_is_float,
94+
value_is_float,
95+
value,
96+
device,
7597
):
7698
"""Fills an allocated `MemInfo` with the value specified.
7799
@@ -96,12 +118,15 @@ def meminfo_fill_unchecked(
96118
b = llvmir.IntType(1)
97119
fnty = llvmir.FunctionType(
98120
cgutils.voidptr_t,
99-
[cgutils.voidptr_t, u64, b, cgutils.int8_t, cgutils.voidptr_t],
121+
[cgutils.voidptr_t, u64, b, b, cgutils.intp_t, cgutils.voidptr_t],
100122
)
101123
fn = cgutils.get_or_insert_function(mod, fnty, "DPEXRT_MemInfo_fill")
102124
fn.return_value.add_attribute("noalias")
103125

104-
ret = builder.call(fn, [meminfo, itemsize, is_float, value, device])
126+
ret = builder.call(
127+
fn,
128+
[meminfo, itemsize, dest_is_float, value_is_float, value, device],
129+
)
105130

106131
return ret
107132

numba_dpex/core/types/usm_ndarray_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(
7474

7575
if not dtype:
7676
dummy_tensor = dpctl.tensor.empty(
77-
shape=1, order=layout, usm_type=usm_type, sycl_queue=self.queue
77+
1, order=layout, usm_type=usm_type, sycl_queue=self.queue
7878
)
7979
# convert dpnp type to numba/numpy type
8080
_dtype = dummy_tensor.dtype

0 commit comments

Comments
 (0)