Skip to content

Commit

Permalink
Fix DDP Reduction in PyTorch 1.7 when the model has unused parameters (
Browse files Browse the repository at this point in the history
…#586)

* Hacky thing to interface with ddp.forward

* State guard

* Add unit test and comment in readme about it.
  • Loading branch information
erikwijmans authored Feb 8, 2021
1 parent aa24551 commit 1e057da
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 16 deletions.
8 changes: 8 additions & 0 deletions habitat_baselines/rl/ddppo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ The two recommended backends are GLOO and NCCL. Use NCCL if your system has it,
See [pytorch's distributed docs](https://pytorch.org/docs/stable/distributed.html#backends-that-come-with-pytorch)
and [pytorch's distributed tutorial](https://pytorch.org/tutorials/intermediate/dist_tuto.html) for more information.

### Verifying gradient reduction

Due to the different nature of RL than supervised learning, the way DD-PPO interfaces with PyTorch's DistributedDataParallel is slightly off the beaten path and while it is reasonably robust new versions of pytorch have broken it in the past. Our CI does not test against every version of pytorch, so if there ever concern that gradient may not be working, run the unit test locally:

```
pytest test/test_ddppo_reduce.py
```

## Pretrained Models (PointGoal Navigation with GPS+Compass)


Expand Down
46 changes: 31 additions & 15 deletions habitat_baselines/rl/ddppo/algo/ddppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Tuple

import torch
from torch import Tensor
from torch import distributed as distrib

from habitat_baselines.common.rollout_storage import RolloutStorage
Expand Down Expand Up @@ -44,6 +43,19 @@ def distributed_mean_and_var(
return mean, var


class _EvalActionsWrapper(torch.nn.Module):
r"""Wrapper on evaluate_actions that allows that to be called from forward.
This is needed to interface with DistributedDataParallel's forward call
"""

def __init__(self, actor_critic):
super().__init__()
self.actor_critic = actor_critic

def forward(self, *args, **kwargs):
return self.actor_critic.evaluate_actions(*args, **kwargs)


class DecentralizedDistributedMixin:
def _get_advantages_distributed(
self, rollouts: RolloutStorage
Expand Down Expand Up @@ -79,24 +91,28 @@ class Guard:
def __init__(self, model, device):
if torch.cuda.is_available():
self.ddp = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[device], output_device=device
model,
device_ids=[device],
output_device=device,
find_unused_parameters=find_unused_params,
)
else:
self.ddp = torch.nn.parallel.DistributedDataParallel(model)

self._ddp_hooks = Guard(self.actor_critic, self.device) # type: ignore
self.get_advantages = self._get_advantages_distributed

self.reducer = self._ddp_hooks.ddp.reducer
self.find_unused_params = find_unused_params
self.ddp = torch.nn.parallel.DistributedDataParallel(
model,
find_unused_parameters=find_unused_params,
)

def before_backward(self, loss: Tensor) -> None:
super().before_backward(loss) # type: ignore
self._evaluate_actions_wrapper = Guard(_EvalActionsWrapper(self.actor_critic), self.device) # type: ignore

if self.find_unused_params:
self.reducer.prepare_for_backward([loss]) # type: ignore
else:
self.reducer.prepare_for_backward([]) # type: ignore
def _evaluate_actions(
self, observations, rnn_hidden_states, prev_actions, masks, action
):
r"""Internal method that calls Policy.evaluate_actions. This is used instead of calling
that directly so that that call can be overrided with inheritence
"""
return self._evaluate_actions_wrapper.ddp(
observations, rnn_hidden_states, prev_actions, masks, action
)


class DDPPO(DecentralizedDistributedMixin, PPO):
Expand Down
12 changes: 11 additions & 1 deletion habitat_baselines/rl/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def update(self, rollouts: RolloutStorage) -> Tuple[float, float, float]:
action_log_probs,
dist_entropy,
_,
) = self.actor_critic.evaluate_actions(
) = self._evaluate_actions(
batch["observations"],
batch["recurrent_hidden_states"],
batch["prev_actions"],
Expand Down Expand Up @@ -152,6 +152,16 @@ def update(self, rollouts: RolloutStorage) -> Tuple[float, float, float]:

return value_loss_epoch, action_loss_epoch, dist_entropy_epoch

def _evaluate_actions(
self, observations, rnn_hidden_states, prev_actions, masks, action
):
r"""Internal method that calls Policy.evaluate_actions. This is used instead of calling
that directly so that that call can be overrided with inheritence
"""
return self.actor_critic.evaluate_actions(
observations, rnn_hidden_states, prev_actions, masks, action
)

def before_backward(self, loss: Tensor) -> None:
pass

Expand Down
126 changes: 126 additions & 0 deletions test/test_ddppo_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import pytest

from habitat.core.spaces import ActionSpace, EmptySpace
from habitat.tasks.nav.nav import IntegratedPointGoalGPSAndCompassSensor

torch = pytest.importorskip("torch")
habitat_baselines = pytest.importorskip("habitat_baselines")

import gym
from torch import distributed as distrib
from torch import nn

from habitat_baselines.common.rollout_storage import RolloutStorage
from habitat_baselines.config.default import get_config
from habitat_baselines.rl.ddppo.algo import DDPPO
from habitat_baselines.rl.ppo.policy import PointNavBaselinePolicy


def _worker_fn(
world_rank: int, world_size: int, port: int, unused_params: bool
):
device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
tcp_store = distrib.TCPStore( # type: ignore
"127.0.0.1", port, world_size, world_rank == 0
)
distrib.init_process_group(
"gloo", store=tcp_store, rank=world_rank, world_size=world_size
)

config = get_config("habitat_baselines/config/test/ppo_pointnav_test.yaml")
obs_space = gym.spaces.Dict(
{
IntegratedPointGoalGPSAndCompassSensor.cls_uuid: gym.spaces.Box(
low=np.finfo(np.float32).min,
high=np.finfo(np.float32).max,
shape=(2,),
dtype=np.float32,
)
}
)
action_space = ActionSpace({"move": EmptySpace()})
actor_critic = PointNavBaselinePolicy.from_config(
config, obs_space, action_space
)
# This use adds some arbitrary parameters that aren't part of the computation
# graph, so they will mess up DDP if they aren't correctly ignored by it
if unused_params:
actor_critic.unused = nn.Linear(64, 64)

actor_critic.to(device=device)
ppo_cfg = config.RL.PPO
agent = DDPPO(
actor_critic=actor_critic,
clip_param=ppo_cfg.clip_param,
ppo_epoch=ppo_cfg.ppo_epoch,
num_mini_batch=ppo_cfg.num_mini_batch,
value_loss_coef=ppo_cfg.value_loss_coef,
entropy_coef=ppo_cfg.entropy_coef,
lr=ppo_cfg.lr,
eps=ppo_cfg.eps,
max_grad_norm=ppo_cfg.max_grad_norm,
use_normalized_advantage=ppo_cfg.use_normalized_advantage,
)
agent.init_distributed()
rollouts = RolloutStorage(
ppo_cfg.num_steps,
2,
obs_space,
action_space,
ppo_cfg.hidden_size,
num_recurrent_layers=actor_critic.net.num_recurrent_layers,
is_double_buffered=False,
)
rollouts.to(device)

for k, v in rollouts.buffers["observations"].items():
rollouts.buffers["observations"][k] = torch.randn_like(v)

# Add two steps so batching works
rollouts.advance_rollout()
rollouts.advance_rollout()

# Get a single batch
batch = next(rollouts.recurrent_generator(rollouts.buffers["returns"], 1))

# Call eval actions through the internal wrapper that is used in
# agent.update
value, action_log_probs, dist_entropy, _ = agent._evaluate_actions(
batch["observations"],
batch["recurrent_hidden_states"],
batch["prev_actions"],
batch["masks"],
batch["actions"],
)
# Backprop on things
(value.mean() + action_log_probs.mean() + dist_entropy.mean()).backward()

# Make sure all ranks have very similar parameters
for param in actor_critic.parameters():
if param.grad is not None:
grads = [param.grad.detach().clone() for _ in range(world_size)]
distrib.all_gather(grads, grads[world_rank])

for i in range(world_size):
assert torch.isclose(grads[i], grads[world_rank]).all()


@pytest.mark.parametrize("unused_params", [True, False])
def test_ddppo_reduce(unused_params: bool):
world_size = 2
torch.multiprocessing.spawn(
_worker_fn,
args=(world_size, 8748 + int(unused_params), unused_params),
nprocs=world_size,
)

0 comments on commit 1e057da

Please sign in to comment.