|
12 | 12 | from unittest.mock import MagicMock, patch
|
13 | 13 |
|
14 | 14 | import torch
|
| 15 | + |
| 16 | +from parameterized import parameterized |
15 | 17 | from torch.autograd import Variable
|
16 | 18 | from torch.distributed import ProcessGroup
|
17 | 19 | from torch.distributed.tensor import distribute_tensor, DTensor, init_device_mesh, Shard
|
18 |
| -from torch.testing._internal.common_utils import ( |
19 |
| - instantiate_parametrized_tests, |
20 |
| - parametrize, |
21 |
| -) |
22 | 20 | from torch.testing._internal.distributed._tensor.common_dtensor import (
|
23 | 21 | DTensorTestBase,
|
24 | 22 | with_comms,
|
25 | 23 | )
|
| 24 | + |
| 25 | +setattr(torch.backends, "__allow_nonbracketed_mutation_flag", True) # noqa: B010 |
| 26 | + |
26 | 27 | from torchrec.optim.clipping import GradientClipping, GradientClippingOptimizer
|
27 | 28 | from torchrec.optim.test_utils import DummyKeyedOptimizer
|
28 | 29 |
|
@@ -243,17 +244,24 @@ def test_clip_no_gradients_norm_meta_device(
|
243 | 244 |
|
244 | 245 |
|
245 | 246 | @unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available")
|
246 |
| -@instantiate_parametrized_tests |
247 | 247 | class TestGradientClippingDTensor(DTensorTestBase):
|
248 | 248 | """No tests for Replicated DTensors as handled prior to GradientClippingOptimizer"""
|
249 | 249 |
|
| 250 | + def setUp(self) -> None: |
| 251 | + setattr(torch.backends, "__allow_nonbracketed_mutation_flag", False) # noqa: B010 |
| 252 | + return super().setUp() |
| 253 | + |
| 254 | + def tearDown(self) -> None: |
| 255 | + setattr(torch.backends, "__allow_nonbracketed_mutation_flag", True) # noqa: B010 |
| 256 | + return super().tearDown() |
| 257 | + |
250 | 258 | def _get_params_to_pg(
|
251 | 259 | self, params: List[DTensor]
|
252 | 260 | ) -> Dict[DTensor, List[ProcessGroup]]:
|
253 | 261 | return {param: [param.device_mesh.get_group()] for param in params}
|
254 | 262 |
|
| 263 | + @parameterized.expand(["inf", 1, 2]) |
255 | 264 | @with_comms
|
256 |
| - @parametrize("norm_type", ("inf", 1, 2)) |
257 | 265 | def test_tensor_and_sharded_dtensor_clip_all_gradients_norm(
|
258 | 266 | self, norm_type: Union[float, str]
|
259 | 267 | ) -> None:
|
@@ -342,8 +350,8 @@ def test_tensor_and_sharded_dtensor_clip_all_gradients_norm(
|
342 | 350 | f"Expect gradient to be the same. However, found {param_grad=}, {ref_param.grad=}",
|
343 | 351 | )
|
344 | 352 |
|
| 353 | + @parameterized.expand(["inf", 1, 2]) |
345 | 354 | @with_comms
|
346 |
| - @parametrize("norm_type", ("inf", 1, 2)) |
347 | 355 | def test_multiple_sharded_dtensors_clip_all_gradients_norm(
|
348 | 356 | self, norm_type: Union[float, str]
|
349 | 357 | ) -> None:
|
|
0 commit comments