Skip to content

Commit 737ed0d

Browse files
Enable test modules on MPS and CI runners (#305)
* Enable test modules on MPS and CI runners * Update lint.yml * Update comments * Retrigger CI * Retrigger CI #2 * Remove comment
1 parent bfc669e commit 737ed0d

File tree

3 files changed

+89
-8
lines changed

3 files changed

+89
-8
lines changed

.github/workflows/_mac-test-mps.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,20 @@ jobs:
8282
8383
${CONDA_RUN} python3 test/run_test.py --mps --verbose
8484
85+
- name: Run MPS Test Modules
86+
id: test_2
87+
env:
88+
ENV_NAME: conda-test-env-${{ github.run_id }}
89+
shell: arch -arch arm64 bash {0}
90+
# During bring up of test_modules don't show this as an error.
91+
continue-on-error: true
92+
run: |
93+
# shellcheck disable=SC1090
94+
set -ex
95+
# TODO(https://github.com/pytorch/pytorch/issues/79293)
96+
97+
${CONDA_RUN} python3 test/test_modules.py -k mps --verbose
98+
8599
- name: Print remaining test logs
86100
shell: bash
87101
if: always()

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ jobs:
7070
# shellcheck disable=SC1090
7171
set -ex
7272
set +e
73-
if ! ${CONDA_RUN} lintrunner --force-color aten/src/ATen/native/mps/operations/* test/test_mps.py; then
73+
if ! ${CONDA_RUN} lintrunner --force-color aten/src/ATen/native/mps/operations/* test/test_mps.py test/test_modules.py; then
7474
echo ""
7575
echo -e "\e[1m\e[36mYou can reproduce these results locally by using \`lintrunner\`.\e[0m"
7676
echo -e "\e[1m\e[36mSee https://github.com/pytorch/pytorch/wiki/lintrunner for setup instructions.\e[0m"

test/test_modules.py

Lines changed: 74 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,23 @@
99
import torch
1010
from torch.testing._internal.common_device_type import (
1111
instantiate_device_type_tests, onlyCUDA, toleranceOverride, tol, skipMeta)
12+
from torch.testing._internal.common_dtype import get_all_dtypes
1213
from torch.testing._internal.common_modules import module_db, modules, TrainEvalMode
1314
from torch.testing._internal.common_utils import (
1415
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck,
15-
gradgradcheck, skipIfMps, skipIfTorchInductor)
16+
gradgradcheck, skipIfTorchInductor)
1617
from unittest.mock import patch, call
1718

19+
MPS_DTYPES = get_all_dtypes()
20+
for t in [torch.double, torch.cdouble, torch.cfloat, torch.int8, torch.bfloat16]:
21+
del MPS_DTYPES[MPS_DTYPES.index(t)]
22+
23+
def _get_mps_error_msg(device, dtype, op, mps_blocklist):
24+
if torch.backends.mps.is_available() and device == "mps" and dtype not in MPS_DTYPES:
25+
return f"MPS doesn't support {str(dtype)} datatype"
26+
if op.name.startswith(tuple(mps_blocklist)):
27+
return "MPS doesn't support op " + str(op.name)
28+
return None
1829

1930
class TestModule(TestCase):
2031
_do_cuda_memory_leak_check = True
@@ -32,7 +43,8 @@ def _assert_module_parameters_and_buffer_are(self, module, device, dtype):
3243
def _check_module(items, name, device=device, dtype=dtype):
3344
for item_name, item in items:
3445
self.assertEqual(
35-
item.device, device,
46+
# workaround for the tests checking the device (mps:0 with mps)
47+
item.device.type, device.type,
3648
f'{name} {item_name} is on device {item.device} instead of the expected device {device}')
3749
if item.dtype.is_floating_point:
3850
self.assertEqual(
@@ -41,9 +53,16 @@ def _check_module(items, name, device=device, dtype=dtype):
4153
_check_module(module.named_parameters(), "Parameter")
4254
_check_module(module.named_buffers(), "Buffer")
4355

44-
@skipIfMps # the test doesn't work on MPS as double types are not supported
4556
@modules(module_db)
4657
def test_forward(self, device, dtype, module_info, training):
58+
MPS_BLOCKLIST = [
59+
"nn.LSTM" # segfault
60+
]
61+
62+
msg = _get_mps_error_msg(device, dtype, module_info, MPS_BLOCKLIST)
63+
if msg is not None:
64+
self.skipTest(msg)
65+
4766
module_cls = module_info.module_cls
4867
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
4968
requires_grad=False, training=training)
@@ -83,6 +102,10 @@ def test_forward(self, device, dtype, module_info, training):
83102
# They should be applied to any created parameters and buffers.
84103
@modules(module_db)
85104
def test_factory_kwargs(self, device, dtype, module_info, training):
105+
msg = _get_mps_error_msg(device, dtype, module_info, [])
106+
if msg is not None:
107+
self.skipTest(msg)
108+
86109
module_cls = module_info.module_cls
87110
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
88111
requires_grad=False, training=training)
@@ -197,6 +220,11 @@ def _to_device1(objs):
197220
@modules(module_db)
198221
def test_repr(self, device, dtype, module_info, training):
199222
# Test module can be represented with repr and str without errors.
223+
224+
msg = _get_mps_error_msg(device, dtype, module_info, [])
225+
if msg is not None:
226+
self.skipTest(msg)
227+
200228
module_cls = module_info.module_cls
201229
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
202230
requires_grad=False, training=training)
@@ -210,10 +238,19 @@ def test_repr(self, device, dtype, module_info, training):
210238
m.__repr__()
211239
str(m)
212240

213-
@skipIfMps
214241
@modules(module_db)
215242
def test_pickle(self, device, dtype, module_info, training):
216243
# Test that module can be pickled and unpickled.
244+
245+
MPS_BLOCKLIST = [
246+
"nn.LSTM" # hard crash
247+
]
248+
249+
msg = _get_mps_error_msg(device, dtype, module_info, MPS_BLOCKLIST)
250+
if msg is not None:
251+
self.skipTest(msg)
252+
253+
217254
module_cls = module_info.module_cls
218255
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
219256
requires_grad=False, training=training)
@@ -248,6 +285,15 @@ def test_pickle(self, device, dtype, module_info, training):
248285
def test_check_inplace(self, device, dtype, module_info, training):
249286
# Check if the inplace variant of the module gives the same result as the out of place
250287
# variant.
288+
289+
MPS_BLOCKLIST = [
290+
"nn.ELU" # hard crash
291+
]
292+
293+
msg = _get_mps_error_msg(device, dtype, module_info, MPS_BLOCKLIST)
294+
if msg is not None:
295+
self.skipTest(msg)
296+
251297
module_cls = module_info.module_cls
252298
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
253299
requires_grad=True, training=training)
@@ -325,11 +371,21 @@ def inner_zero_grad(obj):
325371
obj.grad = None
326372
self._traverse_obj(obj, inner_zero_grad)
327373

328-
@skipIfMps
329374
@modules(module_db)
330375
@skipIfTorchInductor("to be fixed")
331376
def test_non_contiguous_tensors(self, device, dtype, module_info, training):
332377
# Check modules work with non-contiguous tensors
378+
MPS_BLOCKLIST = [
379+
# hard crashes
380+
"nn.GRU",
381+
"nn.LSTM",
382+
"nn.RNN"
383+
]
384+
385+
msg = _get_mps_error_msg(device, dtype, module_info, MPS_BLOCKLIST)
386+
if msg is not None:
387+
self.skipTest(msg)
388+
333389

334390
module_cls = module_info.module_cls
335391
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
@@ -580,10 +636,18 @@ def check_backward(cpu_output, gpu_output):
580636
for cpu_output, gpu_output in zip(flatten_cpu_outputs, flatten_gpu_outputs):
581637
check_backward(cpu_output, gpu_output)
582638

583-
@skipIfMps
584639
@modules(module_db)
585640
@skipIfTorchInductor("to be fixed")
586641
def test_memory_format(self, device, dtype, module_info, training):
642+
MPS_BLOCKLIST = [
643+
"nn.BatchNorm3d", # failed assert
644+
"nn.LSTM", # segfault
645+
]
646+
647+
msg = _get_mps_error_msg(device, dtype, module_info, MPS_BLOCKLIST)
648+
if msg is not None:
649+
self.skipTest(msg)
650+
587651
is_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability(0) == (8, 6)
588652
# TODO tighten it to a specific module
589653
atol, rtol = (3e-3, 7e-3) if is_sm86 else (None, None)
@@ -680,9 +744,12 @@ def inner_check_out_mem_format(output):
680744

681745
# Test whether train and eval modes differ for each module. Use to verify
682746
# that the ModuleInfo entry flag is correct.
683-
@skipIfMps # the test doesn't work on MPS as double types are not supported
684747
@modules(module_db, train_eval_mode=TrainEvalMode.train_only)
685748
def test_if_train_and_eval_modes_differ(self, device, dtype, module_info, training):
749+
msg = _get_mps_error_msg(device, dtype, module_info, [])
750+
if msg is not None:
751+
self.skipTest(msg)
752+
686753
module_cls = module_info.module_cls
687754
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
688755
requires_grad=False, training=training)

0 commit comments

Comments
 (0)