Skip to content

Commit 6a7fbc1

Browse files
authored
Controllable fallback (#40)
Add configuration variable `NUMBA_DPPY_FALLBACK_ON_CPU`
1 parent c5c3381 commit 6a7fbc1

File tree

3 files changed

+76
-1
lines changed

3 files changed

+76
-1
lines changed

numba_dppy/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,5 @@ def _readenv(): ...
3434

3535
# Turn SPIRV-VALIDATION ON/OFF switch
3636
SPIRV_VAL = _readenv("NUMBA_DPPY_SPIRV_VAL", int, 0)
37+
38+
FALLBACK_ON_CPU = _readenv("NUMBA_DPPY_FALLBACK_ON_CPU", int, 1)

numba_dppy/dppy_lowerer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1171,7 +1171,8 @@ def lower(self):
11711171
if numba_dppy.compiler.DEBUG:
11721172
print("Failed to lower parfor on DPPY-device. Due to:\n", e)
11731173
lowering.lower_extensions[parfor.Parfor].pop()
1174-
if (lowering.lower_extensions[parfor.Parfor][-1] == numba.parfors.parfor_lowering._lower_parfor_parallel):
1174+
if ((lowering.lower_extensions[parfor.Parfor][-1] == numba.parfors.parfor_lowering._lower_parfor_parallel) and
1175+
numba_dppy.config.FALLBACK_ON_CPU == 1):
11751176
self.cpu_lower.lower()
11761177
self.base_lower = self.cpu_lower
11771178
else:
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import numpy as np
2+
3+
import numba
4+
import numba_dppy
5+
from numba_dppy.testing import unittest
6+
from numba_dppy.testing import DPPYTestCase
7+
from numba.tests.support import captured_stderr
8+
import dpctl
9+
10+
11+
@unittest.skipUnless(dpctl.has_gpu_queues(), 'test only on GPU system')
12+
class TestDPPYFallback(DPPYTestCase):
13+
def test_dppy_fallback_true(self):
14+
@numba.jit
15+
def fill_value(i):
16+
return i
17+
18+
def inner_call_fallback():
19+
x = 10
20+
a = np.empty(shape=x, dtype=np.float32)
21+
22+
for i in numba.prange(x):
23+
a[i] = fill_value(i)
24+
25+
return a
26+
27+
numba_dppy.compiler.DEBUG = 1
28+
with captured_stderr() as msg_fallback_true:
29+
with dpctl.device_context("opencl:gpu") as gpu_queue:
30+
dppy = numba.njit(parallel=True)(inner_call_fallback)
31+
dppy_fallback_true = dppy()
32+
33+
ref_result = inner_call_fallback()
34+
numba_dppy.compiler.DEBUG = 0
35+
36+
np.testing.assert_array_equal(dppy_fallback_true, ref_result)
37+
self.assertTrue('Failed to lower parfor on DPPY-device' in msg_fallback_true.getvalue())
38+
39+
@unittest.expectedFailure
40+
def test_dppy_fallback_false(self):
41+
@numba.jit
42+
def fill_value(i):
43+
return i
44+
45+
def inner_call_fallback():
46+
x = 10
47+
a = np.empty(shape=x, dtype=np.float32)
48+
49+
for i in numba.prange(x):
50+
a[i] = fill_value(i)
51+
52+
return a
53+
54+
try:
55+
numba_dppy.compiler.DEBUG = 1
56+
numba_dppy.config.FALLBACK_ON_CPU = 0
57+
with captured_stderr() as msg_fallback_true:
58+
with dpctl.device_context("opencl:gpu") as gpu_queue:
59+
dppy = numba.njit(parallel=True)(inner_call_fallback)
60+
dppy_fallback_false = dppy()
61+
62+
finally:
63+
ref_result = inner_call_fallback()
64+
numba_dppy.config.FALLBACK_ON_CPU = 1
65+
numba_dppy.compiler.DEBUG = 0
66+
67+
not np.testing.assert_array_equal(dppy_fallback_false, ref_result)
68+
not self.assertTrue('Failed to lower parfor on DPPY-device' in msg_fallback_true.getvalue())
69+
70+
71+
if __name__ == '__main__':
72+
unittest.main()

0 commit comments

Comments
 (0)