Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[FSDP] Zero 3 Optimization Support #4903

Merged
merged 10 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -219,26 +219,26 @@ commands:
- setupcuda
- fixgit
- restore_cache:
key: deps-20221130-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
key: deps-20221202-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
- setup
- installdeps
- << parameters.more_installs >>
- save_cache:
key: deps-20221130-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
key: deps-20221202-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
paths:
- "~/venv/bin"
- "~/venv/lib"
- findtests:
marker: << parameters.marker >>
- restore_cache:
key: data-20221130-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
key: data-20221202-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
- run:
name: Run tests
no_output_timeout: 60m
command: |
coverage run -m pytest -m << parameters.marker >> << parameters.pytest_flags >> --junitxml=test-results/junit.xml
- save_cache:
key: data-20221130-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
key: data-20221202-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
paths:
- "~/ParlAI/data"
- codecov
Expand All @@ -255,12 +255,12 @@ commands:
- checkout
- fixgit
- restore_cache:
key: deps-20221130-bw-{{ checksum "requirements.txt" }}
key: deps-20221202-bw-{{ checksum "requirements.txt" }}
- setup
- installdeps
- installtorchgpu
- save_cache:
key: deps-20221130-bw-{{ checksum "requirements.txt" }}
key: deps-20221202-bw-{{ checksum "requirements.txt" }}
paths:
- "~/venv/bin"
- "~/venv/lib"
Expand Down
8 changes: 5 additions & 3 deletions parlai/agents/hugging_face/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from parlai.core.params import ParlaiParser
from parlai.core.torch_agent import Batch, TorchAgent
from parlai.core.torch_generator_agent import TorchGeneratorAgent, TorchGeneratorModel
from parlai.utils.fsdp import is_fsdp
from parlai.utils.fsdp import is_fsdp, delay_halving


