Skip to content

Commit ec5d238

Browse files
author
Diptorup Deb
committed
Adds support for specializing a device_func.
1 parent 800cd00 commit ec5d238

File tree

2 files changed

+79
-81
lines changed

2 files changed

+79
-81
lines changed

numba_dpex/experimental/decorators.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _kernel_dispatcher(pyfunc):
9999
"Argument passed to the kernel decorator is neither a "
100100
"function object, nor a signature. If you are trying to "
101101
"specialize the kernel that takes a single argument, specify "
102-
"the return type as void explicitly."
102+
"the return type as None explicitly."
103103
)
104104
return _kernel_dispatcher(func)
105105

@@ -132,13 +132,28 @@ def device_func(func_or_sig=None, **options):
132132
)
133133
options["_compilation_mode"] = CompilationMode.DEVICE_FUNC
134134

135+
func, sigs = _parse_func_or_sig(func_or_sig)
136+
for sig in sigs:
137+
if isinstance(sig, str):
138+
raise NotImplementedError(
139+
"Specifying signatures as string is not yet supported"
140+
)
141+
135142
def _kernel_dispatcher(pyfunc):
136-
return dispatcher(
143+
disp: SPIRVKernelDispatcher = dispatcher(
137144
pyfunc=pyfunc,
138145
targetoptions=options,
139146
)
140147

141-
if func_or_sig is None:
148+
if len(sigs) > 0:
149+
with typeinfer.register_dispatcher(disp):
150+
for sig in sigs:
151+
disp.compile(sig)
152+
disp.disable_compile()
153+
154+
return disp
155+
156+
if func is None:
142157
return _kernel_dispatcher
143158

144159
return _kernel_dispatcher(func_or_sig)

numba_dpex/tests/kernel_tests/test_func_specialization.py

Lines changed: 61 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -4,105 +4,88 @@
44

55
import dpnp
66
import numpy as np
7-
import pytest
7+
from numba import int32, int64
88

9-
import numba_dpex as dpex
10-
from numba_dpex import float32, int32
9+
import numba_dpex.experimental as dpex
1110

12-
single_signature = dpex.func(int32(int32))
13-
list_signature = dpex.func([int32(int32), float32(float32)])
11+
i32_signature = dpex.device_func(int32(int32))
12+
i32i64_signature = dpex.device_func([int32(int32), int64(int64)])
1413

1514
# Array size
16-
N = 10
15+
N = 1024
1716

1817

1918
def increment(a):
20-
return a + dpnp.float32(1)
19+
return a + 1
2120

2221

23-
def test_basic():
24-
"""Basic test with device func"""
22+
fi32 = i32_signature(increment)
23+
fi32i64 = i32i64_signature(increment)
2524

26-
f = dpex.func(increment)
2725

28-
def kernel_function(a, b):
29-
"""Kernel function that applies f() in parallel"""
30-
i = dpex.get_global_id(0)
31-
b[i] = f(a[i])
26+
@dpex.kernel
27+
def kernel_function(item, a, b):
28+
"""Kernel function that calls fi32()"""
29+
i = item.get_id(0)
30+
b[i] = fi32(a[i])
3231

33-
k = dpex.kernel(kernel_function)
3432

35-
a = dpnp.ones(N)
36-
b = dpnp.ones(N)
33+
@dpex.kernel
34+
def kernel_function2(item, a, b):
35+
"""Kernel function that calls fi32i64()"""
36+
i = item.get_id(0)
37+
b[i] = fi32i64(a[i])
3738

38-
dpex.call_kernel(k, dpex.Range(N), a, b)
3939

40-
assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1)
41-
42-
43-
def test_single_signature():
44-
"""Basic test with single signature"""
45-
46-
fi32 = single_signature(increment)
47-
48-
def kernel_function(a, b):
49-
"""Kernel function that applies fi32() in parallel"""
50-
i = dpex.get_global_id(0)
51-
b[i] = fi32(a[i])
52-
53-
k = dpex.kernel(kernel_function)
54-
55-
# Test with int32, should work
56-
a = dpnp.ones(N, dtype=dpnp.int32)
57-
b = dpnp.ones(N, dtype=dpnp.int32)
58-
59-
dpex.call_kernel(k, dpex.Range(N), a, b)
60-
61-
assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1)
62-
63-
# Test with int64, should fail
64-
a = dpnp.ones(N, dtype=dpnp.int64)
65-
b = dpnp.ones(N, dtype=dpnp.int64)
66-
67-
with pytest.raises(Exception) as e:
68-
dpex.call_kernel(k, dpex.Range(N), a, b)
69-
70-
assert " >>> <unknown function>(int64)" in e.value.args[0]
71-
72-
73-
def test_list_signature():
74-
"""Basic test with list signature"""
75-
76-
fi32f32 = list_signature(increment)
77-
78-
def kernel_function(a, b):
79-
"""Kernel function that applies fi32f32() in parallel"""
80-
i = dpex.get_global_id(0)
81-
b[i] = fi32f32(a[i])
82-
83-
k = dpex.kernel(kernel_function)
84-
85-
# Test with int32, should work
40+
def test_calling_specialized_device_func():
41+
"""Tests if a specialized device_func gets called as expected from kernel"""
8642
a = dpnp.ones(N, dtype=dpnp.int32)
87-
b = dpnp.ones(N, dtype=dpnp.int32)
43+
b = dpnp.zeros(N, dtype=dpnp.int32)
8844

