-
Notifications
You must be signed in to change notification settings - Fork 505
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix DDP Reduction in PyTorch 1.7 when the model has unused parameters (…
…#586) * Hacky thing to interface with ddp.forward * State guard * Add unit test and comment in readme about it.
- Loading branch information
1 parent
aa24551
commit 1e057da
Showing
4 changed files
with
176 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from habitat.core.spaces import ActionSpace, EmptySpace | ||
from habitat.tasks.nav.nav import IntegratedPointGoalGPSAndCompassSensor | ||
|
||
torch = pytest.importorskip("torch") | ||
habitat_baselines = pytest.importorskip("habitat_baselines") | ||
|
||
import gym | ||
from torch import distributed as distrib | ||
from torch import nn | ||
|
||
from habitat_baselines.common.rollout_storage import RolloutStorage | ||
from habitat_baselines.config.default import get_config | ||
from habitat_baselines.rl.ddppo.algo import DDPPO | ||
from habitat_baselines.rl.ppo.policy import PointNavBaselinePolicy | ||
|
||
|
||
def _worker_fn( | ||
world_rank: int, world_size: int, port: int, unused_params: bool | ||
): | ||
device = ( | ||
torch.device("cuda") | ||
if torch.cuda.is_available() | ||
else torch.device("cpu") | ||
) | ||
tcp_store = distrib.TCPStore( # type: ignore | ||
"127.0.0.1", port, world_size, world_rank == 0 | ||
) | ||
distrib.init_process_group( | ||
"gloo", store=tcp_store, rank=world_rank, world_size=world_size | ||
) | ||
|
||
config = get_config("habitat_baselines/config/test/ppo_pointnav_test.yaml") | ||
obs_space = gym.spaces.Dict( | ||
{ | ||
IntegratedPointGoalGPSAndCompassSensor.cls_uuid: gym.spaces.Box( | ||
low=np.finfo(np.float32).min, | ||
high=np.finfo(np.float32).max, | ||
shape=(2,), | ||
dtype=np.float32, | ||
) | ||
} | ||
) | ||
action_space = ActionSpace({"move": EmptySpace()}) | ||
actor_critic = PointNavBaselinePolicy.from_config( | ||
config, obs_space, action_space | ||
) | ||
# This use adds some arbitrary parameters that aren't part of the computation | ||
# graph, so they will mess up DDP if they aren't correctly ignored by it | ||
if unused_params: | ||
actor_critic.unused = nn.Linear(64, 64) | ||
|
||
actor_critic.to(device=device) | ||
ppo_cfg = config.RL.PPO | ||
agent = DDPPO( | ||
actor_critic=actor_critic, | ||
clip_param=ppo_cfg.clip_param, | ||
ppo_epoch=ppo_cfg.ppo_epoch, | ||
num_mini_batch=ppo_cfg.num_mini_batch, | ||
value_loss_coef=ppo_cfg.value_loss_coef, | ||
entropy_coef=ppo_cfg.entropy_coef, | ||
lr=ppo_cfg.lr, | ||
eps=ppo_cfg.eps, | ||
max_grad_norm=ppo_cfg.max_grad_norm, | ||
use_normalized_advantage=ppo_cfg.use_normalized_advantage, | ||
) | ||
agent.init_distributed() | ||
rollouts = RolloutStorage( | ||
ppo_cfg.num_steps, | ||
2, | ||
obs_space, | ||
action_space, | ||
ppo_cfg.hidden_size, | ||
num_recurrent_layers=actor_critic.net.num_recurrent_layers, | ||
is_double_buffered=False, | ||
) | ||
rollouts.to(device) | ||
|
||
for k, v in rollouts.buffers["observations"].items(): | ||
rollouts.buffers["observations"][k] = torch.randn_like(v) | ||
|
||
# Add two steps so batching works | ||
rollouts.advance_rollout() | ||
rollouts.advance_rollout() | ||
|
||
# Get a single batch | ||
batch = next(rollouts.recurrent_generator(rollouts.buffers["returns"], 1)) | ||
|
||
# Call eval actions through the internal wrapper that is used in | ||
# agent.update | ||
value, action_log_probs, dist_entropy, _ = agent._evaluate_actions( | ||
batch["observations"], | ||
batch["recurrent_hidden_states"], | ||
batch["prev_actions"], | ||
batch["masks"], | ||
batch["actions"], | ||
) | ||
# Backprop on things | ||
(value.mean() + action_log_probs.mean() + dist_entropy.mean()).backward() | ||
|
||
# Make sure all ranks have very similar parameters | ||
for param in actor_critic.parameters(): | ||
if param.grad is not None: | ||
grads = [param.grad.detach().clone() for _ in range(world_size)] | ||
distrib.all_gather(grads, grads[world_rank]) | ||
|
||
for i in range(world_size): | ||
assert torch.isclose(grads[i], grads[world_rank]).all() | ||
|
||
|
||
@pytest.mark.parametrize("unused_params", [True, False]) | ||
def test_ddppo_reduce(unused_params: bool): | ||
world_size = 2 | ||
torch.multiprocessing.spawn( | ||
_worker_fn, | ||
args=(world_size, 8748 + int(unused_params), unused_params), | ||
nprocs=world_size, | ||
) |