Skip to content

Commit bd98f6b

Browse files
authored
Add support for nested test case names (#12)
1 parent e7eb74e commit bd98f6b

File tree

3 files changed

+116
-3
lines changed

3 files changed

+116
-3
lines changed

pytest_pytorch/plugin.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,86 @@
1+
import re
12
import unittest.mock
3+
from typing import Pattern
24

35
from _pytest.unittest import TestCaseFunction, UnitTestCase
46

7+
from torch.testing._internal.common_device_type import get_device_type_test_bases
58
from torch.testing._internal.common_utils import TestCase as PyTorchTestCaseTemplate
69

710

11+
class PytestPyTorchInternalError(Exception):
12+
def __init__(self, msg):
13+
super().__init__(
14+
f"{msg}\n"
15+
f"This is an internal error of the pytest plugin 'pytest-pytorch'."
16+
f"If you encounter this during normal operation, please file an issue "
17+
f"https://github.com/Quansight/pytest-pytorch/issues."
18+
)
19+
20+
21+
TEST_BASE_DEVICE_PATTERN = re.compile(r"(?P<device>\w*?)TestBase$")
22+
23+
24+
def _get_devices():
25+
devices = []
26+
for test_base in get_device_type_test_bases():
27+
match = TEST_BASE_DEVICE_PATTERN.match(test_base.__name__)
28+
if not match:
29+
raise PytestPyTorchInternalError(
30+
f"Unable to extract device name from {test_base.__name__}"
31+
)
32+
33+
devices.append(match.group("device"))
34+
35+
return devices
36+
37+
38+
DEVICES = _get_devices()
39+
40+
841
class TemplateName(str):
42+
_TEMPLATE_NAME_PATTERN: Pattern
43+
44+
def __init__(self, _):
45+
super().__init__()
46+
match = self._TEMPLATE_NAME_PATTERN.match(self)
47+
if not match:
48+
raise PytestPyTorchInternalError(
49+
f"Unable to extract template name from {self}"
50+
)
51+
self._template_name = match.group("template_name")
52+
953
def __eq__(self, other):
10-
return self.startswith(str(other))
54+
return str.__eq__(self, other) or str.__eq__(self._template_name, other)
1155

1256
def __hash__(self) -> int:
1357
return super().__hash__()
1458

1559

60+
class TestCaseFunctionTemplateName(TemplateName):
61+
_TEMPLATE_NAME_PATTERN = re.compile(
62+
fr"(?P<template_name>\w*?)_({'|'.join([device.lower() for device in DEVICES])})"
63+
)
64+
65+
1666
class PyTorchTestCaseFunction(TestCaseFunction):
1767
@classmethod
1868
def from_parent(cls, parent, *, name, **kw):
19-
return super().from_parent(parent, name=TemplateName(name), **kw)
69+
return super().from_parent(
70+
parent, name=TestCaseFunctionTemplateName(name), **kw
71+
)
72+
73+
74+
class TestCaseTemplateName(TemplateName):
75+
_TEMPLATE_NAME_PATTERN = re.compile(
76+
fr"(?P<template_name>\w*?)({'|'.join([device.upper() for device in DEVICES])})"
77+
)
2078

2179

2280
class PyTorchTestCase(UnitTestCase):
2381
@classmethod
2482
def from_parent(cls, parent, *, name, obj=None):
25-
return super().from_parent(parent, name=TemplateName(name), obj=obj)
83+
return super().from_parent(parent, name=TestCaseTemplateName(name), obj=obj)
2684

2785
def collect(self):
2886
# Yes, this is a bad practice. Unfortunately, there is no other option to

tests/assets/test_nested_names.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from torch.testing._internal.common_device_type import (
2+
instantiate_device_type_tests,
3+
onlyOn,
4+
)
5+
from torch.testing._internal.common_utils import TestCase
6+
7+
8+
class TestFoo(TestCase):
9+
# fails for meta, passes for cpu
10+
def test_baz(self, device):
11+
assert device != "meta"
12+
13+
14+
instantiate_device_type_tests(TestFoo, globals())
15+
16+
17+
class TestFooBar(TestCase):
18+
# passes for meta, skips for cpu
19+
@onlyOn("meta")
20+
def test_baz(self, device):
21+
assert True
22+
23+
24+
instantiate_device_type_tests(TestFooBar, globals())

tests/test_plugin.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,34 @@ def test_dtype(testdir, file, cmds, outcomes):
180180
testdir.copy_example(file)
181181
result = testdir.runpytest(*cmds)
182182
result.assert_outcomes(**outcomes)
183+
184+
185+
@make_params(
186+
"test_nested_names.py",
187+
Config(
188+
"*testcase-*test-*device",
189+
new_cmds=(),
190+
legacy_cmds=(),
191+
passed=2,
192+
skipped=1,
193+
failed=1,
194+
),
195+
Config(
196+
"1testcase1-*test-*device",
197+
new_cmds="::TestFoo",
198+
legacy_cmds=("-k", "TestFoo and not TestFooBar"),
199+
passed=1,
200+
failed=1,
201+
),
202+
Config(
203+
"1testcase2-*test-*device",
204+
new_cmds="::TestFooBar",
205+
legacy_cmds=("-k", "TestFooBar"),
206+
passed=1,
207+
skipped=1,
208+
),
209+
)
210+
def test_nested_names(testdir, file, cmds, outcomes):
211+
testdir.copy_example(file)
212+
result = testdir.runpytest(*cmds)
213+
result.assert_outcomes(**outcomes)

0 commit comments

Comments
 (0)