Skip to content

Commit 3ded18e

Browse files
author
Diptorup Deb
authored
Merge pull request #991 from chudur-budur/feature/dpnp.full
Implementation for dpnp.full()
2 parents bbf978e + 751da72 commit 3ded18e

File tree

11 files changed

+629
-150
lines changed

11 files changed

+629
-150
lines changed

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 98 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ static NRT_ExternalAllocator *
3434
NRT_ExternalAllocator_new_for_usm(DPCTLSyclQueueRef qref, size_t usm_type);
3535
static void *DPEXRTQueue_CreateFromFilterString(const char *device);
3636
static MemInfoDtorInfo *MemInfoDtorInfo_new(NRT_MemInfo *mi, PyObject *owner);
37+
static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
38+
size_t itemsize,
39+
bool dest_is_float,
40+
bool value_is_float,
41+
int64_t value,
42+
const char *device);
3743
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
3844
void *data,
3945
npy_intp nitems,
@@ -510,25 +516,47 @@ DPEXRT_MemInfo_alloc(npy_intp size, size_t usm_type, const char *device)
510516
* This function takes an allocated memory as NRT_MemInfo and fills it with
511517
* the value specified by `value`.
512518
*
513-
* @param mi An NRT_MemInfo object, should be found from memory
514-
* allocation.
515-
* @param itemsize The itemsize, the size of each item in the array.
516-
* @param is_float Flag to specify if the data being float or not.
517-
* @param value The value to be used to fill an array.
518-
* @param device The device on which the memory was allocated.
519-
* @return NRT_MemInfo* A new NRT_MemInfo object, NULL if no NRT_MemInfo
520-
* object could be created.
519+
* @param mi An NRT_MemInfo object, should be found from memory
520+
* allocation.
521+
* @param itemsize The itemsize, the size of each item in the array.
522+
* @param dest_is_float True if the destination array's dtype is float.
523+
* @param value_is_float True if the value to be filled is float.
524+
* @param value The value to be used to fill an array.
525+
* @param device The device on which the memory was allocated.
526+
* @return NRT_MemInfo* A new NRT_MemInfo object, NULL if no NRT_MemInfo
527+
* object could be created.
521528
*/
522529
static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
523530
size_t itemsize,
524-
bool is_float,
525-
uint8_t value,
531+
bool dest_is_float,
532+
bool value_is_float,
533+
int64_t value,
526534
const char *device)
527535
{
528536
DPCTLSyclQueueRef qref = NULL;
529537
DPCTLSyclEventRef eref = NULL;
530538
size_t count = 0, size = 0, exp = 0;
531539

540+
/**
541+
* @brief A union for bit conversion from the input int64_t value
542+
* to a uintX_t bit-pattern with appropriate type conversion when the
543+
* input value represents a float.
544+
*/
545+
typedef union
546+
{
547+
float f_; /**< The float to be represented. */
548+
double d_;
549+
int8_t i8_;
550+
int16_t i16_;
551+
int32_t i32_;
552+
int64_t i64_;
553+
uint8_t ui8_;
554+
uint16_t ui16_;
555+
uint32_t ui32_; /**< The bit representation. */
556+
uint64_t ui64_; /**< The bit representation. */
557+
} bitcaster_t;
558+
559+
bitcaster_t bc;
532560
size = mi->size;
533561
while (itemsize >>= 1)
534562
exp++;
@@ -552,40 +580,86 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
552580
switch (exp) {
553581
case 3:
554582
{
555-
uint64_t value_assign = (uint64_t)value;
556-
if (is_float) {
557-
double const_val = (double)value;
583+
if (dest_is_float && value_is_float) {
584+
double *p = (double *)(&value);
585+
bc.d_ = *p;
586+
}
587+
else if (dest_is_float && !value_is_float) {
558588
// To stop warning: dereferencing type-punned pointer
559589
// will break strict-aliasing rules [-Wstrict-aliasing]
560-
double *p = &const_val;
561-
value_assign = *((uint64_t *)(p));
590+
double cd = (double)value;
591+
bc.d_ = *((double *)(&cd));
592+
}
593+
else if (!dest_is_float && value_is_float) {
594+
double *p = (double *)&value;
595+
bc.i64_ = *p;
596+
}
597+
else {
598+
bc.i64_ = value;
562599
}
563-
if (!(eref = DPCTLQueue_Fill64(qref, mi->data, value_assign, count)))
600+
601+
if (!(eref = DPCTLQueue_Fill64(qref, mi->data, bc.ui64_, count)))
564602
goto error;
565603
break;
566604
}
567605
case 2:
568606
{
569-
uint32_t value_assign = (uint32_t)value;
570-
if (is_float) {
571-
float const_val = (float)value;
607+
if (dest_is_float && value_is_float) {
608+
double *p = (double *)(&value);
609+
bc.f_ = *p;
610+
}
611+
else if (dest_is_float && !value_is_float) {
572612
// To stop warning: dereferencing type-punned pointer
573613
// will break strict-aliasing rules [-Wstrict-aliasing]
574-
float *p = &const_val;
575-
value_assign = *((uint32_t *)(p));
614+
float cf = (float)value;
615+
bc.f_ = *((float *)(&cf));
616+
}
617+
else if (!dest_is_float && value_is_float) {
618+
double *p = (double *)&value;
619+
bc.i32_ = *p;
576620
}
577-
if (!(eref = DPCTLQueue_Fill32(qref, mi->data, value_assign, count)))
621+
else {
622+
bc.i32_ = (int32_t)value;
623+
}
624+
625+
if (!(eref = DPCTLQueue_Fill32(qref, mi->data, bc.ui32_, count)))
578626
goto error;
579627
break;
580628
}
581629
case 1:
582-
if (!(eref = DPCTLQueue_Fill16(qref, mi->data, value, count)))
630+
{
631+
if (dest_is_float)
632+
goto error;
633+
634+
if (value_is_float) {
635+
double *p = (double *)&value;
636+
bc.i16_ = *p;
637+
}
638+
else {
639+
bc.i16_ = (int16_t)value;
640+
}
641+
642+
if (!(eref = DPCTLQueue_Fill16(qref, mi->data, bc.ui16_, count)))
583643
goto error;
584644
break;
645+
}
585646
case 0:
586-
if (!(eref = DPCTLQueue_Fill8(qref, mi->data, value, count)))
647+
{
648+
if (dest_is_float)
649+
goto error;
650+
651+
if (value_is_float) {
652+
double *p = (double *)&value;
653+
bc.i8_ = *p;
654+
}
655+
else {
656+
bc.i8_ = (int8_t)value;
657+
}
658+
659+
if (!(eref = DPCTLQueue_Fill8(qref, mi->data, bc.ui8_, count)))
587660
goto error;
588661
break;
662+
}
589663
default:
590664
goto error;
591665
}

numba_dpex/core/runtime/context.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,35 @@ def wrap(self, builder, *args, **kwargs):
2929

3030
@_check_null_result
3131
def meminfo_alloc(self, builder, size, usm_type, device):
32-
"""A wrapped caller for meminfo_alloc_unchecked() with null check."""
32+
"""
33+
Wrapper to call :func:`~context.DpexRTContext.meminfo_alloc_unchecked`
34+
with null checking of the returned value.
35+
"""
3336
return self.meminfo_alloc_unchecked(builder, size, usm_type, device)
3437

3538
@_check_null_result
36-
def meminfo_fill(self, builder, meminfo, itemsize, is_float, value, device):
37-
"""A wrapped caller for meminfo_fill_unchecked() with null check."""
39+
def meminfo_fill(
40+
self,
41+
builder,
42+
meminfo,
43+
itemsize,
44+
dest_is_float,
45+
value_is_float,
46+
value,
47+
device,
48+
):
49+
"""
50+
Wrapper to call :func:`~context.DpexRTContext.meminfo_fill_unchecked`
51+
with null checking of the returned value.
52+
"""
3853
return self.meminfo_fill_unchecked(
39-
builder, meminfo, itemsize, is_float, value, device
54+
builder,
55+
meminfo,
56+
itemsize,
57+
dest_is_float,
58+
value_is_float,
59+
value,
60+
device,
4061
)
4162

4263
def meminfo_alloc_unchecked(self, builder, size, usm_type, device):
@@ -71,7 +92,14 @@ def meminfo_alloc_unchecked(self, builder, size, usm_type, device):
7192
return ret
7293

7394
def meminfo_fill_unchecked(
74-
self, builder, meminfo, itemsize, is_float, value, device
95+
self,
96+
builder,
97+
meminfo,
98+
itemsize,
99+
dest_is_float,
100+
value_is_float,
101+
value,
102+
device,
75103
):
76104
"""Fills an allocated `MemInfo` with the value specified.
77105
@@ -96,12 +124,15 @@ def meminfo_fill_unchecked(
96124
b = llvmir.IntType(1)
97125
fnty = llvmir.FunctionType(
98126
cgutils.voidptr_t,
99-
[cgutils.voidptr_t, u64, b, cgutils.int8_t, cgutils.voidptr_t],
127+
[cgutils.voidptr_t, u64, b, b, u64, cgutils.voidptr_t],
100128
)
101129
fn = cgutils.get_or_insert_function(mod, fnty, "DPEXRT_MemInfo_fill")
102130
fn.return_value.add_attribute("noalias")
103131

104-
ret = builder.call(fn, [meminfo, itemsize, is_float, value, device])
132+
ret = builder.call(
133+
fn,
134+
[meminfo, itemsize, dest_is_float, value_is_float, value, device],
135+
)
105136

106137
return ret
107138

numba_dpex/core/types/usm_ndarray_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(
7474

7575
if not dtype:
7676
dummy_tensor = dpctl.tensor.empty(
77-
shape=1, order=layout, usm_type=usm_type, sycl_queue=self.queue
77+
1, order=layout, usm_type=usm_type, sycl_queue=self.queue
7878
)
7979
# convert dpnp type to numba/numpy type
8080
_dtype = dummy_tensor.dtype

0 commit comments

Comments
 (0)