Skip to content

Commit d6c00e3

Browse files
committed
fix skip
1 parent 9022c90 commit d6c00e3

File tree

1 file changed

+38
-13
lines changed

1 file changed

+38
-13
lines changed

test/legacy_test/test_narrow.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,11 @@ def check_narrow_alias(input_tensor, output_tensor, dim, start):
6464
return is_alias
6565

6666

67-
@unittest.skipIf(paddle.device.get_device() == "xpu", "Skip on XPU")
67+
@unittest.skipIf(paddle.device.get_device().startswith("xpu"), "Skip on XPU")
6868
class TestNarrowBase(unittest.TestCase):
69+
@unittest.skipIf(
70+
paddle.device.get_device().startswith("xpu"), "Skip on XPU"
71+
)
6972
def setUp(self):
7073
self.input_np = np.array([1, 2, 3, 4, 5], dtype='float32')
7174
self.input_shape = self.input_np.shape
@@ -110,7 +113,9 @@ def check_dygraph_result(self, place):
110113
f"narrow should be an alias! input={input.numpy()}, result={result.numpy()}",
111114
)
112115

113-
@unittest.skipIf(paddle.device.get_device() == "xpu", "Skip on XPU")
116+
@unittest.skipIf(
117+
paddle.device.get_device().startswith("xpu"), "Skip on XPU"
118+
)
114119
def test_dygraph(self):
115120
for place in self.places:
116121
self.check_dygraph_result(place=place)
@@ -277,33 +282,41 @@ def setUp(self):
277282
# self.length = 0
278283

279284

280-
@unittest.skipIf(paddle.device.get_device() == "xpu", "Skip on XPU")
285+
@unittest.skipIf(paddle.device.get_device().startswith("xpu"), "Skip on XPU")
281286
class TestNarrowExtra(unittest.TestCase):
282-
@unittest.skipIf(paddle.device.get_device() == "xpu", "Skip on XPU")
287+
@unittest.skipIf(
288+
paddle.device.get_device().startswith("xpu"), "Skip on XPU"
289+
)
283290
def test_start_tensor(self):
284291
arr = np.arange(10, dtype='int64')
285292
x = paddle.to_tensor(arr)
286293
s = paddle.to_tensor(3, dtype='int64')
287294
out = paddle.narrow(x, dim=0, start=s, length=2)
288295
np.testing.assert_array_equal(out.numpy(), arr[3:5])
289296

290-
@unittest.skipIf(paddle.device.get_device() == "xpu", "Skip on XPU")
297+
@unittest.skipIf(
298+
paddle.device.get_device().startswith("xpu"), "Skip on XPU"
299+
)
291300
def test_start_tensor_wrong_dtype(self):
292301
arr = np.arange(10, dtype='float32')
293302
x = paddle.to_tensor(arr)
294303
s = paddle.to_tensor(3.1, dtype='float32')
295304
with self.assertRaises(AssertionError):
296305
paddle.narrow(x, dim=0, start=s, length=2)
297306

298-
@unittest.skipIf(paddle.device.get_device() == "xpu", "Skip on XPU")
307+
@unittest.skipIf(
308+
paddle.device.get_device().startswith("xpu"), "Skip on XPU"
309+
)
299310
def test_start_tensor_wrong_shape(self):
300311
arr = np.arange(10, dtype='float32')
301312
x = paddle.to_tensor(arr)
302313
s = paddle.to_tensor([1, 2], dtype='int64')
303314
with self.assertRaises(AssertionError):
304315
paddle.narrow(x, dim=0, start=s, length=2)
305316

306-
@unittest.skipIf(paddle.device.get_device() == "xpu", "Skip on XPU")
317+
@unittest.skipIf(
318+
paddle.device.get_device().startswith("xpu"), "Skip on XPU"
319+
)
307320
def test_dim_out_of_range(self):
308321
arr = np.arange(10)
309322
x = paddle.to_tensor(arr)
@@ -312,7 +325,9 @@ def test_dim_out_of_range(self):
312325
with self.assertRaises(IndexError):
313326
paddle.narrow(x, dim=-2, start=0, length=1)
314327

315-
@unittest.skipIf(paddle.device.get_device() == "xpu", "Skip on XPU")
328+
@unittest.skipIf(
329+
paddle.device.get_device().startswith("xpu"), "Skip on XPU"
330+
)
316331
def test_start_out_of_range(self):
317332
arr = np.arange(5)
318333
x = paddle.to_tensor(arr)
@@ -321,34 +336,44 @@ def test_start_out_of_range(self):
321336
with self.assertRaises(AssertionError):
322337
paddle.narrow(x, dim=0, start=-6, length=1)
323338

324-
@unittest.skipIf(paddle.device.get_device() == "xpu", "Skip on XPU")
339+
@unittest.skipIf(
340+
paddle.device.get_device().startswith("xpu"), "Skip on XPU"
341+
)
325342
def test_length_negative(self):
326343
arr = np.arange(5)
327344
x = paddle.to_tensor(arr)
328345
with self.assertRaises(AssertionError):
329346
paddle.narrow(x, dim=0, start=1, length=-1)
330347

331-
@unittest.skipIf(paddle.device.get_device() == "xpu", "Skip on XPU")
348+
@unittest.skipIf(
349+
paddle.device.get_device().startswith("xpu"), "Skip on XPU"
350+
)
332351
def test_0_dim_tensor(self):
333352
x = paddle.to_tensor(111)
334353
with self.assertRaises(AssertionError):
335354
paddle.narrow(x, dim=0, start=0, length=1)
336355

337-
@unittest.skipIf(paddle.device.get_device() == "xpu", "Skip on XPU")
356+
@unittest.skipIf(
357+
paddle.device.get_device().startswith("xpu"), "Skip on XPU"
358+
)
338359
def test_start_plus_length_overflow(self):
339360
arr = np.arange(5)
340361
x = paddle.to_tensor(arr)
341362
with self.assertRaises(AssertionError):
342363
paddle.narrow(x, dim=0, start=3, length=3)
343364

344-
@unittest.skipIf(paddle.device.get_device() == "xpu", "Skip on XPU")
365+
@unittest.skipIf(
366+
paddle.device.get_device().startswith("xpu"), "Skip on XPU"
367+
)
345368
def test_negative_start(self):
346369
arr = np.arange(8)
347370
x = paddle.to_tensor(arr)
348371
out = paddle.narrow(x, dim=0, start=-3, length=2)
349372
np.testing.assert_array_equal(out.numpy(), arr[5:7])
350373

351-
@unittest.skipIf(paddle.device.get_device() == "xpu", "Skip on XPU")
374+
@unittest.skipIf(
375+
paddle.device.get_device().startswith("xpu"), "Skip on XPU"
376+
)
352377
def test_negative_dim(self):
353378
arr = np.arange(12).reshape(3, 4)
354379
x = paddle.to_tensor(arr)

0 commit comments

Comments
 (0)