89-
dpex.call_kernel(k, dpex.Range(N), a, b)
45+
dpex.call_kernel(kernel_function, dpex.Range(N), a, b)
9046

9147
assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1)
9248

93-
# Test with float32, should work
94-
a = dpnp.ones(N, dtype=dpnp.float32)
95-
b = dpnp.ones(N, dtype=dpnp.float32)
9649

97-
dpex.call_kernel(k, dpex.Range(N), a, b)
50+
def test_calling_specialized_device_func_wrong_signature():
51+
"""Tests that calling specialized signature with wrong signature does not
52+
trigger recompilation.
9853
99-
assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1)
54+
Tests kernel_function with float32. Numba will downcast float32 to int32
55+
and call the specialized function. The implicit casting is a problem, but
56+
for the purpose of this test case, all we care is to check if the
57+
specialized function was called and we did not recompiled the device_func.
58+
Refer: https://github.com/numba/numba/issues/9506
10059
60+
"""
10161
# Test with int64, should fail
102-
a = dpnp.ones(N, dtype=dpnp.int64)
103-
b = dpnp.ones(N, dtype=dpnp.int64)
104-
105-
with pytest.raises(Exception) as e:
106-
dpex.call_kernel(k, dpex.Range(N), a, b)
107-
108-
assert " >>> <unknown function>(int64)" in e.value.args[0]
62+
a = dpnp.full(N, 1.5, dtype=dpnp.float32)
63+
b = dpnp.zeros(N, dtype=dpnp.float32)
64+
65+
dpex.call_kernel(kernel_function, dpex.Range(N), a, b)
66+
67+
# Since Numba is calling the i32 specialization of increment, the values in
68+
# `a` are first down converted to int32, *i.e.*, 1.5 to 1 and then
69+
# incremented. Thus, the output is 2 instead of 2.5.
70+
# The implicit down casting is a dangerous thing for Numba to do, but we use
71+
# to our advantage to test if re compilation did not happen for a
72+
# specialized device function.
73+
assert np.all(dpnp.asnumpy(b) == 2)
74+
assert not np.all(dpnp.asnumpy(b) == 2.5)
75+
76+
77+
def test_multi_specialized_device_func():
78+
"""Tests if a device_func with multiple specialization can be called
79+
in a kernel
80+
"""
81+
# Test with int32, i64 should work
82+
ai32 = dpnp.ones(N, dtype=dpnp.int32)
83+
bi32 = dpnp.ones(N, dtype=dpnp.int32)
84+
ai64 = dpnp.ones(N, dtype=dpnp.int64)
85+
bi64 = dpnp.ones(N, dtype=dpnp.int64)
86+
87+
dpex.call_kernel(kernel_function2, dpex.Range(N), ai32, bi32)
88+
dpex.call_kernel(kernel_function2, dpex.Range(N), ai64, bi64)
89+
90+
assert np.array_equal(dpnp.asnumpy(bi32), dpnp.asnumpy(ai32) + 1)
91+
assert np.array_equal(dpnp.asnumpy(bi64), dpnp.asnumpy(ai64) + 1)

0 commit comments

Comments
 (0)