Skip to content

Commit 6206b29

Browse files
author
khaled
committed
Working on pytest
1 parent bac5cee commit 6206b29

File tree

2 files changed

+97
-0
lines changed

2 files changed

+97
-0
lines changed

numba_dpex/dpnp_iface/arrayobj.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,7 @@ def impl(
821821
):
822822
return impl_dpnp_full_like(
823823
x1,
824+
fill_value,
824825
_dtype,
825826
_order,
826827
subok,
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Tests for dpnp ndarray constructors."""
6+
7+
import math
8+
9+
import dpctl
10+
import dpctl.tensor as dpt
11+
import dpnp
12+
import numpy
13+
import pytest
14+
from numba import errors
15+
16+
from numba_dpex import dpjit
17+
18+
shapes = [11, (3, 7)]
19+
dtypes = [dpnp.int32, dpnp.int64, dpnp.float32, dpnp.float64]
20+
usm_types = ["device", "shared", "host"]
21+
devices = ["cpu", "unknown"]
22+
# TODO: test with 3.4028237e38, 4294967295 etc.
23+
fill_values = [7, -7, 7.1, -7.1, math.pi, math.e]
24+
25+
26+
@pytest.mark.parametrize("shape", shapes)
27+
@pytest.mark.parametrize("fill_value", fill_values)
28+
@pytest.mark.parametrize("dtype", dtypes)
29+
@pytest.mark.parametrize("usm_type", usm_types)
30+
@pytest.mark.parametrize("device", devices)
31+
def test_dpnp_full_like(shape, fill_value, dtype, usm_type, device):
32+
@dpjit
33+
def func(a, v):
34+
c = dpnp.full_like(a, v, dtype=dtype, usm_type=usm_type, device=device)
35+
return c
36+
37+
if isinstance(shape, int):
38+
NZ = numpy.random.rand(shape)
39+
else:
40+
NZ = numpy.random.rand(*shape)
41+
42+
try:
43+
c = func(NZ, fill_value)
44+
except Exception:
45+
pytest.fail("Calling dpnp.zeros_like inside dpjit failed")
46+
47+
if len(c.shape) == 1:
48+
assert c.shape[0] == NZ.shape[0]
49+
else:
50+
assert c.shape == NZ.shape
51+
52+
assert c.dtype == dtype
53+
assert c.usm_type == usm_type
54+
if device != "unknown":
55+
assert (
56+
c.sycl_device.filter_string
57+
== dpctl.SyclDevice(device).filter_string
58+
)
59+
else:
60+
c.sycl_device.filter_string == dpctl.SyclDevice().filter_string
61+
62+
assert numpy.array_equal(
63+
dpt.asnumpy(c._array_obj), numpy.full_like(c._array_obj, fill_value)
64+
)
65+
66+
67+
def test_dpnp_full_like_exceptions():
68+
@dpjit
69+
def func1(a):
70+
c = dpnp.full_like(a, shape=(3, 3))
71+
return c
72+
73+
try:
74+
func1(numpy.random.rand(5, 5))
75+
except Exception as e:
76+
assert isinstance(e, errors.TypingError)
77+
assert (
78+
"No implementation of function Function(<function full_like"
79+
in str(e)
80+
)
81+
82+
queue = dpctl.SyclQueue()
83+
84+
@dpjit
85+
def func2(a):
86+
c = dpnp.full_like(a, sycl_queue=queue)
87+
return c
88+
89+
try:
90+
func2(numpy.random.rand(5, 5))
91+
except Exception as e:
92+
assert isinstance(e, errors.TypingError)
93+
assert (
94+
"No implementation of function Function(<function full_like"
95+
in str(e)
96+
)

0 commit comments

Comments
 (0)