Skip to content

Commit 193f1c4

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
fix RuntimeError: not allowed to set torch.backends.cudnn flags after disable_global_flags (#3343)
Summary: # context * after fix github CI workflow (GPU unit tests) we found lots of errors come from the same root cause: ``` torchrec/test_utils/__init__.py:129: in _wrapper return wrapped_func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ torchrec/distributed/test_utils/multi_process.py:131: in setUp torch.backends.cudnn.allow_tf32 = False ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ self = <torch.backends.ContextProp object at 0x7f4e8bb3ba10> obj = <module 'torch.backends.cudnn' from '/opt/conda/envs/build_binary/lib/python3.11/site-packages/torch/backends/cudnn/__init__.py'> val = False def __set__(self, obj, val): if not flags_frozen(): self.setter(val) else: > raise RuntimeError( f"not allowed to set {obj.__name__} flags " "after disable_global_flags; please use flags() context manager instead" ) E RuntimeError: not allowed to set torch.backends.cudnn flags after disable_global_flags; please use flags() context manager instead ``` * according to D77758554, the issue is due to D78326114 introducing `torch.testing._internal.common_utils` ``` # torch/testing/_internal/common_utils.py calls `disable_global_flags()` # workaround RuntimeError: not allowed to set ... after disable_global_flags ``` Differential Revision: D81529616
1 parent 0d67c37 commit 193f1c4

File tree

2 files changed

+4
-15
lines changed

2 files changed

+4
-15
lines changed

torchrec/distributed/test_utils/multi_process.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,7 @@ def __exit__(self, exc_type, exc_instance, traceback) -> None:
9090
dist.destroy_process_group(self.pg)
9191
torch.use_deterministic_algorithms(False)
9292
if torch.cuda.is_available() and self.disable_cuda_tf_32:
93-
# torch/testing/_internal/common_utils.py calls `disable_global_flags()`
94-
# workaround RuntimeError: not allowed to set ... after disable_global_flags
95-
setattr( # noqa: B010
96-
torch.backends, "__allow_nonbracketed_mutation_flag", True
97-
)
9893
torch.backends.cudnn.allow_tf32 = True
99-
setattr( # noqa: B010
100-
torch.backends, "__allow_nonbracketed_mutation_flag", False
101-
)
10294

10395

10496
class MultiProcessTestBase(unittest.TestCase):

torchrec/optim/tests/test_clipping.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,11 @@
1212
from unittest.mock import MagicMock, patch
1313

1414
import torch
15+
16+
from parameterized import parameterized
1517
from torch.autograd import Variable
1618
from torch.distributed import ProcessGroup
1719
from torch.distributed.tensor import distribute_tensor, DTensor, init_device_mesh, Shard
18-
from torch.testing._internal.common_utils import (
19-
instantiate_parametrized_tests,
20-
parametrize,
21-
)
2220
from torch.testing._internal.distributed._tensor.common_dtensor import (
2321
DTensorTestBase,
2422
with_comms,
@@ -243,7 +241,6 @@ def test_clip_no_gradients_norm_meta_device(
243241

244242

245243
@unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available")
246-
@instantiate_parametrized_tests
247244
class TestGradientClippingDTensor(DTensorTestBase):
248245
"""No tests for Replicated DTensors as handled prior to GradientClippingOptimizer"""
249246

@@ -252,8 +249,8 @@ def _get_params_to_pg(
252249
) -> Dict[DTensor, List[ProcessGroup]]:
253250
return {param: [param.device_mesh.get_group()] for param in params}
254251

252+
@parameterized.expand(["inf", 1, 2])
255253
@with_comms
256-
@parametrize("norm_type", ("inf", 1, 2))
257254
def test_tensor_and_sharded_dtensor_clip_all_gradients_norm(
258255
self, norm_type: Union[float, str]
259256
) -> None:
@@ -342,8 +339,8 @@ def test_tensor_and_sharded_dtensor_clip_all_gradients_norm(
342339
f"Expect gradient to be the same. However, found {param_grad=}, {ref_param.grad=}",
343340
)
344341

342+
@parameterized.expand(["inf", 1, 2])
345343
@with_comms
346-
@parametrize("norm_type", ("inf", 1, 2))
347344
def test_multiple_sharded_dtensors_clip_all_gradients_norm(
348345
self, norm_type: Union[float, str]
349346
) -> None:

0 commit comments

Comments
 (0)