diff --git a/parlai/agents/transformer/modules/decoder.py b/parlai/agents/transformer/modules/decoder.py index 7ba4968195a..52ee3a80cc1 100644 --- a/parlai/agents/transformer/modules/decoder.py +++ b/parlai/agents/transformer/modules/decoder.py @@ -25,6 +25,7 @@ from parlai.core.opt import Opt from parlai.utils.misc import warn_once from parlai.utils.torch import PipelineHelper +from parlai.utils.fsdp import fsdp_wrap @swappable( @@ -277,16 +278,15 @@ def _default(val, default): def build_layers(self) -> nn.ModuleList: layers = nn.ModuleList() for _ in range(self.n_layers): - layers.append( - self.swappables.layer( - self.opt, - attention_dropout=self.opt.get('attention_dropout', 0.0), - relu_dropout=self.opt.get('relu_dropout', 0.0), - dropout=self.opt.get('dropout', 0.0), - activation=self.activation, - variant=self.variant, - ) # type: ignore + layer = self.swappables.layer( + self.opt, + attention_dropout=self.opt.get('attention_dropout', 0.0), + relu_dropout=self.opt.get('relu_dropout', 0.0), + dropout=self.opt.get('dropout', 0.0), + activation=self.activation, + variant=self.variant, ) + layers.append(fsdp_wrap(layer)) # type: ignore return layers def forward_embedding( diff --git a/parlai/agents/transformer/modules/encoder.py b/parlai/agents/transformer/modules/encoder.py index b79981fd9eb..441d13112f9 100644 --- a/parlai/agents/transformer/modules/encoder.py +++ b/parlai/agents/transformer/modules/encoder.py @@ -25,6 +25,7 @@ from parlai.core.opt import Opt from parlai.utils.misc import warn_once from parlai.utils.torch import PipelineHelper +from parlai.utils.fsdp import fsdp_wrap @swappable(self_attention=MultiHeadAttention, feedforward=TransformerFFN) @@ -227,16 +228,15 @@ def _default(val, default): def build_layers(self) -> nn.ModuleList: layers = nn.ModuleList() for _ in range(self.n_layers): - layers.append( - self.swappables.layer( # type: ignore - self.opt, - attention_dropout=self.opt.get('attention_dropout', 0.0), - relu_dropout=self.opt.get('relu_dropout', 0.0), - dropout=self.dropout_frac, - variant=self.variant, - activation=self.activation, - ) + layer = self.swappables.layer( # type: ignore + self.opt, + attention_dropout=self.opt.get('attention_dropout', 0.0), + relu_dropout=self.opt.get('relu_dropout', 0.0), + dropout=self.dropout_frac, + variant=self.variant, + activation=self.activation, ) + layers.append(fsdp_wrap(layer)) return layers def forward_embedding( diff --git a/parlai/core/params.py b/parlai/core/params.py index 6465df14f6b..281439d185d 100644 --- a/parlai/core/params.py +++ b/parlai/core/params.py @@ -772,6 +772,16 @@ def add_distributed_training_args(self): grp.add_argument( '--distributed-world-size', type=int, help='Number of workers.' ) + grp.add_argument( + '--ddp-backend', + # TODO: add in zero3. https://github.com/facebookresearch/ParlAI/issues/3753 + choices=['ddp', 'zero2'], + default='ddp', + help=( + 'Distributed backend. Zero2 can be faster but is more experimental. ' + 'DDP is the most tested.' + ), + ) return grp def add_model_args(self): diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index 5994b344754..af9c3a93ab3 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -36,6 +36,7 @@ from parlai.utils.distributed import is_distributed from parlai.utils.misc import AttrDict, warn_once from parlai.utils.io import PathManager +from parlai.utils.fsdp import should_sync_gradnorm, is_fsdp, DEFAULT_DDP_BACKEND from parlai.utils.fp16 import ( SafeFP16Optimizer, MemoryEfficientFP16Optimizer, @@ -1052,7 +1053,9 @@ def init_optim( self.optimizer = optim_class(params, **kwargs) if self.fp16: if self.fp16_impl == 'safe': - self.optimizer = SafeFP16Optimizer(self.optimizer) + self.optimizer = SafeFP16Optimizer( + self.optimizer, should_sync_gradnorm(opt) + ) else: # Using memory efficient optimizer opt_name = opt['optimizer'] @@ -1064,7 +1067,9 @@ def init_optim( 'with Memory Efficient FP16. Please select from among this ' f'list:\n{compatible_list}' ) - self.optimizer = MemoryEfficientFP16Optimizer(self.optimizer) + self.optimizer = MemoryEfficientFP16Optimizer( + self.optimizer, should_sync_gradnorm(opt) + ) if is_finetune: logging.warning('Detected a fine-tune run. Resetting the optimizer.') @@ -1969,10 +1974,11 @@ def state_dict(self): """ states = {} if hasattr(self, 'model'): # save model params - if hasattr(self.model, 'module'): - # did we wrap in a DistributedDataParallel + if hasattr(self.model, 'module') and not is_fsdp(self.model): + # did we wrap in a DistributedDataParallel or DataParallel states['model'] = self.model.module.state_dict() else: + # regular model or FSDP states['model'] = self.model.state_dict() if hasattr(self, 'optimizer'): @@ -1992,6 +1998,16 @@ def state_dict(self): return states + def save_nonprimary(self, path=None): + """ + Save model parameters, when you are working on the non-primary worker. + + For models or optimizers that shard parameters, this ensures we sync. + """ + if self.opt.get('ddp_backend', DEFAULT_DDP_BACKEND) in ('zero2', 'zero3'): + # make sure we call the state dict + self.state_dict() + def save(self, path=None): """ Save model parameters to path (or default to model_file arg). diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 88c40ac121e..eda8538e0e9 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -35,6 +35,7 @@ import parlai.utils.logging as logging from parlai.core.metrics import SumMetric, AverageMetric, FairseqBleuMetric from parlai.utils.fp16 import FP16SafeCrossEntropy +import parlai.utils.fsdp as fsdp_utils from parlai.utils.torch import ( neginf, total_parameters, @@ -479,8 +480,10 @@ def __init__(self, opt: Opt, shared=None): else: # this is not a shared instance of this class, so do full init self.criterion = self.build_criterion() - # ensure all distributed copies will always be in sync - self.model = self.build_model() + with fsdp_utils.maybe_fsdp_wrap(opt): + self.model = fsdp_utils.fsdp_wrap(self.build_model()) + if self.fp16 and not fsdp_utils.delay_halving(opt): + self.model = self.model.half() # load the block_list for beam search self.beam_block_list = self._load_beam_block_list() @@ -498,16 +501,15 @@ def __init__(self, opt: Opt, shared=None): self.model.cuda() self.criterion.cuda() - sync_parameters(self.model) + if not fsdp_utils.is_fsdp(self.model): + sync_parameters(self.model) + train_params = trainable_parameters(self.model) total_params = total_parameters(self.model) logging.info( f"Total parameters: {total_params:,d} ({train_params:,d} trainable)" ) - if self.fp16: - self.model = self.model.half() - if init_model is not None: # load model parameters if available logging.info(f'Loading existing model params from {init_model}') @@ -530,7 +532,11 @@ def __init__(self, opt: Opt, shared=None): logging.warning("Optimizer was reset. Also resetting LR scheduler.") self.build_lr_scheduler(states, hard_reset=is_finetune or was_reset) - if shared is None and is_distributed(): + if ( + shared is None + and is_distributed() + and opt.get('ddp_backend', fsdp_utils.DEFAULT_DDP_BACKEND) == 'ddp' + ): device_ids = None if self.model_parallel else [self.opt['gpu']] self.model = torch.nn.parallel.DistributedDataParallel( self.model, device_ids=device_ids, broadcast_buffers=False diff --git a/parlai/scripts/multiprocessing_eval.py b/parlai/scripts/multiprocessing_eval.py index 2b220fc1a24..bfc7bdc34b7 100644 --- a/parlai/scripts/multiprocessing_eval.py +++ b/parlai/scripts/multiprocessing_eval.py @@ -23,7 +23,6 @@ """ import torch -import random import os import signal import parlai.utils.distributed as distributed_utils @@ -88,7 +87,7 @@ def setup_args(cls): return setup_args() def run(self): - port = random.randint(32000, 48000) + port = distributed_utils.find_free_port() return launch_and_eval(self.opt, port) diff --git a/parlai/scripts/multiprocessing_train.py b/parlai/scripts/multiprocessing_train.py index 14c2e305846..543d316b01a 100644 --- a/parlai/scripts/multiprocessing_train.py +++ b/parlai/scripts/multiprocessing_train.py @@ -24,7 +24,6 @@ """ import torch -import random import os import signal import traceback @@ -55,10 +54,12 @@ def multiprocess_train( raise -def launch_and_train(opt, port): +def launch_and_train(opt, port=None): """ Perform a fork() to many processes. """ + if port is None: + port = distributed_utils.find_free_port() # Launch multiple subprocesses spawncontext = torch.multiprocessing.start_processes( multiprocess_train, @@ -99,7 +100,7 @@ def setup_args(cls): def run(self): if self.opt['port'] is None: - port = random.randint(32000, 48000) + port = None else: port = self.opt['port'] return launch_and_train(self.opt, port) diff --git a/parlai/scripts/train_model.py b/parlai/scripts/train_model.py index f6e9c2d7321..64a3d8c92b3 100644 --- a/parlai/scripts/train_model.py +++ b/parlai/scripts/train_model.py @@ -442,10 +442,6 @@ def save_model(self, suffix=None): """ Save the model to disk, possibly with a suffix. """ - if not is_primary_worker(): - # never do IO as a non-primary worker - return - if not self.opt.get('model_file'): # nothing to save to, just exit return @@ -453,6 +449,13 @@ def save_model(self, suffix=None): fn = self.opt['model_file'] if suffix: fn += suffix + + if not is_primary_worker(): + # never do IO as a non-primary worker + if hasattr(self.agent, 'save_nonprimary'): + self.agent.save_nonprimary(fn) + return + while True: # don't ever let a ctrl-c interrupt saving try: @@ -543,7 +546,7 @@ def validate(self): ) self.best_valid = new_valid self.impatience = 0 - if opt.get('model_file') and is_primary_worker(): + if opt.get('model_file'): logging.info(f"saving best valid model: {opt['model_file']}") self.save_model() self.saved = True @@ -566,11 +569,7 @@ def validate(self): self.validate_time.reset() # saving - if ( - opt.get('model_file') - and opt.get('save_after_valid') - and is_primary_worker() - ): + if opt.get('model_file') and opt.get('save_after_valid'): logging.info(f"saving model checkpoint: {opt['model_file']}.checkpoint") self.save_model('.checkpoint') @@ -720,24 +719,26 @@ def _get_time(self, world: World) -> Tuple[float, float, float]: self._total_epochs = self._preempted_epochs + sum( all_gather_list(world.get_total_epochs()) ) - train_time, log_time, validate_time = sync_object( + train_time, log_time, validate_time, save_time = sync_object( ( self.train_time.time(), self.log_time.time(), self.validate_time.time(), + self.save_time.time(), ) ) else: - train_time, log_time, validate_time = ( + train_time, log_time, validate_time, save_time = ( self.train_time.time(), self.log_time.time(), self.validate_time.time(), + self.save_time.time(), ) self._total_epochs = self._preempted_epochs + ( num_workers() * world.get_total_epochs() ) - return train_time, log_time, validate_time + return train_time, log_time, validate_time, save_time def log(self): """ @@ -810,7 +811,7 @@ def train_steps(self): self._last_log_steps += 1 / self.update_freq # the following additionally updates self._total_epochs - train_time, log_time, validate_time = self._get_time(world) + train_time, log_time, validate_time, save_time = self._get_time(world) # get the total training examples done, compute epochs exs_per_epoch = world.num_examples() self._total_exs = int(np.round(self._total_epochs * exs_per_epoch)) @@ -859,11 +860,7 @@ def train_steps(self): break # make sure metrics are clean before we log world.reset_metrics() - if ( - self.save_time.time() > self.save_every_n_secs - and opt.get('model_file') - and is_primary_worker() - ): + if save_time > self.save_every_n_secs and opt.get('model_file'): logging.info( f"saving model checkpoint: {opt['model_file']}.checkpoint" ) @@ -872,7 +869,7 @@ def train_steps(self): self.save_model('.checkpoint') self.save_time.reset() - if not self.saved and is_primary_worker(): + if not sync_object(self.saved): # save agent self.save_model() diff --git a/parlai/utils/distributed.py b/parlai/utils/distributed.py index 27a9a240b55..2088b3451dc 100644 --- a/parlai/utils/distributed.py +++ b/parlai/utils/distributed.py @@ -296,6 +296,19 @@ def distributed_context( dist.destroy_process_group() +def get_dist_group(): + """ + Find the default pytorch distributed group. + + Used within FSDP to mark which workers are participating. Important to manually call + this because FSDP will cache old groups, but our test suite will instantiate new + groups per test. + """ + from torch.distributed.distributed_c10d import _get_default_group + + return _get_default_group() + + @contextlib.contextmanager def slurm_distributed_context(opt): """ @@ -346,3 +359,15 @@ def slurm_distributed_context(opt): except FileNotFoundError: # Slurm is not installed raise RuntimeError('SLURM does not appear to be installed.') + + +def find_free_port() -> int: + """ + Find a free port we can bind to locally. + + Credit: https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number + """ + with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(('', 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] diff --git a/parlai/utils/fp16.py b/parlai/utils/fp16.py index 073f7757fb4..00d03d36744 100644 --- a/parlai/utils/fp16.py +++ b/parlai/utils/fp16.py @@ -55,27 +55,48 @@ def forward(self, scores, targets): ) -def clip_grad_norm(params, max_norm): +def clip_grad_norm(params, max_norm: float = 0, sync: bool = False): """ - Clips grad norm. + Clips grad norms. + + During combination with FSDP, will also ensure that grad norms are aggregated + across all workers, since each worker only stores their shard of the + gradients. + + :param params: + Parameters whose gradients we wish to clip + :param max_norm: + Maximum norm we wish the gradients to have. If non-positive, then + we will not perform clipping. + :param sync: + Boolean indicating whether we should aggregate across the distributed + group. Used only in combination with FSDP. + + :returns: + The gradient norm across all parameters, before clipping. """ if isinstance(params, torch.Tensor): params = [params] # make sure any generators are expanded params = list(params) - if len(params) == 1: - p = params[0].grad - grad_norm = torch.norm(p) - if grad_norm > max_norm > 0: - clip_coef = max_norm / (grad_norm + 1e-6) - p.mul_(clip_coef) - return grad_norm - elif max_norm > 0: + # if syncing we need to manually perform the clipping so that we aggregrate + # properly + if max_norm > 0 and not sync: return torch.nn.utils.clip_grad_norm_(params, max_norm) else: - return torch.sqrt( - sum(p.grad.data.norm() ** 2 for p in params if p.grad is not None) - ) + normsq = sum(p.grad.data.norm() ** 2 for p in params if p.grad is not None) + if sync: + # also need to get the norms from all the other sharded works in FSDP + import torch.distributed as dist + + dist.all_reduce(normsq) + grad_norm = normsq.sqrt() + if max_norm > 0: + clip_coef = max_norm / (grad_norm + 1e-6) + for p in params: + p.grad.detach().mul_(clip_coef) + + return grad_norm def has_overflow(grad_norm): @@ -88,7 +109,7 @@ def has_overflow(grad_norm): class SafeFP16Optimizer(torch.optim.Optimizer): - def __init__(self, optimizer): + def __init__(self, optimizer, aggregate_gnorms=False): self.fp16_params = self._get_parameters(optimizer) self.fp32_params = self._build_fp32_params(self.fp16_params, flatten=False) self.optimizer = optimizer @@ -103,6 +124,7 @@ def __init__(self, optimizer): self.scaler = DynamicLossScaler(2.0 ** 15) self.min_loss_scale = 2 ** -5 + self._aggregate_gnorms = aggregate_gnorms @classmethod def _get_parameters(cls, optimizer): @@ -210,7 +232,9 @@ def clip_master_grads(self, max_norm): Clips gradient norm and updates dynamic loss scaler. """ self._sync_fp16_grads_to_fp32() - grad_norm = clip_grad_norm(self.fp32_params, max_norm) + grad_norm = clip_grad_norm( + self.fp32_params, max_norm, sync=self._aggregate_gnorms + ) # detect overflow and adjust loss scale if self.scaler is not None: @@ -390,6 +414,7 @@ class MemoryEfficientFP16Optimizer(torch.optim.Optimizer): def __init__( self, init_optimizer: torch.optim.Optimizer, # type: ignore + aggregate_gnorms: bool = False, loss_initial_scale: float = 2.0 ** 17, min_loss_scale: float = 1e-4, ): @@ -398,6 +423,8 @@ def __init__( self.min_loss_scale = min_loss_scale self.scaler = DynamicLossScaler(init_scale=loss_initial_scale) + self._aggregate_gnorms = aggregate_gnorms + @staticmethod def compatible_optimizers(): """ @@ -446,7 +473,9 @@ def clip_master_grads(self, gradient_clip): Returns -1 if the most recently computed gradients overflowed. """ self._unscale_grads() - grad_norm = clip_grad_norm(self.params, gradient_clip) + grad_norm = clip_grad_norm( + self.params, gradient_clip, sync=self._aggregate_gnorms + ) # detect overflow and adjust loss scale overflow = has_overflow(grad_norm) self.scaler.update_scale(overflow) diff --git a/parlai/utils/fsdp.py b/parlai/utils/fsdp.py new file mode 100644 index 00000000000..e2fb305f372 --- /dev/null +++ b/parlai/utils/fsdp.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utility functions for FullyShardedDataParallel. +""" + +import contextlib +import torch.nn +from parlai.utils.distributed import is_distributed, get_dist_group + +try: + from fairscale.nn.wrap.auto_wrap import wrap + from fairscale.nn.wrap.auto_wrap import enable_wrap as fairscale_enable_wrap + from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP + + FSDP_AVAILABLE = True +except ImportError: + FSDP_AVAILABLE = False + + def wrap(module, **kwargs): + return module + + +DEFAULT_DDP_BACKEND = "ddp" + + +def is_fsdp(module: torch.nn.Module): + """ + Checks whether a module is fully sharded. + """ + return FSDP_AVAILABLE and isinstance(module, FSDP) + + +def should_use_fsdp(opt): + return ( + FSDP_AVAILABLE + and is_distributed() + and opt.get('ddp_backend', DEFAULT_DDP_BACKEND) in ('zero2', 'zero3') + ) + + +@contextlib.contextmanager +def maybe_fsdp_wrap(opt): + """ + Context manager for enabling wrapping in FullyShardedDataParallel. + """ + if not should_use_fsdp(opt): + # make a no-op + yield + return + + # zero3 not supported at this time. Throw an exception + if opt['ddp_backend'] == 'zero3': + raise NotImplementedError( + '--ddp-backend zero3 is not supported at this time. For details, see ' + 'https://github.com/facebookresearch/ParlAI/issues/3753.' + ) + + reshard_after_forward = opt['ddp_backend'] == 'zero3' + compute_dtype = torch.float16 if opt['fp16'] else torch.float32 + mixed_precision = opt['fp16'] and opt['fp16_impl'] == 'safe' + fsdp_args = dict( + reshard_after_forward=reshard_after_forward, + mixed_precision=mixed_precision, + compute_dtype=compute_dtype, + state_dict_device=torch.device('cpu'), + flatten_parameters=True, + process_group=get_dist_group(), + ) + with fairscale_enable_wrap(wrapper_cls=FSDP, **fsdp_args): + yield + + +def delay_halving(opt): + """ + Check whether we should keep the model in fp32 before other setup. + + When using Zero2 or Zero3 backends with mixed precision, we need to avoid converting + the model to fp16, as the FSDP module does this for us. + + If we are using just plain DDP or MemoryEfficient optimizers, then we want + to call half() early. + """ + + return opt['fp16'] and should_use_fsdp(opt) and opt['fp16_impl'] == 'safe' + + +def should_sync_gradnorm(opt): + """ + Indicates whether fp16 optimizer wrappers should accumulate over workers. + + FP16 overflow detection and gradient clipping both require accumulating gradients + across all workers when using FSDP, as workers only store a fraction of the + gradients. + """ + return ( + FSDP_AVAILABLE + and opt['fp16'] + and opt.get('ddp_backend', DEFAULT_DDP_BACKEND) in ('zero2', 'zero3') + ) + + +def fsdp_wrap(module): + """ + Helper function for wrapping the outermost root module. + """ + return wrap(module) diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 0a9372412f4..d1fda5d36d7 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import os -import copy import unittest import parlai.utils.testing as testing_utils import parlai.scripts.build_dict as build_dict @@ -15,21 +14,30 @@ BATCHSIZE = 4 -def _forced_parse(parser, opt): - parser.set_params(**opt) - parser.set_params(log_every_n_sec=10) - popt = parser.parse_args([]) - # in some rare cases, like for instance if the model class also - # overrides its default params, the params override will not - # be taken into account. - for k, v in opt.items(): - popt[k] = v - return popt +class _AbstractTest(unittest.TestCase): + def _distributed_train_model(self, **overrides): + opt = {**self.base_config, **overrides} + with testing_utils.tempdir() as tmpdir: + if 'model_file' not in opt: + opt['model_file'] = os.path.join(tmpdir, 'model') + if 'dict_file' not in opt: + opt['dict_file'] = os.path.join(tmpdir, 'model.dict') + + parser = mp_train.setup_args() + popt = parser.parse_kwargs(**opt) + + # we need a prebuilt dictionary + parser = build_dict.setup_args() + build_dict.build_dict(popt) + + valid, test = mp_train.launch_and_train(popt) + + return (valid, test) @testing_utils.skipUnlessGPU -class TestDistributed(unittest.TestCase): - _base_config = dict( +class TestDistributed(_AbstractTest): + base_config = dict( task='integration_tests:overfit', model='transformer/generator', optimizer='adam', @@ -46,30 +54,8 @@ class TestDistributed(unittest.TestCase): verbose=True, ) - def setUp(self): - print(f'[Setting up test {self._testMethodName}]') - - def _distributed_train_model(self, opt): - with testing_utils.tempdir() as tmpdir: - if 'model_file' not in opt: - opt['model_file'] = os.path.join(tmpdir, 'model') - if 'dict_file' not in opt: - opt['dict_file'] = os.path.join(tmpdir, 'model.dict') - - parser = mp_train.setup_args() - popt = _forced_parse(parser, opt) - - # we need a prebuilt dictionary - parser = build_dict.setup_args() - build_dict.build_dict(popt) - - valid, test = mp_train.launch_and_train(popt, 31338) - - return (valid, test) - def test_generator_distributed(self): - config = copy.deepcopy(self._base_config) - valid, test = self._distributed_train_model(config) + valid, test = self._distributed_train_model() self.assertLessEqual(valid['ppl'], 1.60) self.assertLessEqual(test['ppl'], 1.60) @@ -80,11 +66,11 @@ def test_generator_distributed(self): self.assertEqual(test['exs'].value(), BATCHSIZE) def test_multitask_distributed(self): - config = copy.deepcopy(self._base_config) - config['num_epochs'] = 50 - config['task'] = 'integration_tests:overfit,integration_tests:overfit_multiturn' - config['dynb'] = 'full' - valid, test = self._distributed_train_model(config) + valid, test = self._distributed_train_model( + num_epochs=50, + task='integration_tests:overfit,integration_tests:overfit_multiturn', + truncate=16, + ) self.assertLessEqual(valid['ppl'], 1.20) self.assertLessEqual(test['ppl'], 1.20) @@ -100,12 +86,12 @@ def test_multitask_distributed(self): ) def test_distributed_eval_max_exs(self): - config = copy.deepcopy(self._base_config) - config['task'] = 'integration_tests' - config['num_epochs'] = 0.01 - config['validation_max_exs'] = 90 - config['short_final_eval'] = True - valid, test = self._distributed_train_model(config) + valid, test = self._distributed_train_model( + task='integration_tests', + num_epochs=0.01, + validation_max_exs=90, + short_final_eval=True, + ) # Tests that DialogData.get() is doing the right thing # Ensure no duplication of examples among workers @@ -120,11 +106,9 @@ def test_distributed_eval_max_exs(self): self.assertEqual(test['exs'].value(), 96) def test_distributed_eval_stream_mode(self): - config = copy.deepcopy(self._base_config) - config['task'] = 'integration_tests' - config['num_epochs'] = 0.01 - config['datatype'] = 'train:stream' - valid, test = self._distributed_train_model(config) + valid, test = self._distributed_train_model( + task='integration_tests', num_epochs=0.01, datatype='train:stream' + ) # Tests that StreamDialogData.get() is doing the right thing # Ensure no duplication of examples among workers @@ -133,14 +117,13 @@ def test_distributed_eval_stream_mode(self): self.assertEqual(test['exs'].value(), inttests.NUM_TEST) def test_distributed_eval_stream_mode_max_exs(self): - config = copy.deepcopy(self._base_config) - config['task'] = 'integration_tests' - config['num_epochs'] = 0.01 - config['datatype'] = 'train:stream' - config['validation_max_exs'] = 90 - config['short_final_eval'] = True - - valid, test = self._distributed_train_model(config) + valid, test = self._distributed_train_model( + task='integration_tests', + num_epochs=0.01, + datatype='train:stream', + validation_max_exs=90, + short_final_eval=True, + ) # Tests that StreamDialogData.get() is doing the right thing # Ensure no duplication of examples among workers @@ -155,45 +138,68 @@ def test_distributed_eval_stream_mode_max_exs(self): self.assertEqual(test['exs'].value(), 96) def test_chunked_dynamic_teacher(self): - config = copy.deepcopy(self._base_config) - config['task'] = 'integration_tests' - config['num_epochs'] = 0.01 - config['datatype'] = 'train:stream' - config['dynamic_batching'] = 'full' - config['truncate'] = 16 - - valid, test = self._distributed_train_model(config) + valid, test = self._distributed_train_model( + task='integration_tests', + num_epochs=0.01, + datatype='train:stream', + dynamic_batching='full', + truncate=16, + ) assert valid['exs'].value() == inttests.NUM_TEST assert test['exs'].value() == inttests.NUM_TEST def test_chunked_teacher(self): - config = copy.deepcopy(self._base_config) - config['task'] = 'integration_tests' - config['num_epochs'] = 0.01 - config['datatype'] = 'train:stream' - config['num_epochs'] = 5 - config['dynamic_batching'] = None - - valid, test = self._distributed_train_model(config) + valid, test = self._distributed_train_model( + task='integration_tests', + datatype='train:stream', + num_epochs=5, + dynamic_batching=None, + ) assert valid['exs'].value() == inttests.NUM_TEST assert test['exs'].value() == inttests.NUM_TEST + +@testing_utils.skipUnlessGPU +class TestZero2(TestDistributed): + """ + Integration tests for zero2 FSDP. + """ + + base_config = {**TestDistributed.base_config, 'ddp_backend': 'zero2'} + + +@unittest.skip +@testing_utils.skipUnlessGPU +class TestZero3(TestDistributed): + # Not supported at this time. See: + # https://github.com/facebookresearch/ParlAI/pull/3740 + base_config = {**TestDistributed.base_config, 'ddp_backend': 'zero3'} + + +@testing_utils.skipUnlessGPU +class TestNoModelParallel(_AbstractTest): + base_config = dict( + task='integration_tests:overfit', + optimizer='sgd', + validation_metric='loss', + learningrate=1e-2, + batchsize=BATCHSIZE, + validation_every_n_epochs=1, + num_epochs=1, + n_layers=1, + n_heads=1, + ffn_size=32, + embedding_size=8, + verbose=True, + ) + def test_no_model_parallel(self): """ - Checks that we throw an error when combining mp_train with. - - --model-parallel true. + Checks that we throw an error when combining mp_train with --model-parallel. """ - config = copy.deepcopy(self._base_config) - config['model_parallel'] = True - for m in [ - 'transformer/generator', - 'transformer/ranker', - 'transformer/classifier', - ]: - config['model'] = m + for m in ['transformer/generator', 'transformer/ranker']: try: - _ = self._distributed_train_model(config) + _ = self._distributed_train_model(model=m, model_parallel=True) except RuntimeError: pass else: