|
23 | 23 | convert_float_to_uint16, |
24 | 24 | get_device_place, |
25 | 25 | get_places, |
| 26 | + is_custom_device, |
26 | 27 | ) |
27 | 28 | from scipy.special import erf, expit |
28 | 29 | from utils import static_guard |
@@ -497,7 +498,8 @@ def init_shape(self): |
497 | 498 |
|
498 | 499 |
|
499 | 500 | @unittest.skipIf( |
500 | | - not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(), |
| 501 | + not (core.is_compiled_with_cuda() or is_custom_device()) |
| 502 | + or core.is_compiled_with_rocm(), |
501 | 503 | "core is not compiled with CUDA", |
502 | 504 | ) |
503 | 505 | class TestSigmoidBF16(OpTest): |
@@ -1765,7 +1767,8 @@ def init_dtype(self): |
1765 | 1767 |
|
1766 | 1768 |
|
1767 | 1769 | @unittest.skipIf( |
1768 | | - not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(), |
| 1770 | + not (core.is_compiled_with_cuda() or is_custom_device()) |
| 1771 | + or core.is_compiled_with_rocm(), |
1769 | 1772 | "core is not compiled with CUDA", |
1770 | 1773 | ) |
1771 | 1774 | class TestSqrtBF16(OpTest): |
@@ -2037,7 +2040,7 @@ def setUp(self): |
2037 | 2040 | self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)} |
2038 | 2041 | self.outputs = {'Out': out} |
2039 | 2042 | self.convert_input_output() |
2040 | | - if not core.is_compiled_with_cuda(): |
| 2043 | + if not (core.is_compiled_with_cuda() or is_custom_device()): |
2041 | 2044 | self.__class__.no_need_check_grad = True |
2042 | 2045 |
|
2043 | 2046 | def init_shape(self): |
@@ -2091,7 +2094,7 @@ def setUp(self): |
2091 | 2094 | self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)} |
2092 | 2095 | self.outputs = {'Out': out} |
2093 | 2096 | self.convert_input_output() |
2094 | | - if not core.is_compiled_with_cuda(): |
| 2097 | + if not (core.is_compiled_with_cuda() or is_custom_device()): |
2095 | 2098 | self.__class__.no_need_check_grad = True |
2096 | 2099 |
|
2097 | 2100 | def init_shape(self): |
@@ -4563,7 +4566,8 @@ def init_shape(self): |
4563 | 4566 |
|
4564 | 4567 |
|
4565 | 4568 | @unittest.skipIf( |
4566 | | - not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(), |
| 4569 | + not (core.is_compiled_with_cuda() or is_custom_device()) |
| 4570 | + or core.is_compiled_with_rocm(), |
4567 | 4571 | "core is not compiled with CUDA", |
4568 | 4572 | ) |
4569 | 4573 | class TestSquareBF16(OpTest): |
@@ -4917,7 +4921,8 @@ def init_shape(self): |
4917 | 4921 |
|
4918 | 4922 |
|
4919 | 4923 | @unittest.skipIf( |
4920 | | - not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(), |
| 4924 | + not (core.is_compiled_with_cuda() or is_custom_device()) |
| 4925 | + or core.is_compiled_with_rocm(), |
4921 | 4926 | "core is not compiled with CUDA", |
4922 | 4927 | ) |
4923 | 4928 | class TestSoftplusBF16(OpTest): |
@@ -5595,7 +5600,8 @@ def test_errors(self): |
5595 | 5600 | # ------------------ Test Cudnn Activation---------------------- |
5596 | 5601 | def create_test_act_cudnn_class(parent, atol=1e-3, grad_atol=1e-3): |
5597 | 5602 | @unittest.skipIf( |
5598 | | - not core.is_compiled_with_cuda(), "core is not compiled with CUDA" |
| 5603 | + not (core.is_compiled_with_cuda() or is_custom_device()), |
| 5604 | + "core is not compiled with CUDA", |
5599 | 5605 | ) |
5600 | 5606 | class TestActCudnn(parent): |
5601 | 5607 | def init_kernel_type(self): |
|
0 commit comments