Skip to content

Commit 9396aa7

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
fix RuntimeError: not allowed to set torch.backends.cudnn flags after disable_global_flags (#3343)
Summary: # context * after fix github CI workflow (GPU unit tests) we found lots of errors come from the same root cause: ``` torchrec/test_utils/__init__.py:129: in _wrapper return wrapped_func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ torchrec/distributed/test_utils/multi_process.py:131: in setUp torch.backends.cudnn.allow_tf32 = False ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ self = <torch.backends.ContextProp object at 0x7f4e8bb3ba10> obj = <module 'torch.backends.cudnn' from '/opt/conda/envs/build_binary/lib/python3.11/site-packages/torch/backends/cudnn/__init__.py'> val = False def __set__(self, obj, val): if not flags_frozen(): self.setter(val) else: > raise RuntimeError( f"not allowed to set {obj.__name__} flags " "after disable_global_flags; please use flags() context manager instead" ) E RuntimeError: not allowed to set torch.backends.cudnn flags after disable_global_flags; please use flags() context manager instead ``` * according to D77758554, the issue is due to D78326114 introducing `torch.testing._internal.common_utils` ``` # torch/testing/_internal/common_utils.py calls `disable_global_flags()` # workaround RuntimeError: not allowed to set ... after disable_global_flags ``` Differential Revision: D81529616
1 parent 0d67c37 commit 9396aa7

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

torchrec/distributed/test_utils/multi_process.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,7 @@ def __exit__(self, exc_type, exc_instance, traceback) -> None:
9090
dist.destroy_process_group(self.pg)
9191
torch.use_deterministic_algorithms(False)
9292
if torch.cuda.is_available() and self.disable_cuda_tf_32:
93-
# torch/testing/_internal/common_utils.py calls `disable_global_flags()`
94-
# workaround RuntimeError: not allowed to set ... after disable_global_flags
95-
setattr( # noqa: B010
96-
torch.backends, "__allow_nonbracketed_mutation_flag", True
97-
)
9893
torch.backends.cudnn.allow_tf32 = True
99-
setattr( # noqa: B010
100-
torch.backends, "__allow_nonbracketed_mutation_flag", False
101-
)
10294

10395

10496
class MultiProcessTestBase(unittest.TestCase):

torchrec/optim/tests/test_clipping.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,18 @@
1212
from unittest.mock import MagicMock, patch
1313

1414
import torch
15+
16+
from parameterized import parameterized
1517
from torch.autograd import Variable
1618
from torch.distributed import ProcessGroup
1719
from torch.distributed.tensor import distribute_tensor, DTensor, init_device_mesh, Shard
18-
from torch.testing._internal.common_utils import (
19-
instantiate_parametrized_tests,
20-
parametrize,
21-
)
2220
from torch.testing._internal.distributed._tensor.common_dtensor import (
2321
DTensorTestBase,
2422
with_comms,
2523
)
24+
25+
setattr(torch.backends, "__allow_nonbracketed_mutation_flag", True) # noqa: B010
26+
2627
from torchrec.optim.clipping import GradientClipping, GradientClippingOptimizer
2728
from torchrec.optim.test_utils import DummyKeyedOptimizer
2829

@@ -243,17 +244,24 @@ def test_clip_no_gradients_norm_meta_device(
243244

244245

245246
@unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available")
246-
@instantiate_parametrized_tests
247247
class TestGradientClippingDTensor(DTensorTestBase):
248248
"""No tests for Replicated DTensors as handled prior to GradientClippingOptimizer"""
249249

250+
def setUp(self) -> None:
251+
setattr(torch.backends, "__allow_nonbracketed_mutation_flag", False) # noqa: B010
252+
return super().setUp()
253+
254+
def tearDown(self) -> None:
255+
setattr(torch.backends, "__allow_nonbracketed_mutation_flag", True) # noqa: B010
256+
return super().tearDown()
257+
250258
def _get_params_to_pg(
251259
self, params: List[DTensor]
252260
) -> Dict[DTensor, List[ProcessGroup]]:
253261
return {param: [param.device_mesh.get_group()] for param in params}
254262

263+
@parameterized.expand(["inf", 1, 2])
255264
@with_comms
256-
@parametrize("norm_type", ("inf", 1, 2))
257265
def test_tensor_and_sharded_dtensor_clip_all_gradients_norm(
258266
self, norm_type: Union[float, str]
259267
) -> None:
@@ -342,8 +350,8 @@ def test_tensor_and_sharded_dtensor_clip_all_gradients_norm(
342350
f"Expect gradient to be the same. However, found {param_grad=}, {ref_param.grad=}",
343351
)
344352

353+
@parameterized.expand(["inf", 1, 2])
345354
@with_comms
346-
@parametrize("norm_type", ("inf", 1, 2))
347355
def test_multiple_sharded_dtensors_clip_all_gradients_norm(
348356
self, norm_type: Union[float, str]
349357
) -> None:

0 commit comments

Comments
 (0)