|
4 | 4 |
|
5 | 5 | import dpnp |
6 | 6 | import numpy as np |
7 | | -import pytest |
| 7 | +from numba import int32, int64 |
8 | 8 |
|
9 | | -import numba_dpex as dpex |
10 | | -from numba_dpex import float32, int32 |
| 9 | +import numba_dpex.experimental as dpex |
11 | 10 |
|
12 | | -single_signature = dpex.func(int32(int32)) |
13 | | -list_signature = dpex.func([int32(int32), float32(float32)]) |
| 11 | +i32_signature = dpex.device_func(int32(int32)) |
| 12 | +i32i64_signature = dpex.device_func([int32(int32), int64(int64)]) |
14 | 13 |
|
15 | 14 | # Array size |
16 | | -N = 10 |
| 15 | +N = 1024 |
17 | 16 |
|
18 | 17 |
|
19 | 18 | def increment(a): |
20 | | - return a + dpnp.float32(1) |
| 19 | + return a + 1 |
21 | 20 |
|
22 | 21 |
|
23 | | -def test_basic(): |
24 | | - """Basic test with device func""" |
| 22 | +fi32 = i32_signature(increment) |
| 23 | +fi32i64 = i32i64_signature(increment) |
25 | 24 |
|
26 | | - f = dpex.func(increment) |
27 | 25 |
|
28 | | - def kernel_function(a, b): |
29 | | - """Kernel function that applies f() in parallel""" |
30 | | - i = dpex.get_global_id(0) |
31 | | - b[i] = f(a[i]) |
| 26 | +@dpex.kernel |
| 27 | +def kernel_function(item, a, b): |
| 28 | + """Kernel function that calls fi32()""" |
| 29 | + i = item.get_id(0) |
| 30 | + b[i] = fi32(a[i]) |
32 | 31 |
|
33 | | - k = dpex.kernel(kernel_function) |
34 | 32 |
|
35 | | - a = dpnp.ones(N) |
36 | | - b = dpnp.ones(N) |
| 33 | +@dpex.kernel |
| 34 | +def kernel_function2(item, a, b): |
| 35 | + """Kernel function that calls fi32i64()""" |
| 36 | + i = item.get_id(0) |
| 37 | + b[i] = fi32i64(a[i]) |
37 | 38 |
|
38 | | - dpex.call_kernel(k, dpex.Range(N), a, b) |
39 | 39 |
|
40 | | - assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1) |
41 | | - |
42 | | - |
43 | | -def test_single_signature(): |
44 | | - """Basic test with single signature""" |
45 | | - |
46 | | - fi32 = single_signature(increment) |
47 | | - |
48 | | - def kernel_function(a, b): |
49 | | - """Kernel function that applies fi32() in parallel""" |
50 | | - i = dpex.get_global_id(0) |
51 | | - b[i] = fi32(a[i]) |
52 | | - |
53 | | - k = dpex.kernel(kernel_function) |
54 | | - |
55 | | - # Test with int32, should work |
56 | | - a = dpnp.ones(N, dtype=dpnp.int32) |
57 | | - b = dpnp.ones(N, dtype=dpnp.int32) |
58 | | - |
59 | | - dpex.call_kernel(k, dpex.Range(N), a, b) |
60 | | - |
61 | | - assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1) |
62 | | - |
63 | | - # Test with int64, should fail |
64 | | - a = dpnp.ones(N, dtype=dpnp.int64) |
65 | | - b = dpnp.ones(N, dtype=dpnp.int64) |
66 | | - |
67 | | - with pytest.raises(Exception) as e: |
68 | | - dpex.call_kernel(k, dpex.Range(N), a, b) |
69 | | - |
70 | | - assert " >>> <unknown function>(int64)" in e.value.args[0] |
71 | | - |
72 | | - |
73 | | -def test_list_signature(): |
74 | | - """Basic test with list signature""" |
75 | | - |
76 | | - fi32f32 = list_signature(increment) |
77 | | - |
78 | | - def kernel_function(a, b): |
79 | | - """Kernel function that applies fi32f32() in parallel""" |
80 | | - i = dpex.get_global_id(0) |
81 | | - b[i] = fi32f32(a[i]) |
82 | | - |
83 | | - k = dpex.kernel(kernel_function) |
84 | | - |
85 | | - # Test with int32, should work |
| 40 | +def test_calling_specialized_device_func(): |
| 41 | + """Tests if a specialized device_func gets called as expected from kernel""" |
86 | 42 | a = dpnp.ones(N, dtype=dpnp.int32) |
87 | | - b = dpnp.ones(N, dtype=dpnp.int32) |
| 43 | + b = dpnp.zeros(N, dtype=dpnp.int32) |
88 | 44 |
|
89 | | - dpex.call_kernel(k, dpex.Range(N), a, b) |
| 45 | + dpex.call_kernel(kernel_function, dpex.Range(N), a, b) |
90 | 46 |
|
91 | 47 | assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1) |
92 | 48 |
|
93 | | - # Test with float32, should work |
94 | | - a = dpnp.ones(N, dtype=dpnp.float32) |
95 | | - b = dpnp.ones(N, dtype=dpnp.float32) |
96 | 49 |
|
97 | | - dpex.call_kernel(k, dpex.Range(N), a, b) |
| 50 | +def test_calling_specialized_device_func_wrong_signature(): |
| 51 | + """Tests that calling specialized signature with wrong signature does not |
| 52 | + trigger recompilation. |
98 | 53 |
|
99 | | - assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1) |
| 54 | + Tests kernel_function with float32. Numba will downcast float32 to int32 |
| 55 | + and call the specialized function. The implicit casting is a problem, but |
| 56 | + for the purpose of this test case, all we care is to check if the |
| 57 | + specialized function was called and we did not recompiled the device_func. |
| 58 | + Refer: https://github.com/numba/numba/issues/9506 |
100 | 59 |
|
| 60 | + """ |
101 | 61 | # Test with int64, should fail |
102 | | - a = dpnp.ones(N, dtype=dpnp.int64) |
103 | | - b = dpnp.ones(N, dtype=dpnp.int64) |
104 | | - |
105 | | - with pytest.raises(Exception) as e: |
106 | | - dpex.call_kernel(k, dpex.Range(N), a, b) |
107 | | - |
108 | | - assert " >>> <unknown function>(int64)" in e.value.args[0] |
| 62 | + a = dpnp.full(N, 1.5, dtype=dpnp.float32) |
| 63 | + b = dpnp.zeros(N, dtype=dpnp.float32) |
| 64 | + |
| 65 | + dpex.call_kernel(kernel_function, dpex.Range(N), a, b) |
| 66 | + |
| 67 | + # Since Numba is calling the i32 specialization of increment, the values in |
| 68 | + # `a` are first down converted to int32, *i.e.*, 1.5 to 1 and then |
| 69 | + # incremented. Thus, the output is 2 instead of 2.5. |
| 70 | + # The implicit down casting is a dangerous thing for Numba to do, but we use |
| 71 | + # to our advantage to test if re compilation did not happen for a |
| 72 | + # specialized device function. |
| 73 | + assert np.all(dpnp.asnumpy(b) == 2) |
| 74 | + assert not np.all(dpnp.asnumpy(b) == 2.5) |
| 75 | + |
| 76 | + |
| 77 | +def test_multi_specialized_device_func(): |
| 78 | + """Tests if a device_func with multiple specialization can be called |
| 79 | + in a kernel |
| 80 | + """ |
| 81 | + # Test with int32, i64 should work |
| 82 | + ai32 = dpnp.ones(N, dtype=dpnp.int32) |
| 83 | + bi32 = dpnp.ones(N, dtype=dpnp.int32) |
| 84 | + ai64 = dpnp.ones(N, dtype=dpnp.int64) |
| 85 | + bi64 = dpnp.ones(N, dtype=dpnp.int64) |
| 86 | + |
| 87 | + dpex.call_kernel(kernel_function2, dpex.Range(N), ai32, bi32) |
| 88 | + dpex.call_kernel(kernel_function2, dpex.Range(N), ai64, bi64) |
| 89 | + |
| 90 | + assert np.array_equal(dpnp.asnumpy(bi32), dpnp.asnumpy(ai32) + 1) |
| 91 | + assert np.array_equal(dpnp.asnumpy(bi64), dpnp.asnumpy(ai64) + 1) |
0 commit comments