Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Fully Sharded Data Parallel #3740

Merged
merged 33 commits into from
Jul 1, 2021
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
a5b4da4
Implement zero2 and zero3
stephenroller Jun 22, 2021
ea9390c
Implement overflow syncing.
stephenroller Jun 22, 2021
378eacc
Tweak log statements.
stephenroller Jun 22, 2021
0ea7d3a
Use free ports rather than random ports
stephenroller Jun 22, 2021
fc3e668
Refactor test_distributed
stephenroller Jun 22, 2021
65ad526
More refactor.
stephenroller Jun 22, 2021
3153dd8
Fixup checkpoints.
stephenroller Jun 22, 2021
82f8b01
Merge branch 'freeport' into fsdp
stephenroller Jun 22, 2021
44fcdfc
Get tests working.
stephenroller Jun 22, 2021
281efd1
GPU only
stephenroller Jun 22, 2021
4146d86
Sigh
stephenroller Jun 22, 2021
5c6755a
Moar.
stephenroller Jun 22, 2021
dc5edc3
Trying to sync grad norms
stephenroller Jun 23, 2021
7e12292
Correctly implement gnorm syncing.
stephenroller Jun 23, 2021
66d53d3
Update comment.
stephenroller Jun 23, 2021
0bb9995
Merge branch 'master' into fsdp
stephenroller Jun 24, 2021
1cb30d1
Try zero3.
stephenroller Jun 24, 2021
5cea3b2
Okay got zero3 working.
stephenroller Jun 24, 2021
490f5d8
Refactor.
stephenroller Jun 25, 2021
31dfeb5
Get FSDP Zero3 working, except during validation.
stephenroller Jun 25, 2021
ded3708
Merge branch 'master' into fsdp
stephenroller Jun 28, 2021
d095f51
Check in missing code. Carve out notimplemented.
stephenroller Jun 28, 2021
f17abb2
Lint.
stephenroller Jun 28, 2021
231e88d
Er.
stephenroller Jun 28, 2021
4a3ce86
Add a test to ensure we keep track of zero3 not working.
stephenroller Jun 28, 2021
98a90b7
Remove debugs, add docstrings, rename variable.
stephenroller Jun 28, 2021
a2f84c1
Silly
stephenroller Jun 28, 2021
e45b149
Merge branch 'master' into fsdp
stephenroller Jun 29, 2021
61b64dc
Reviewer comments.
stephenroller Jul 1, 2021
16374c9
Lint.
stephenroller Jul 1, 2021
074be0a
We disabled zero3 as an option, so don't need the test.
stephenroller Jul 1, 2021
0814c99
Bug caught by Kurt.
stephenroller Jul 1, 2021
c5a82aa
Rofl
stephenroller Jul 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions parlai/agents/transformer/modules/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from parlai.core.opt import Opt
from parlai.utils.misc import warn_once
from parlai.utils.torch import PipelineHelper
from parlai.utils.fsdp import fsdp_wrap