def check_hf_version(v: Tuple[int, int]) -> bool:
Expand All @@ -41,7 +41,9 @@ def check_hf_version(v: Tuple[int, int]) -> bool:
def build_t5(opt: Opt) -> T5ForConditionalGeneration:
if not check_hf_version(HF_VERSION):
raise RuntimeError('Must use transformers package >= 4.3 to use t5')
torch_dtype = torch.float16 if opt['fp16'] else torch.float32
torch_dtype = (
torch.float16 if (opt['fp16'] and not delay_halving(opt)) else torch.float32
)
try:
return T5ForConditionalGeneration.from_pretrained(
opt['t5_model_arch'],
Expand Down Expand Up @@ -369,7 +371,7 @@ def output(self, tensor):
"""
# Taken directly from HuggingFace
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
# See https://github.com/tensorflow/mesh/blob/fa19d69/mesh_tensorflow/transformer/transformer.py#L586
tensor = tensor * (self.t5.model_dim**-0.5)
lm_logits = self.t5.lm_head(tensor)
return lm_logits
Expand Down
4 changes: 2 additions & 2 deletions parlai/core/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,11 +774,11 @@ def add_distributed_training_args(self):
)
grp.add_argument(
'--ddp-backend',
# TODO: add in zero3. https://github.com/facebookresearch/ParlAI/issues/3753
choices=['ddp', 'zero2'],
choices=['ddp', 'zero2', 'zero3'],
default='ddp',
help=(
'Distributed backend. Zero2 can be faster but is more experimental. '
'Zero3 significantly reduces memory pressure. '
'DDP is the most tested.'
),
)
Expand Down
13 changes: 11 additions & 2 deletions parlai/core/torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@
from parlai.utils.distributed import is_distributed
from parlai.utils.misc import AttrDict, warn_once
from parlai.utils.io import PathManager
from parlai.utils.fsdp import should_sync_gradnorm, is_fsdp, DEFAULT_DDP_BACKEND
from parlai.utils.fsdp import (
should_sync_gradnorm,
is_fsdp,
DEFAULT_DDP_BACKEND,
FSDP_AVAILABLE,
get_state_dict,
)
from parlai.utils.fp16 import (
SafeFP16Optimizer,
MemoryEfficientFP16Optimizer,
Expand Down Expand Up @@ -1981,8 +1987,11 @@ def state_dict(self):
if hasattr(self.model, 'module') and not is_fsdp(self.model):
# did we wrap in a DistributedDataParallel or DataParallel
states['model'] = self.model.module.state_dict()
elif is_fsdp(self.model) and FSDP_AVAILABLE:
# FSDP Model; use fancy saving
states['model'] = get_state_dict(self.model)
else:
# regular model or FSDP
# regular model
states['model'] = self.model.state_dict()

if hasattr(self, 'optimizer'):
Expand Down
8 changes: 5 additions & 3 deletions parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from parlai.utils.misc import warn_once
from parlai.utils.io import PathManager
import parlai.utils.logging as logging
from parlai.core.metrics import Metric, SumMetric, AverageMetric, FairseqBleuMetric
from parlai.core.metrics import SumMetric, AverageMetric, FairseqBleuMetric
from parlai.utils.fp16 import FP16SafeCrossEntropy
import parlai.utils.fsdp as fsdp_utils
from parlai.utils.torch import (
Expand Down Expand Up @@ -516,8 +516,10 @@ def __init__(self, opt: Opt, shared=None):
else:
# this is not a shared instance of this class, so do full init
self.criterion = self.build_criterion()

self.model = self.build_model()
with fsdp_utils.maybe_fsdp_wrap(opt):
self.model = fsdp_utils.fsdp_wrap(self.build_model())
self.model = fsdp_utils.fsdp_wrap(self.model)
if self.fp16 and not fsdp_utils.delay_halving(opt):
self.model = self.model.half()

Expand Down Expand Up @@ -2054,7 +2056,7 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection

class FactualNucleusSampling(NucleusSampling):
"""
Factual Nucleus Sampling
Factual Nucleus Sampling.

See https://arxiv.org/pdf/2206.04624.pdf for more information
"""
Expand Down
8 changes: 5 additions & 3 deletions parlai/scripts/distributed_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
-m seq2seq -t convai2 --dict-file /path/to/dict-file
```
"""

import parlai.scripts.eval_model as eval_model
from parlai.core.script import ParlaiScript
import parlai.scripts.eval_model as eval_model
import parlai.utils.distributed as distributed_utils
import parlai.utils.fsdp as fsdp_utils


def setup_args():
Expand All @@ -51,7 +51,9 @@ def setup_args(cls):

def run(self):
with distributed_utils.slurm_distributed_context(self.opt) as opt:
return eval_model.eval_model(opt)
self.evaluator = fsdp_utils.JoinableEvaluator(opt)
with fsdp_utils.fsdp_join(self.evaluator):
return self.evaluator.eval_model()


if __name__ == '__main__':
Expand Down
6 changes: 4 additions & 2 deletions parlai/scripts/distributed_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import parlai.scripts.train_model as single_train
from parlai.core.script import ParlaiScript
import parlai.utils.distributed as distributed_utils
import parlai.utils.fsdp as fsdp_utils


def setup_args():
Expand All @@ -51,8 +52,9 @@ def setup_args(cls):

def run(self):
with distributed_utils.slurm_distributed_context(self.opt) as opt:
self.train_loop = single_train.TrainLoop(opt)
return self.train_loop.train()
self.train_loop = fsdp_utils.JoinableTrainLoop(opt)
with fsdp_utils.fsdp_join(self.train_loop):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need to do this for distributed_eval too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, forgot about that. also multiprocessing eval

return self.train_loop.train()


if __name__ == '__main__':
Expand Down
17 changes: 15 additions & 2 deletions parlai/scripts/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
aggregate_unnamed_reports,
Metric,
)
from parlai.core.opt import Opt
from parlai.core.worlds import create_task
from parlai.utils.misc import TimeLogger, nice_report
from parlai.utils.world_logging import WorldLogger
Expand Down Expand Up @@ -77,7 +78,10 @@ def setup_args(parser=None):
'-auc',
type=int,
default=-1,
help='a positive number indicates to calculate the area under the roc curve and it also determines how many decimal digits of the predictions to keep (higher numbers->more precise); also used to determine whether or not to calculate the AUC metric',
help='a positive number indicates to calculate the area under the '
'roc curve and it also determines how many decimal digits of the '
'predictions to keep (higher numbers->more precise); also used '
'to determine whether or not to calculate the AUC metric',
)
parser.add_argument(
'--area-under-curve-class',
Expand Down Expand Up @@ -291,14 +295,23 @@ def eval_model(opt):
return report


class Evaluator:
def __init__(self, opt: Opt):
self.opt = opt

def eval_model(self):
return eval_model(self.opt)


@register_script('eval_model', aliases=['em', 'eval'])
class EvalModel(ParlaiScript):
@classmethod
def setup_args(cls):
return setup_args()

def run(self):
return eval_model(self.opt)
self.evaluator = Evaluator(self.opt)
return self.evaluator.eval_model()


if __name__ == '__main__':
Expand Down
7 changes: 5 additions & 2 deletions parlai/scripts/multiprocessing_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
import torch
import os
import signal
import parlai.utils.distributed as distributed_utils
import parlai.scripts.eval_model as eval_model
import parlai.utils.distributed as distributed_utils
import parlai.utils.fsdp as fsdp_utils
from parlai.core.script import ParlaiScript, register_script


Expand All @@ -43,7 +44,9 @@ def multiprocess_eval(
rank, opt, rank_offset, gpu, init_method=init_method
) as opt:
opt['multiprocessing'] = True
return eval_model.eval_model(opt)
evaluator = fsdp_utils.JoinableEvaluator(opt)
with fsdp_utils.fsdp_join(evaluator):
return evaluator.eval_model()


def launch_and_eval(opt, port):
Expand Down
5 changes: 4 additions & 1 deletion parlai/scripts/multiprocessing_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import traceback
import parlai.scripts.train_model as single_train
import parlai.utils.distributed as distributed_utils
import parlai.utils.fsdp as fsdp_utils
from parlai.core.script import ParlaiScript, register_script


Expand All @@ -41,8 +42,10 @@ def multiprocess_train(
) as opt:
# Run the actual training
opt['multiprocessing'] = True
loop = fsdp_utils.JoinableTrainLoop(opt)
try:
return single_train.TrainLoop(opt).train()
with fsdp_utils.fsdp_join(loop):
return loop.train()
except Exception:
import parlai.utils.logging as logging

Expand Down
Loading