diff --git a/.circleci/config.yml b/.circleci/config.yml index 33eb0d6050e..07ebbd088f4 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -219,26 +219,26 @@ commands: - setupcuda - fixgit - restore_cache: - key: deps-20221130-<< parameters.cachename >>-{{ checksum "requirements.txt" }} + key: deps-20221202-<< parameters.cachename >>-{{ checksum "requirements.txt" }} - setup - installdeps - << parameters.more_installs >> - save_cache: - key: deps-20221130-<< parameters.cachename >>-{{ checksum "requirements.txt" }} + key: deps-20221202-<< parameters.cachename >>-{{ checksum "requirements.txt" }} paths: - "~/venv/bin" - "~/venv/lib" - findtests: marker: << parameters.marker >> - restore_cache: - key: data-20221130-<< parameters.cachename >>-{{ checksum "teststorun.txt" }} + key: data-20221202-<< parameters.cachename >>-{{ checksum "teststorun.txt" }} - run: name: Run tests no_output_timeout: 60m command: | coverage run -m pytest -m << parameters.marker >> << parameters.pytest_flags >> --junitxml=test-results/junit.xml - save_cache: - key: data-20221130-<< parameters.cachename >>-{{ checksum "teststorun.txt" }} + key: data-20221202-<< parameters.cachename >>-{{ checksum "teststorun.txt" }} paths: - "~/ParlAI/data" - codecov @@ -255,12 +255,12 @@ commands: - checkout - fixgit - restore_cache: - key: deps-20221130-bw-{{ checksum "requirements.txt" }} + key: deps-20221202-bw-{{ checksum "requirements.txt" }} - setup - installdeps - installtorchgpu - save_cache: - key: deps-20221130-bw-{{ checksum "requirements.txt" }} + key: deps-20221202-bw-{{ checksum "requirements.txt" }} paths: - "~/venv/bin" - "~/venv/lib" diff --git a/parlai/agents/hugging_face/t5.py b/parlai/agents/hugging_face/t5.py index 1c22f063a86..7c0d52bdb30 100644 --- a/parlai/agents/hugging_face/t5.py +++ b/parlai/agents/hugging_face/t5.py @@ -27,7 +27,7 @@ from parlai.core.params import ParlaiParser from parlai.core.torch_agent import Batch, TorchAgent from parlai.core.torch_generator_agent import TorchGeneratorAgent, TorchGeneratorModel -from parlai.utils.fsdp import is_fsdp +from parlai.utils.fsdp import is_fsdp, delay_halving def check_hf_version(v: Tuple[int, int]) -> bool: @@ -41,7 +41,9 @@ def check_hf_version(v: Tuple[int, int]) -> bool: def build_t5(opt: Opt) -> T5ForConditionalGeneration: if not check_hf_version(HF_VERSION): raise RuntimeError('Must use transformers package >= 4.3 to use t5') - torch_dtype = torch.float16 if opt['fp16'] else torch.float32 + torch_dtype = ( + torch.float16 if (opt['fp16'] and not delay_halving(opt)) else torch.float32 + ) try: return T5ForConditionalGeneration.from_pretrained( opt['t5_model_arch'], @@ -369,7 +371,7 @@ def output(self, tensor): """ # Taken directly from HuggingFace # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + # See https://github.com/tensorflow/mesh/blob/fa19d69/mesh_tensorflow/transformer/transformer.py#L586 tensor = tensor * (self.t5.model_dim**-0.5) lm_logits = self.t5.lm_head(tensor) return lm_logits diff --git a/parlai/core/params.py b/parlai/core/params.py index d2084c2f92d..c1f00248656 100644 --- a/parlai/core/params.py +++ b/parlai/core/params.py @@ -774,11 +774,11 @@ def add_distributed_training_args(self): ) grp.add_argument( '--ddp-backend', - # TODO: add in zero3. https://github.com/facebookresearch/ParlAI/issues/3753 - choices=['ddp', 'zero2'], + choices=['ddp', 'zero2', 'zero3'], default='ddp', help=( 'Distributed backend. Zero2 can be faster but is more experimental. ' + 'Zero3 significantly reduces memory pressure. ' 'DDP is the most tested.' ), ) diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index 5c3f55b7a0a..0a2f237908a 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -36,7 +36,13 @@ from parlai.utils.distributed import is_distributed from parlai.utils.misc import AttrDict, warn_once from parlai.utils.io import PathManager -from parlai.utils.fsdp import should_sync_gradnorm, is_fsdp, DEFAULT_DDP_BACKEND +from parlai.utils.fsdp import ( + should_sync_gradnorm, + is_fsdp, + DEFAULT_DDP_BACKEND, + FSDP_AVAILABLE, + get_state_dict, +) from parlai.utils.fp16 import ( SafeFP16Optimizer, MemoryEfficientFP16Optimizer, @@ -1981,8 +1987,11 @@ def state_dict(self): if hasattr(self.model, 'module') and not is_fsdp(self.model): # did we wrap in a DistributedDataParallel or DataParallel states['model'] = self.model.module.state_dict() + elif is_fsdp(self.model) and FSDP_AVAILABLE: + # FSDP Model; use fancy saving + states['model'] = get_state_dict(self.model) else: - # regular model or FSDP + # regular model states['model'] = self.model.state_dict() if hasattr(self, 'optimizer'): diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index b79895f338d..01df3a36740 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -34,7 +34,7 @@ from parlai.utils.misc import warn_once from parlai.utils.io import PathManager import parlai.utils.logging as logging -from parlai.core.metrics import Metric, SumMetric, AverageMetric, FairseqBleuMetric +from parlai.core.metrics import SumMetric, AverageMetric, FairseqBleuMetric from parlai.utils.fp16 import FP16SafeCrossEntropy import parlai.utils.fsdp as fsdp_utils from parlai.utils.torch import ( @@ -516,8 +516,10 @@ def __init__(self, opt: Opt, shared=None): else: # this is not a shared instance of this class, so do full init self.criterion = self.build_criterion() + + self.model = self.build_model() with fsdp_utils.maybe_fsdp_wrap(opt): - self.model = fsdp_utils.fsdp_wrap(self.build_model()) + self.model = fsdp_utils.fsdp_wrap(self.model) if self.fp16 and not fsdp_utils.delay_halving(opt): self.model = self.model.half() @@ -2054,7 +2056,7 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection class FactualNucleusSampling(NucleusSampling): """ - Factual Nucleus Sampling + Factual Nucleus Sampling. See https://arxiv.org/pdf/2206.04624.pdf for more information """ diff --git a/parlai/scripts/distributed_eval.py b/parlai/scripts/distributed_eval.py index f8a6f0f72ef..a0e218936b5 100644 --- a/parlai/scripts/distributed_eval.py +++ b/parlai/scripts/distributed_eval.py @@ -31,10 +31,10 @@ -m seq2seq -t convai2 --dict-file /path/to/dict-file ``` """ - -import parlai.scripts.eval_model as eval_model from parlai.core.script import ParlaiScript +import parlai.scripts.eval_model as eval_model import parlai.utils.distributed as distributed_utils +import parlai.utils.fsdp as fsdp_utils def setup_args(): @@ -51,7 +51,9 @@ def setup_args(cls): def run(self): with distributed_utils.slurm_distributed_context(self.opt) as opt: - return eval_model.eval_model(opt) + self.evaluator = fsdp_utils.JoinableEvaluator(opt) + with fsdp_utils.fsdp_join(self.evaluator): + return self.evaluator.eval_model() if __name__ == '__main__': diff --git a/parlai/scripts/distributed_train.py b/parlai/scripts/distributed_train.py index cc846dde395..72b18482490 100644 --- a/parlai/scripts/distributed_train.py +++ b/parlai/scripts/distributed_train.py @@ -35,6 +35,7 @@ import parlai.scripts.train_model as single_train from parlai.core.script import ParlaiScript import parlai.utils.distributed as distributed_utils +import parlai.utils.fsdp as fsdp_utils def setup_args(): @@ -51,8 +52,9 @@ def setup_args(cls): def run(self): with distributed_utils.slurm_distributed_context(self.opt) as opt: - self.train_loop = single_train.TrainLoop(opt) - return self.train_loop.train() + self.train_loop = fsdp_utils.JoinableTrainLoop(opt) + with fsdp_utils.fsdp_join(self.train_loop): + return self.train_loop.train() if __name__ == '__main__': diff --git a/parlai/scripts/eval_model.py b/parlai/scripts/eval_model.py index 34bca53e92c..96b880e71e9 100644 --- a/parlai/scripts/eval_model.py +++ b/parlai/scripts/eval_model.py @@ -24,6 +24,7 @@ aggregate_unnamed_reports, Metric, ) +from parlai.core.opt import Opt from parlai.core.worlds import create_task from parlai.utils.misc import TimeLogger, nice_report from parlai.utils.world_logging import WorldLogger @@ -77,7 +78,10 @@ def setup_args(parser=None): '-auc', type=int, default=-1, - help='a positive number indicates to calculate the area under the roc curve and it also determines how many decimal digits of the predictions to keep (higher numbers->more precise); also used to determine whether or not to calculate the AUC metric', + help='a positive number indicates to calculate the area under the ' + 'roc curve and it also determines how many decimal digits of the ' + 'predictions to keep (higher numbers->more precise); also used ' + 'to determine whether or not to calculate the AUC metric', ) parser.add_argument( '--area-under-curve-class', @@ -291,6 +295,14 @@ def eval_model(opt): return report +class Evaluator: + def __init__(self, opt: Opt): + self.opt = opt + + def eval_model(self): + return eval_model(self.opt) + + @register_script('eval_model', aliases=['em', 'eval']) class EvalModel(ParlaiScript): @classmethod @@ -298,7 +310,8 @@ def setup_args(cls): return setup_args() def run(self): - return eval_model(self.opt) + self.evaluator = Evaluator(self.opt) + return self.evaluator.eval_model() if __name__ == '__main__': diff --git a/parlai/scripts/multiprocessing_eval.py b/parlai/scripts/multiprocessing_eval.py index bfc7bdc34b7..b71af97da80 100644 --- a/parlai/scripts/multiprocessing_eval.py +++ b/parlai/scripts/multiprocessing_eval.py @@ -25,8 +25,9 @@ import torch import os import signal -import parlai.utils.distributed as distributed_utils import parlai.scripts.eval_model as eval_model +import parlai.utils.distributed as distributed_utils +import parlai.utils.fsdp as fsdp_utils from parlai.core.script import ParlaiScript, register_script @@ -43,7 +44,9 @@ def multiprocess_eval( rank, opt, rank_offset, gpu, init_method=init_method ) as opt: opt['multiprocessing'] = True - return eval_model.eval_model(opt) + evaluator = fsdp_utils.JoinableEvaluator(opt) + with fsdp_utils.fsdp_join(evaluator): + return evaluator.eval_model() def launch_and_eval(opt, port): diff --git a/parlai/scripts/multiprocessing_train.py b/parlai/scripts/multiprocessing_train.py index 543d316b01a..029114285d2 100644 --- a/parlai/scripts/multiprocessing_train.py +++ b/parlai/scripts/multiprocessing_train.py @@ -29,6 +29,7 @@ import traceback import parlai.scripts.train_model as single_train import parlai.utils.distributed as distributed_utils +import parlai.utils.fsdp as fsdp_utils from parlai.core.script import ParlaiScript, register_script @@ -41,8 +42,10 @@ def multiprocess_train( ) as opt: # Run the actual training opt['multiprocessing'] = True + loop = fsdp_utils.JoinableTrainLoop(opt) try: - return single_train.TrainLoop(opt).train() + with fsdp_utils.fsdp_join(loop): + return loop.train() except Exception: import parlai.utils.logging as logging diff --git a/parlai/utils/fsdp.py b/parlai/utils/fsdp.py index e2fb305f372..ed0f0029b17 100644 --- a/parlai/utils/fsdp.py +++ b/parlai/utils/fsdp.py @@ -7,15 +7,32 @@ """ Utility functions for FullyShardedDataParallel. """ - import contextlib +import functools +import torch +import torch.distributed +from torch.distributed.algorithms.join import Join, Joinable, JoinHook import torch.nn + +from parlai.scripts.eval_model import Evaluator +from parlai.scripts.train_model import TrainLoop from parlai.utils.distributed import is_distributed, get_dist_group try: - from fairscale.nn.wrap.auto_wrap import wrap - from fairscale.nn.wrap.auto_wrap import enable_wrap as fairscale_enable_wrap - from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP + import torch + import torch.distributed + import torch.distributed.fsdp + from torch.distributed.fsdp.wrap import ( + wrap, + enable_wrap, + transformer_auto_wrap_policy, + ) + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullyShardedDataParallel as FSDP, + ShardingStrategy, + MixedPrecision, + BackwardPrefetch, + ) FSDP_AVAILABLE = True except ImportError: @@ -53,25 +70,73 @@ def maybe_fsdp_wrap(opt): yield return - # zero3 not supported at this time. Throw an exception - if opt['ddp_backend'] == 'zero3': - raise NotImplementedError( - '--ddp-backend zero3 is not supported at this time. For details, see ' - 'https://github.com/facebookresearch/ParlAI/issues/3753.' + mixed_precision = opt['fp16'] and opt['fp16_impl'] == 'safe' + + # settings as of pytorch 1.13 + # There is a warning in pytorch 1.13 for FSDP that is unavoidable; + # at the risk of suppressing valid warnings, just going to suppress that one. + import warnings + + warnings.filterwarnings("ignore") + + # sharding strategy determines zero2 or zero3 + sharding_strategy = ( + ShardingStrategy.FULL_SHARD + if opt['ddp_backend'] == 'zero3' + else ShardingStrategy.SHARD_GRAD_OP + ) + + # mp determines how to mix precision + if mixed_precision: + mp_strategy = MixedPrecision( + reduce_dtype=torch.float16, + param_dtype=torch.float16, + buffer_dtype=torch.float16, ) + else: + mp_strategy = None + + # autowrap policy. + auto_wrap_policy = None + ignored_modules = None + if opt['model'] in ['bart', 'transformer/generator']: + from parlai.agents.transformer.modules.encoder import ( + TransformerEncoderLayer, + ) + from parlai.agents.transformer.modules.decoder import ( + TransformerDecoderLayer, + ) + + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + TransformerEncoderLayer, + TransformerDecoderLayer, + }, + ) + + # backward prefetch; determines when to fetch the parameters during backward pass + # set to BACKWARD_PRE to increase throughput, at the cost of memory + backward_prefetch = BackwardPrefetch.BACKWARD_POST + + # CPU offloading; this can offload parameters to the CPU + cpu_offload = None - reshard_after_forward = opt['ddp_backend'] == 'zero3' - compute_dtype = torch.float16 if opt['fp16'] else torch.float32 - mixed_precision = opt['fp16'] and opt['fp16_impl'] == 'safe' fsdp_args = dict( - reshard_after_forward=reshard_after_forward, - mixed_precision=mixed_precision, - compute_dtype=compute_dtype, - state_dict_device=torch.device('cpu'), - flatten_parameters=True, process_group=get_dist_group(), + sharding_strategy=sharding_strategy, + cpu_offload=cpu_offload, + auto_wrap_policy=auto_wrap_policy, + backward_prefetch=backward_prefetch, + mixed_precision=mp_strategy, + ignored_modules=ignored_modules, + param_init_fn=None, + device_id=opt['gpu'], + sync_module_states=False, # need this for syncing the first call; specify False because we do it manually after cuda + forward_prefetch=False, # specify true for CPU-heavy workload + limit_all_gathers=False, # specifying the default here ) - with fairscale_enable_wrap(wrapper_cls=FSDP, **fsdp_args): + with enable_wrap(wrapper_cls=FSDP, **fsdp_args): yield @@ -109,3 +174,129 @@ def fsdp_wrap(module): Helper function for wrapping the outermost root module. """ return wrap(module) + + +def get_state_dict(model): + """ + Get the state dict from the model. + + When using Pytorch FSDP, we can offload to CPU. + """ + + if FSDP_AVAILABLE: + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullStateDictConfig, + StateDictType, + ) + + save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): + state = model.state_dict() + else: + state = model.state_dict() + + return state + + +@contextlib.contextmanager +def fsdp_join(*args): + with Join([*args]): + yield + + +class JoinableTrainLoop(TrainLoop, Joinable): + """ + Joinable train loop. + """ + + def __init__(self, opt): + import parlai.utils.distributed as dist_utils + + super().__init__(opt) + self.__device = opt['gpu'] + self.__group = dist_utils.get_dist_group() + + def __call__(self): + """ + Join caller. + + For now, don't do anything. + """ + Join.notify_join_context(self) + + def join_hook(self, **kwargs) -> JoinHook: + """ + Return our fake join hook. + """ + return TrainLoopJoinHook(self) + + @property + def join_device(self) -> torch.device: + return self.__device + + @property + def join_process_group(self): + return self.__group + + +class TrainLoopJoinHook(JoinHook): + """ + Join hook for train loop. + + Adapted from https://pytorch.org/tutorials/advanced/generic_join.html + """ + + def __init__(self, train_loop: JoinableTrainLoop): + self.train_loop = train_loop + + def main_hook(self): + pass + + def post_hook(self, is_last_joiner: bool): + pass + + +class JoinableEvaluator(Evaluator, Joinable): + """ + Joinable Evaluator. + """ + + def __init__(self, opt): + import parlai.utils.distributed as dist_utils + + super().__init__(opt) + self.__device = opt['gpu'] + self.__group = dist_utils.get_dist_group() + + def __call__(self): + """ + Join caller. + + For now, don't do anything. + """ + Join.notify_join_context(self) + + def join_hook(self, **kwargs) -> JoinHook: + """ + Return our fake join hook. + """ + return EvaluatorJoinHook(self) + + @property + def join_device(self) -> torch.device: + return self.__device + + @property + def join_process_group(self): + return self.__group + + +class EvaluatorJoinHook(JoinHook): + def __init__(self, evaluator: JoinableEvaluator): + self.evaluator = evaluator + + def main_hook(self): + pass + + def post_hook(self, is_last_joiner: bool): + pass diff --git a/requirements.txt b/requirements.txt index 5a3d38a0622..414941ac816 100644 --- a/requirements.txt +++ b/requirements.txt @@ -44,7 +44,7 @@ subword-nmt==0.3.7 tensorboardX<=2.5.0 tokenizers>=0.8.0 tomli>=2.0.0 -torchtext>=0.5.0,<=0.13.1 +torchtext>=0.5.0,<0.14.0 tornado==6.0.4 tqdm~=4.62.1 typing-extensions==3.7.4.3 diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 8a73de5ae9c..443dd28ef95 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -169,11 +169,8 @@ class TestZero2(TestDistributed): base_config = {**TestDistributed.base_config, 'ddp_backend': 'zero2'} -@unittest.skip @testing_utils.skipUnlessGPU class TestZero3(TestDistributed): - # Not supported at this time. See: - # https://github.com/facebookresearch/ParlAI/pull/3740 base_config = {**TestDistributed.base_config, 'ddp_backend': 'zero3'}