99import torch
1010from torch .testing ._internal .common_device_type import (
1111 instantiate_device_type_tests , onlyCUDA , toleranceOverride , tol , skipMeta )
12+ from torch .testing ._internal .common_dtype import get_all_dtypes
1213from torch .testing ._internal .common_modules import module_db , modules , TrainEvalMode
1314from torch .testing ._internal .common_utils import (
1415 TestCase , run_tests , freeze_rng_state , mock_wrapper , get_tensors_from , gradcheck ,
15- gradgradcheck , skipIfMps , skipIfTorchInductor )
16+ gradgradcheck , skipIfTorchInductor )
1617from unittest .mock import patch , call
1718
19+ MPS_DTYPES = get_all_dtypes ()
20+ for t in [torch .double , torch .cdouble , torch .cfloat , torch .int8 , torch .bfloat16 ]:
21+ del MPS_DTYPES [MPS_DTYPES .index (t )]
22+
23+ def _get_mps_error_msg (device , dtype , op , mps_blocklist ):
24+ if torch .backends .mps .is_available () and device == "mps" and dtype not in MPS_DTYPES :
25+ return f"MPS doesn't support { str (dtype )} datatype"
26+ if op .name .startswith (tuple (mps_blocklist )):
27+ return "MPS doesn't support op " + str (op .name )
28+ return None
1829
1930class TestModule (TestCase ):
2031 _do_cuda_memory_leak_check = True
@@ -32,7 +43,8 @@ def _assert_module_parameters_and_buffer_are(self, module, device, dtype):
3243 def _check_module (items , name , device = device , dtype = dtype ):
3344 for item_name , item in items :
3445 self .assertEqual (
35- item .device , device ,
46+ # workaround for the tests checking the device (mps:0 with mps)
47+ item .device .type , device .type ,
3648 f'{ name } { item_name } is on device { item .device } instead of the expected device { device } ' )
3749 if item .dtype .is_floating_point :
3850 self .assertEqual (
@@ -41,9 +53,16 @@ def _check_module(items, name, device=device, dtype=dtype):
4153 _check_module (module .named_parameters (), "Parameter" )
4254 _check_module (module .named_buffers (), "Buffer" )
4355
44- @skipIfMps # the test doesn't work on MPS as double types are not supported
4556 @modules (module_db )
4657 def test_forward (self , device , dtype , module_info , training ):
58+ MPS_BLOCKLIST = [
59+ "nn.LSTM" # segfault
60+ ]
61+
62+ msg = _get_mps_error_msg (device , dtype , module_info , MPS_BLOCKLIST )
63+ if msg is not None :
64+ self .skipTest (msg )
65+
4766 module_cls = module_info .module_cls
4867 module_inputs = module_info .module_inputs_func (module_info , device = device , dtype = dtype ,
4968 requires_grad = False , training = training )
@@ -83,6 +102,10 @@ def test_forward(self, device, dtype, module_info, training):
83102 # They should be applied to any created parameters and buffers.
84103 @modules (module_db )
85104 def test_factory_kwargs (self , device , dtype , module_info , training ):
105+ msg = _get_mps_error_msg (device , dtype , module_info , [])
106+ if msg is not None :
107+ self .skipTest (msg )
108+
86109 module_cls = module_info .module_cls
87110 module_inputs = module_info .module_inputs_func (module_info , device = device , dtype = dtype ,
88111 requires_grad = False , training = training )
@@ -197,6 +220,11 @@ def _to_device1(objs):
197220 @modules (module_db )
198221 def test_repr (self , device , dtype , module_info , training ):
199222 # Test module can be represented with repr and str without errors.
223+
224+ msg = _get_mps_error_msg (device , dtype , module_info , [])
225+ if msg is not None :
226+ self .skipTest (msg )
227+
200228 module_cls = module_info .module_cls
201229 module_inputs = module_info .module_inputs_func (module_info , device = device , dtype = dtype ,
202230 requires_grad = False , training = training )
@@ -210,10 +238,19 @@ def test_repr(self, device, dtype, module_info, training):
210238 m .__repr__ ()
211239 str (m )
212240
213- @skipIfMps
214241 @modules (module_db )
215242 def test_pickle (self , device , dtype , module_info , training ):
216243 # Test that module can be pickled and unpickled.
244+
245+ MPS_BLOCKLIST = [
246+ "nn.LSTM" # hard crash
247+ ]
248+
249+ msg = _get_mps_error_msg (device , dtype , module_info , MPS_BLOCKLIST )
250+ if msg is not None :
251+ self .skipTest (msg )
252+
253+
217254 module_cls = module_info .module_cls
218255 module_inputs = module_info .module_inputs_func (module_info , device = device , dtype = dtype ,
219256 requires_grad = False , training = training )
@@ -248,6 +285,15 @@ def test_pickle(self, device, dtype, module_info, training):
248285 def test_check_inplace (self , device , dtype , module_info , training ):
249286 # Check if the inplace variant of the module gives the same result as the out of place
250287 # variant.
288+
289+ MPS_BLOCKLIST = [
290+ "nn.ELU" # hard crash
291+ ]
292+
293+ msg = _get_mps_error_msg (device , dtype , module_info , MPS_BLOCKLIST )
294+ if msg is not None :
295+ self .skipTest (msg )
296+
251297 module_cls = module_info .module_cls
252298 module_inputs = module_info .module_inputs_func (module_info , device = device , dtype = dtype ,
253299 requires_grad = True , training = training )
@@ -325,11 +371,21 @@ def inner_zero_grad(obj):
325371 obj .grad = None
326372 self ._traverse_obj (obj , inner_zero_grad )
327373
328- @skipIfMps
329374 @modules (module_db )
330375 @skipIfTorchInductor ("to be fixed" )
331376 def test_non_contiguous_tensors (self , device , dtype , module_info , training ):
332377 # Check modules work with non-contiguous tensors
378+ MPS_BLOCKLIST = [
379+ # hard crashes
380+ "nn.GRU" ,
381+ "nn.LSTM" ,
382+ "nn.RNN"
383+ ]
384+
385+ msg = _get_mps_error_msg (device , dtype , module_info , MPS_BLOCKLIST )
386+ if msg is not None :
387+ self .skipTest (msg )
388+
333389
334390 module_cls = module_info .module_cls
335391 module_inputs = module_info .module_inputs_func (module_info , device = device , dtype = dtype ,
@@ -580,10 +636,18 @@ def check_backward(cpu_output, gpu_output):
580636 for cpu_output , gpu_output in zip (flatten_cpu_outputs , flatten_gpu_outputs ):
581637 check_backward (cpu_output , gpu_output )
582638
583- @skipIfMps
584639 @modules (module_db )
585640 @skipIfTorchInductor ("to be fixed" )
586641 def test_memory_format (self , device , dtype , module_info , training ):
642+ MPS_BLOCKLIST = [
643+ "nn.BatchNorm3d" , # failed assert
644+ "nn.LSTM" , # segfault
645+ ]
646+
647+ msg = _get_mps_error_msg (device , dtype , module_info , MPS_BLOCKLIST )
648+ if msg is not None :
649+ self .skipTest (msg )
650+
587651 is_sm86 = device .startswith ("cuda" ) and torch .cuda .get_device_capability (0 ) == (8 , 6 )
588652 # TODO tighten it to a specific module
589653 atol , rtol = (3e-3 , 7e-3 ) if is_sm86 else (None , None )
@@ -680,9 +744,12 @@ def inner_check_out_mem_format(output):
680744
681745 # Test whether train and eval modes differ for each module. Use to verify
682746 # that the ModuleInfo entry flag is correct.
683- @skipIfMps # the test doesn't work on MPS as double types are not supported
684747 @modules (module_db , train_eval_mode = TrainEvalMode .train_only )
685748 def test_if_train_and_eval_modes_differ (self , device , dtype , module_info , training ):
749+ msg = _get_mps_error_msg (device , dtype , module_info , [])
750+ if msg is not None :
751+ self .skipTest (msg )
752+
686753 module_cls = module_info .module_cls
687754 module_inputs = module_info .module_inputs_func (module_info , device = device , dtype = dtype ,
688755 requires_grad = False , training = training )
0 commit comments