From a768b6372ef4616d5ea0eb1fef63554382804af8 Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Thu, 22 Aug 2024 11:36:09 +0200 Subject: [PATCH 1/2] Added kwargs to torch.nn.parallel.DistributedDataParallel. Signed-off-by: simonsays1980 --- rllib/algorithms/algorithm_config.py | 11 +++++++++++ rllib/core/learner/torch/torch_learner.py | 14 +++++++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 63e9aafd71ef..24e51a8c0c96 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -312,6 +312,8 @@ def __init__(self, algo_class: Optional[type] = None): "aot_eager" if sys.platform == "darwin" else "onnxrt" ) self.torch_compile_worker_dynamo_mode = None + # Default kwargs for `torch.nn.parallel.DistributedDataParallel`. + self.torch_ddp_kwargs = {} # `self.api_stack()` self.enable_rl_module_and_learner = False @@ -1378,6 +1380,7 @@ def framework( torch_compile_worker: Optional[bool] = NotProvided, torch_compile_worker_dynamo_backend: Optional[str] = NotProvided, torch_compile_worker_dynamo_mode: Optional[str] = NotProvided, + torch_ddp_kwargs: Optional[Dict[str, Any]] = NotProvided, ) -> "AlgorithmConfig": """Sets the config's DL framework settings. @@ -1417,6 +1420,12 @@ def framework( the workers. torch_compile_worker_dynamo_mode: The torch dynamo mode to use on the workers. + torch_ddp_kwargs: The kwargs to pass into + `torch.nn.parallel.DistributedDataParallel` when using `num_learners + > 1. This is specifically helpful when searching for unused parameters + that are not used in the backward pass. This can give hints for errors + in custom models where some parameters do not get touched in the + backward pass although they should. Returns: This updated AlgorithmConfig object. @@ -1458,6 +1467,8 @@ def framework( ) if torch_compile_worker_dynamo_mode is not NotProvided: self.torch_compile_worker_dynamo_mode = torch_compile_worker_dynamo_mode + if torch_ddp_kwargs is not NotProvided: + self.torch_ddp_kwargs = torch_ddp_kwargs return self diff --git a/rllib/core/learner/torch/torch_learner.py b/rllib/core/learner/torch/torch_learner.py index ff643ffb6098..a560f8a7b1f4 100644 --- a/rllib/core/learner/torch/torch_learner.py +++ b/rllib/core/learner/torch/torch_learner.py @@ -251,7 +251,9 @@ def add_module( "torch compile." ) self._module.add_module( - module_id, TorchDDPRLModule(module), override=True + module_id, + TorchDDPRLModule(module, **self.config.torch_ddp_kwargs), + override=True, ) return marl_spec @@ -406,7 +408,9 @@ def _make_modules_ddp_if_necessary(self) -> None: if self._distributed: # Single agent module: Convert to `TorchDDPRLModule`. if isinstance(self._module, TorchRLModule): - self._module = TorchDDPRLModule(self._module) + self._module = TorchDDPRLModule( + self._module, **self.config.torch_ddp_kwargs + ) # Multi agent module: Convert each submodule to `TorchDDPRLModule`. else: assert isinstance(self._module, MultiRLModule) @@ -415,7 +419,11 @@ def _make_modules_ddp_if_necessary(self) -> None: if isinstance(sub_module, TorchRLModule): # Wrap and override the module ID key in self._module. self._module.add_module( - key, TorchDDPRLModule(sub_module), override=True + key, + TorchDDPRLModule( + sub_module, **self.config.torch_ddp_kwargs + ), + override=True, ) def _is_module_compatible_with_learner(self, module: RLModule) -> bool: From 7e57b8e298490ad2b7c1170e99223daa9f740660 Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Sat, 24 Aug 2024 12:02:45 +0200 Subject: [PATCH 2/2] Added missing literal string in docs. Signed-off-by: simonsays1980 --- rllib/algorithms/algorithm_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 560ff349bd8d..41cc4874e25c 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -1422,7 +1422,7 @@ def framework( workers. torch_ddp_kwargs: The kwargs to pass into `torch.nn.parallel.DistributedDataParallel` when using `num_learners - > 1. This is specifically helpful when searching for unused parameters + > 1`. This is specifically helpful when searching for unused parameters that are not used in the backward pass. This can give hints for errors in custom models where some parameters do not get touched in the backward pass although they should.