From 1e057dad030cbbb4c9ab42d63ca39b34f97407c9 Mon Sep 17 00:00:00 2001 From: Erik Wijmans Date: Mon, 8 Feb 2021 18:05:41 -0500 Subject: [PATCH] Fix DDP Reduction in PyTorch 1.7 when the model has unused parameters (#586) * Hacky thing to interface with ddp.forward * State guard * Add unit test and comment in readme about it. --- habitat_baselines/rl/ddppo/README.md | 8 ++ habitat_baselines/rl/ddppo/algo/ddppo.py | 46 ++++++--- habitat_baselines/rl/ppo/ppo.py | 12 ++- test/test_ddppo_reduce.py | 126 +++++++++++++++++++++++ 4 files changed, 176 insertions(+), 16 deletions(-) create mode 100644 test/test_ddppo_reduce.py diff --git a/habitat_baselines/rl/ddppo/README.md b/habitat_baselines/rl/ddppo/README.md index 1c24e4a02d..380a4aef64 100644 --- a/habitat_baselines/rl/ddppo/README.md +++ b/habitat_baselines/rl/ddppo/README.md @@ -15,6 +15,14 @@ The two recommended backends are GLOO and NCCL. Use NCCL if your system has it, See [pytorch's distributed docs](https://pytorch.org/docs/stable/distributed.html#backends-that-come-with-pytorch) and [pytorch's distributed tutorial](https://pytorch.org/tutorials/intermediate/dist_tuto.html) for more information. +### Verifying gradient reduction + +Due to the different nature of RL than supervised learning, the way DD-PPO interfaces with PyTorch's DistributedDataParallel is slightly off the beaten path and while it is reasonably robust new versions of pytorch have broken it in the past. Our CI does not test against every version of pytorch, so if there ever concern that gradient may not be working, run the unit test locally: + +``` +pytest test/test_ddppo_reduce.py +``` + ## Pretrained Models (PointGoal Navigation with GPS+Compass) diff --git a/habitat_baselines/rl/ddppo/algo/ddppo.py b/habitat_baselines/rl/ddppo/algo/ddppo.py index b8340a722b..dcddc280b4 100644 --- a/habitat_baselines/rl/ddppo/algo/ddppo.py +++ b/habitat_baselines/rl/ddppo/algo/ddppo.py @@ -7,7 +7,6 @@ from typing import Tuple import torch -from torch import Tensor from torch import distributed as distrib from habitat_baselines.common.rollout_storage import RolloutStorage @@ -44,6 +43,19 @@ def distributed_mean_and_var( return mean, var +class _EvalActionsWrapper(torch.nn.Module): + r"""Wrapper on evaluate_actions that allows that to be called from forward. + This is needed to interface with DistributedDataParallel's forward call + """ + + def __init__(self, actor_critic): + super().__init__() + self.actor_critic = actor_critic + + def forward(self, *args, **kwargs): + return self.actor_critic.evaluate_actions(*args, **kwargs) + + class DecentralizedDistributedMixin: def _get_advantages_distributed( self, rollouts: RolloutStorage @@ -79,24 +91,28 @@ class Guard: def __init__(self, model, device): if torch.cuda.is_available(): self.ddp = torch.nn.parallel.DistributedDataParallel( - model, device_ids=[device], output_device=device + model, + device_ids=[device], + output_device=device, + find_unused_parameters=find_unused_params, ) else: - self.ddp = torch.nn.parallel.DistributedDataParallel(model) - - self._ddp_hooks = Guard(self.actor_critic, self.device) # type: ignore - self.get_advantages = self._get_advantages_distributed - - self.reducer = self._ddp_hooks.ddp.reducer - self.find_unused_params = find_unused_params + self.ddp = torch.nn.parallel.DistributedDataParallel( + model, + find_unused_parameters=find_unused_params, + ) - def before_backward(self, loss: Tensor) -> None: - super().before_backward(loss) # type: ignore + self._evaluate_actions_wrapper = Guard(_EvalActionsWrapper(self.actor_critic), self.device) # type: ignore - if self.find_unused_params: - self.reducer.prepare_for_backward([loss]) # type: ignore - else: - self.reducer.prepare_for_backward([]) # type: ignore + def _evaluate_actions( + self, observations, rnn_hidden_states, prev_actions, masks, action + ): + r"""Internal method that calls Policy.evaluate_actions. This is used instead of calling + that directly so that that call can be overrided with inheritence + """ + return self._evaluate_actions_wrapper.ddp( + observations, rnn_hidden_states, prev_actions, masks, action + ) class DDPPO(DecentralizedDistributedMixin, PPO): diff --git a/habitat_baselines/rl/ppo/ppo.py b/habitat_baselines/rl/ppo/ppo.py index 34917f9a02..9865efdcc6 100644 --- a/habitat_baselines/rl/ppo/ppo.py +++ b/habitat_baselines/rl/ppo/ppo.py @@ -88,7 +88,7 @@ def update(self, rollouts: RolloutStorage) -> Tuple[float, float, float]: action_log_probs, dist_entropy, _, - ) = self.actor_critic.evaluate_actions( + ) = self._evaluate_actions( batch["observations"], batch["recurrent_hidden_states"], batch["prev_actions"], @@ -152,6 +152,16 @@ def update(self, rollouts: RolloutStorage) -> Tuple[float, float, float]: return value_loss_epoch, action_loss_epoch, dist_entropy_epoch + def _evaluate_actions( + self, observations, rnn_hidden_states, prev_actions, masks, action + ): + r"""Internal method that calls Policy.evaluate_actions. This is used instead of calling + that directly so that that call can be overrided with inheritence + """ + return self.actor_critic.evaluate_actions( + observations, rnn_hidden_states, prev_actions, masks, action + ) + def before_backward(self, loss: Tensor) -> None: pass diff --git a/test/test_ddppo_reduce.py b/test/test_ddppo_reduce.py new file mode 100644 index 0000000000..5eae6fff54 --- /dev/null +++ b/test/test_ddppo_reduce.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import pytest + +from habitat.core.spaces import ActionSpace, EmptySpace +from habitat.tasks.nav.nav import IntegratedPointGoalGPSAndCompassSensor + +torch = pytest.importorskip("torch") +habitat_baselines = pytest.importorskip("habitat_baselines") + +import gym +from torch import distributed as distrib +from torch import nn + +from habitat_baselines.common.rollout_storage import RolloutStorage +from habitat_baselines.config.default import get_config +from habitat_baselines.rl.ddppo.algo import DDPPO +from habitat_baselines.rl.ppo.policy import PointNavBaselinePolicy + + +def _worker_fn( + world_rank: int, world_size: int, port: int, unused_params: bool +): + device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ) + tcp_store = distrib.TCPStore( # type: ignore + "127.0.0.1", port, world_size, world_rank == 0 + ) + distrib.init_process_group( + "gloo", store=tcp_store, rank=world_rank, world_size=world_size + ) + + config = get_config("habitat_baselines/config/test/ppo_pointnav_test.yaml") + obs_space = gym.spaces.Dict( + { + IntegratedPointGoalGPSAndCompassSensor.cls_uuid: gym.spaces.Box( + low=np.finfo(np.float32).min, + high=np.finfo(np.float32).max, + shape=(2,), + dtype=np.float32, + ) + } + ) + action_space = ActionSpace({"move": EmptySpace()}) + actor_critic = PointNavBaselinePolicy.from_config( + config, obs_space, action_space + ) + # This use adds some arbitrary parameters that aren't part of the computation + # graph, so they will mess up DDP if they aren't correctly ignored by it + if unused_params: + actor_critic.unused = nn.Linear(64, 64) + + actor_critic.to(device=device) + ppo_cfg = config.RL.PPO + agent = DDPPO( + actor_critic=actor_critic, + clip_param=ppo_cfg.clip_param, + ppo_epoch=ppo_cfg.ppo_epoch, + num_mini_batch=ppo_cfg.num_mini_batch, + value_loss_coef=ppo_cfg.value_loss_coef, + entropy_coef=ppo_cfg.entropy_coef, + lr=ppo_cfg.lr, + eps=ppo_cfg.eps, + max_grad_norm=ppo_cfg.max_grad_norm, + use_normalized_advantage=ppo_cfg.use_normalized_advantage, + ) + agent.init_distributed() + rollouts = RolloutStorage( + ppo_cfg.num_steps, + 2, + obs_space, + action_space, + ppo_cfg.hidden_size, + num_recurrent_layers=actor_critic.net.num_recurrent_layers, + is_double_buffered=False, + ) + rollouts.to(device) + + for k, v in rollouts.buffers["observations"].items(): + rollouts.buffers["observations"][k] = torch.randn_like(v) + + # Add two steps so batching works + rollouts.advance_rollout() + rollouts.advance_rollout() + + # Get a single batch + batch = next(rollouts.recurrent_generator(rollouts.buffers["returns"], 1)) + + # Call eval actions through the internal wrapper that is used in + # agent.update + value, action_log_probs, dist_entropy, _ = agent._evaluate_actions( + batch["observations"], + batch["recurrent_hidden_states"], + batch["prev_actions"], + batch["masks"], + batch["actions"], + ) + # Backprop on things + (value.mean() + action_log_probs.mean() + dist_entropy.mean()).backward() + + # Make sure all ranks have very similar parameters + for param in actor_critic.parameters(): + if param.grad is not None: + grads = [param.grad.detach().clone() for _ in range(world_size)] + distrib.all_gather(grads, grads[world_rank]) + + for i in range(world_size): + assert torch.isclose(grads[i], grads[world_rank]).all() + + +@pytest.mark.parametrize("unused_params", [True, False]) +def test_ddppo_reduce(unused_params: bool): + world_size = 2 + torch.multiprocessing.spawn( + _worker_fn, + args=(world_size, 8748 + int(unused_params), unused_params), + nprocs=world_size, + )