|  | 
| 4 | 4 | # This source code is licensed under the BSD 3-Clause license found in the | 
| 5 | 5 | # LICENSE file in the root directory of this source tree. | 
| 6 | 6 | import unittest | 
| 7 |  | -import warnings | 
| 8 | 7 | from unittest.mock import patch | 
| 9 | 8 | 
 | 
| 10 | 9 | import torch | 
| @@ -37,55 +36,6 @@ def test_torch_version_at_least(self): | 
| 37 | 36 |                     f"Failed for torch.__version__={torch_version}, comparing with {compare_version}", | 
| 38 | 37 |                 ) | 
| 39 | 38 | 
 | 
| 40 |  | -    def test_torch_version_deprecation(self): | 
| 41 |  | -        """ | 
| 42 |  | -        Test that TORCH_VERSION_AT_LEAST* and TORCH_VERSION_AFTER* | 
| 43 |  | -        trigger deprecation warnings on use, not on import. | 
| 44 |  | -        """ | 
| 45 |  | -        # Reset deprecation warning state, otherwise we won't log warnings here | 
| 46 |  | -        warnings.resetwarnings() | 
| 47 |  | - | 
| 48 |  | -        # Importing and referencing should not trigger deprecation warning | 
| 49 |  | -        with warnings.catch_warnings(record=True) as _warnings: | 
| 50 |  | -            from torchao.utils import ( | 
| 51 |  | -                TORCH_VERSION_AFTER_2_2, | 
| 52 |  | -                TORCH_VERSION_AFTER_2_3, | 
| 53 |  | -                TORCH_VERSION_AFTER_2_4, | 
| 54 |  | -                TORCH_VERSION_AFTER_2_5, | 
| 55 |  | -                TORCH_VERSION_AT_LEAST_2_2, | 
| 56 |  | -                TORCH_VERSION_AT_LEAST_2_3, | 
| 57 |  | -                TORCH_VERSION_AT_LEAST_2_4, | 
| 58 |  | -                TORCH_VERSION_AT_LEAST_2_5, | 
| 59 |  | -                TORCH_VERSION_AT_LEAST_2_6, | 
| 60 |  | -                TORCH_VERSION_AT_LEAST_2_7, | 
| 61 |  | -                TORCH_VERSION_AT_LEAST_2_8, | 
| 62 |  | -            ) | 
| 63 |  | - | 
| 64 |  | -            deprecated_api_to_name = [ | 
| 65 |  | -                (TORCH_VERSION_AT_LEAST_2_8, "TORCH_VERSION_AT_LEAST_2_8"), | 
| 66 |  | -                (TORCH_VERSION_AT_LEAST_2_7, "TORCH_VERSION_AT_LEAST_2_7"), | 
| 67 |  | -                (TORCH_VERSION_AT_LEAST_2_6, "TORCH_VERSION_AT_LEAST_2_6"), | 
| 68 |  | -                (TORCH_VERSION_AT_LEAST_2_5, "TORCH_VERSION_AT_LEAST_2_5"), | 
| 69 |  | -                (TORCH_VERSION_AT_LEAST_2_4, "TORCH_VERSION_AT_LEAST_2_4"), | 
| 70 |  | -                (TORCH_VERSION_AT_LEAST_2_3, "TORCH_VERSION_AT_LEAST_2_3"), | 
| 71 |  | -                (TORCH_VERSION_AT_LEAST_2_2, "TORCH_VERSION_AT_LEAST_2_2"), | 
| 72 |  | -                (TORCH_VERSION_AFTER_2_5, "TORCH_VERSION_AFTER_2_5"), | 
| 73 |  | -                (TORCH_VERSION_AFTER_2_4, "TORCH_VERSION_AFTER_2_4"), | 
| 74 |  | -                (TORCH_VERSION_AFTER_2_3, "TORCH_VERSION_AFTER_2_3"), | 
| 75 |  | -                (TORCH_VERSION_AFTER_2_2, "TORCH_VERSION_AFTER_2_2"), | 
| 76 |  | -            ] | 
| 77 |  | -            self.assertEqual(len(_warnings), 0) | 
| 78 |  | - | 
| 79 |  | -        # Accessing the boolean value should trigger deprecation warning | 
| 80 |  | -        with warnings.catch_warnings(record=True) as _warnings: | 
| 81 |  | -            for api, name in deprecated_api_to_name: | 
| 82 |  | -                num_warnings_before = len(_warnings) | 
| 83 |  | -                if api: | 
| 84 |  | -                    pass | 
| 85 |  | -                regex = f"{name} is deprecated and will be removed" | 
| 86 |  | -                self.assertEqual(len(_warnings), num_warnings_before + 1) | 
| 87 |  | -                self.assertIn(regex, str(_warnings[-1].message)) | 
| 88 |  | - | 
| 89 | 39 | 
 | 
| 90 | 40 | class TestTorchAOBaseTensor(unittest.TestCase): | 
| 91 | 41 |     def test_print_arg_types(self): | 
|  | 
0 commit comments