From 76c742f94a5f9a3489dc16bcb6dbf2989fe6cacb Mon Sep 17 00:00:00 2001 From: Denis Vieriu Date: Mon, 18 Jul 2022 14:31:28 -0700 Subject: [PATCH] Add back support for PYTORCH_TEST_WITH_MPS --- torch/testing/_internal/common_device_type.py | 8 +++----- torch/testing/_internal/common_utils.py | 1 + 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index c48890ea651ff..1b44cfe6ffe76 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -14,7 +14,7 @@ skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \ IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, DeterministicGuard, \ _TestParametrizer, compose_parametrize_fns, dtype_name, \ - TEST_WITH_MIOPEN_SUGGEST_NHWC, NATIVE_DEVICES + TEST_WITH_MIOPEN_SUGGEST_NHWC, NATIVE_DEVICES, TEST_WITH_MPS from torch.testing._internal.common_cuda import _get_torch_cuda_version, TEST_CUSPARSE_GENERIC from torch.testing._internal.common_dtype import get_all_dtypes @@ -530,10 +530,8 @@ def get_device_type_test_bases(): test_bases.append(CPUTestBase) if torch.cuda.is_available(): test_bases.append(CUDATestBase) - # Disable MPS testing in generic device testing temporarily while we're - # ramping up support. - # elif torch.backends.mps.is_available(): - # test_bases.append(MPSTestBase) + elif torch.backends.mps.is_available(): + test_bases.append(MPSTestBase) return test_bases diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 92a7f31470987..e0cc85827eafc 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -787,6 +787,7 @@ def _check_module_exists(name: str) -> bool: TEST_WITH_TSAN = os.getenv('PYTORCH_TEST_WITH_TSAN', '0') == '1' TEST_WITH_UBSAN = os.getenv('PYTORCH_TEST_WITH_UBSAN', '0') == '1' TEST_WITH_ROCM = os.getenv('PYTORCH_TEST_WITH_ROCM', '0') == '1' +TEST_WITH_MPS = os.getenv('PYTORCH_TEST_WITH_MPS', '0') == '1' # TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen # See #64427