@swappable(
Expand Down Expand Up @@ -277,16 +278,15 @@ def _default(val, default):
def build_layers(self) -> nn.ModuleList:
layers = nn.ModuleList()
for _ in range(self.n_layers):
layers.append(
self.swappables.layer(
self.opt,
attention_dropout=self.opt.get('attention_dropout', 0.0),
relu_dropout=self.opt.get('relu_dropout', 0.0),
dropout=self.opt.get('dropout', 0.0),
activation=self.activation,
variant=self.variant,
) # type: ignore
layer = self.swappables.layer(
self.opt,
attention_dropout=self.opt.get('attention_dropout', 0.0),
relu_dropout=self.opt.get('relu_dropout', 0.0),
dropout=self.opt.get('dropout', 0.0),
activation=self.activation,
variant=self.variant,
)
layers.append(fsdp_wrap(layer)) # type: ignore
return layers

def forward_embedding(
Expand Down
18 changes: 9 additions & 9 deletions parlai/agents/transformer/modules/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from parlai.core.opt import Opt
from parlai.utils.misc import warn_once
from parlai.utils.torch import PipelineHelper
from parlai.utils.fsdp import fsdp_wrap


@swappable(self_attention=MultiHeadAttention, feedforward=TransformerFFN)
Expand Down Expand Up @@ -227,16 +228,15 @@ def _default(val, default):
def build_layers(self) -> nn.ModuleList:
layers = nn.ModuleList()
for _ in range(self.n_layers):
layers.append(
self.swappables.layer( # type: ignore
self.opt,
attention_dropout=self.opt.get('attention_dropout', 0.0),
relu_dropout=self.opt.get('relu_dropout', 0.0),
dropout=self.dropout_frac,
variant=self.variant,
activation=self.activation,
)
layer = self.swappables.layer( # type: ignore
self.opt,
attention_dropout=self.opt.get('attention_dropout', 0.0),
relu_dropout=self.opt.get('relu_dropout', 0.0),
dropout=self.dropout_frac,
variant=self.variant,
activation=self.activation,
)
layers.append(fsdp_wrap(layer))
return layers

def forward_embedding(
Expand Down
10 changes: 10 additions & 0 deletions parlai/core/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,16 @@ def add_distributed_training_args(self):
grp.add_argument(
'--distributed-world-size', type=int, help='Number of workers.'
)
grp.add_argument(
'--ddp-backend',
# TODO: add in zero3. https://github.com/facebookresearch/ParlAI/issues/3753
choices=['ddp', 'zero2'],
default='ddp',
help=(
'Distributed backend. Zero2 can be faster but is more experimental. '
'DDP is the most tested.'
),
)
return grp

def add_model_args(self):
Expand Down
24 changes: 20 additions & 4 deletions parlai/core/torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from parlai.utils.distributed import is_distributed
from parlai.utils.misc import AttrDict, warn_once
from parlai.utils.io import PathManager
from parlai.utils.fsdp import should_sync_gradnorm, is_fsdp, DEFAULT_DDP_BACKEND
from parlai.utils.fp16 import (
SafeFP16Optimizer,
MemoryEfficientFP16Optimizer,
Expand Down Expand Up @@ -1052,7 +1053,9 @@ def init_optim(
self.optimizer = optim_class(params, **kwargs)
if self.fp16:
if self.fp16_impl == 'safe':
self.optimizer = SafeFP16Optimizer(self.optimizer)
self.optimizer = SafeFP16Optimizer(
self.optimizer, should_sync_gradnorm(opt)
)
else:
# Using memory efficient optimizer
opt_name = opt['optimizer']
Expand All @@ -1064,7 +1067,9 @@ def init_optim(
'with Memory Efficient FP16. Please select from among this '
f'list:\n{compatible_list}'
)
self.optimizer = MemoryEfficientFP16Optimizer(self.optimizer)
self.optimizer = MemoryEfficientFP16Optimizer(
self.optimizer, should_sync_gradnorm(opt)
)

if is_finetune:
logging.warning('Detected a fine-tune run. Resetting the optimizer.')
Expand Down Expand Up @@ -1969,10 +1974,11 @@ def state_dict(self):
"""
states = {}
if hasattr(self, 'model'): # save model params
if hasattr(self.model, 'module'):
# did we wrap in a DistributedDataParallel
if hasattr(self.model, 'module') and not is_fsdp(self.model):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could make this a helper function too? like should_sync_gradnorm (not necessary of course)

# did we wrap in a DistributedDataParallel or DataParallel
states['model'] = self.model.module.state_dict()
else:
# regular model or FSDP
states['model'] = self.model.state_dict()

if hasattr(self, 'optimizer'):
Expand All @@ -1992,6 +1998,16 @@ def state_dict(self):

return states

def save_nonprimary(self, path=None):
"""
Save model parameters, when you are working on the non-primary worker.

For models or optimizers that shard parameters, this ensures we sync.
"""
if self.opt.get('ddp_backend', DEFAULT_DDP_BACKEND) in ('zero2', 'zero3'):
# make sure we call the state dict
self.state_dict()

def save(self, path=None):
"""
Save model parameters to path (or default to model_file arg).
Expand Down
20 changes: 13 additions & 7 deletions parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import parlai.utils.logging as logging
from parlai.core.metrics import SumMetric, AverageMetric, FairseqBleuMetric
from parlai.utils.fp16 import FP16SafeCrossEntropy
import parlai.utils.fsdp as fsdp_utils
from parlai.utils.torch import (
neginf,
total_parameters,
Expand Down Expand Up @@ -479,8 +480,10 @@ def __init__(self, opt: Opt, shared=None):
else:
# this is not a shared instance of this class, so do full init
self.criterion = self.build_criterion()
# ensure all distributed copies will always be in sync
self.model = self.build_model()
with fsdp_utils.maybe_fsdp_wrap(opt):
self.model = fsdp_utils.fsdp_wrap(self.build_model())
if self.fp16 and not fsdp_utils.should_use_fsdp(opt):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remember that bug with the instability stuff? is this not re-introducing it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(because we moved the model.half() call?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I think this needs to use my utility should_delay_halving. Forgot this.

We haven't really moved it the moment of halving. The operations between these two points don't do much, and the original code path should be about the same.

  • We now half it on CPU instead of GPU, and then transfer. That's probably a small speedup in initialization really, with maybe some small numerical differences
  • We model parallel after halving. Probably small speedup at initialization.
  • We synchronize parameters after halving. Again, small initialization speedup.

The catch is that FSDP expects the model pre-halved if we're doing safe optimization, and post-halved if we're doing memory-efficient. (Similar to the optimizer wrappers, it looks for parameters of types to decide what type are the gradients).

This is the desired pattern

  • If we're in Safe and using DDP, we SHOULD still halve, just as before
  • If we're in MemEff and using DDP, we SHOULD still halve, just as before
  • If we're in Safe and Zero2, we should NOT halve here
  • If we're in MemEff and Zero2, we SHOULD halve here.

self.model = self.model.half()

# load the block_list for beam search
self.beam_block_list = self._load_beam_block_list()
Expand All @@ -498,16 +501,15 @@ def __init__(self, opt: Opt, shared=None):
self.model.cuda()
self.criterion.cuda()

sync_parameters(self.model)
if not fsdp_utils.is_fsdp(self.model):
sync_parameters(self.model)

train_params = trainable_parameters(self.model)
total_params = total_parameters(self.model)
logging.info(
f"Total parameters: {total_params:,d} ({train_params:,d} trainable)"
)

if self.fp16:
self.model = self.model.half()

if init_model is not None:
# load model parameters if available
logging.info(f'Loading existing model params from {init_model}')
Expand All @@ -530,7 +532,11 @@ def __init__(self, opt: Opt, shared=None):
logging.warning("Optimizer was reset. Also resetting LR scheduler.")
self.build_lr_scheduler(states, hard_reset=is_finetune or was_reset)

if shared is None and is_distributed():
if (
shared is None
and is_distributed()
and opt.get('ddp_backend', fsdp_utils.DEFAULT_DDP_BACKEND) == 'ddp'
):
device_ids = None if self.model_parallel else [self.opt['gpu']]
self.model = torch.nn.parallel.DistributedDataParallel(
self.model, device_ids=device_ids, broadcast_buffers=False
Expand Down
3 changes: 1 addition & 2 deletions parlai/scripts/multiprocessing_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
"""

import torch
import random
import os
import signal
import parlai.utils.distributed as distributed_utils
Expand Down Expand Up @@ -88,7 +87,7 @@ def setup_args(cls):
return setup_args()

def run(self):
port = random.randint(32000, 48000)
port = distributed_utils.find_free_port()
return launch_and_eval(self.opt, port)


Expand Down
7 changes: 4 additions & 3 deletions parlai/scripts/multiprocessing_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"""

import torch
import random
import os
import signal
import traceback
Expand Down Expand Up @@ -55,10 +54,12 @@ def multiprocess_train(
raise


def launch_and_train(opt, port):
def launch_and_train(opt, port=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will we ever specify a port here?

"""
Perform a fork() to many processes.
"""
if port is None:
port = distributed_utils.find_free_port()
# Launch multiple subprocesses
spawncontext = torch.multiprocessing.start_processes(
multiprocess_train,
Expand Down Expand Up @@ -99,7 +100,7 @@ def setup_args(cls):

def run(self):
if self.opt['port'] is None:
port = random.randint(32000, 48000)
port = None
else:
port = self.opt['port']
return launch_and_train(self.opt, port)
Expand Down
37 changes: 17 additions & 20 deletions parlai/scripts/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,17 +442,20 @@ def save_model(self, suffix=None):
"""
Save the model to disk, possibly with a suffix.
"""
if not is_primary_worker():
# never do IO as a non-primary worker
return

if not self.opt.get('model_file'):
# nothing to save to, just exit
return

fn = self.opt['model_file']
if suffix:
fn += suffix

if not is_primary_worker():
# never do IO as a non-primary worker
if hasattr(self.agent, 'save_nonprimary'):
self.agent.save_nonprimary(fn)
return

while True:
# don't ever let a ctrl-c interrupt saving
try:
Expand Down Expand Up @@ -543,7 +546,7 @@ def validate(self):
)
self.best_valid = new_valid
self.impatience = 0
if opt.get('model_file') and is_primary_worker():
if opt.get('model_file'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just making sure I understand - we can get rid of this check because it's handled in save_model right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to be able do save_on_nonprimary_worker actually

logging.info(f"saving best valid model: {opt['model_file']}")
self.save_model()
self.saved = True
Expand All @@ -566,11 +569,7 @@ def validate(self):
self.validate_time.reset()

# saving
if (
opt.get('model_file')
and opt.get('save_after_valid')
and is_primary_worker()
):
if opt.get('model_file') and opt.get('save_after_valid'):
logging.info(f"saving model checkpoint: {opt['model_file']}.checkpoint")
self.save_model('.checkpoint')

Expand Down Expand Up @@ -720,24 +719,26 @@ def _get_time(self, world: World) -> Tuple[float, float, float]:
self._total_epochs = self._preempted_epochs + sum(
all_gather_list(world.get_total_epochs())
)
train_time, log_time, validate_time = sync_object(
train_time, log_time, validate_time, save_time = sync_object(
(
self.train_time.time(),
self.log_time.time(),
self.validate_time.time(),
self.save_time.time(),
)
)
else:
train_time, log_time, validate_time = (
train_time, log_time, validate_time, save_time = (
self.train_time.time(),
self.log_time.time(),
self.validate_time.time(),
self.save_time.time(),
)
self._total_epochs = self._preempted_epochs + (
num_workers() * world.get_total_epochs()
)

return train_time, log_time, validate_time
return train_time, log_time, validate_time, save_time

def log(self):
"""
Expand Down Expand Up @@ -810,7 +811,7 @@ def train_steps(self):
self._last_log_steps += 1 / self.update_freq

# the following additionally updates self._total_epochs
train_time, log_time, validate_time = self._get_time(world)
train_time, log_time, validate_time, save_time = self._get_time(world)
# get the total training examples done, compute epochs
exs_per_epoch = world.num_examples()
self._total_exs = int(np.round(self._total_epochs * exs_per_epoch))
Expand Down Expand Up @@ -859,11 +860,7 @@ def train_steps(self):
break
# make sure metrics are clean before we log
world.reset_metrics()
if (
self.save_time.time() > self.save_every_n_secs
and opt.get('model_file')
and is_primary_worker()
):
if save_time > self.save_every_n_secs and opt.get('model_file'):
logging.info(
f"saving model checkpoint: {opt['model_file']}.checkpoint"
)
Expand All @@ -872,7 +869,7 @@ def train_steps(self):
self.save_model('.checkpoint')
self.save_time.reset()

if not self.saved and is_primary_worker():
if not sync_object(self.saved):
# save agent
self.save_model()

Expand Down
25 changes: 25 additions & 0 deletions parlai/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,19 @@ def distributed_context(
dist.destroy_process_group()


def get_dist_group():
"""
Find the default pytorch distributed group.

Used within FSDP to mark which workers are participating. Important to manually call
this because FSDP will cache old groups, but our test suite will instantiate new
groups per test.
"""
from torch.distributed.distributed_c10d import _get_default_group

return _get_default_group()


@contextlib.contextmanager
def slurm_distributed_context(opt):
"""
Expand Down Expand Up @@ -346,3 +359,15 @@ def slurm_distributed_context(opt):
except FileNotFoundError:
# Slurm is not installed
raise RuntimeError('SLURM does not appear to be installed.')


def find_free_port() -> int:
"""
Find a free port we can bind to locally.
stephenroller marked this conversation as resolved.
Show resolved Hide resolved

Credit: https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number
"""
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
Loading