Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[FSDP] Zero 3 Optimization Support #4903

Merged
merged 10 commits into from
Dec 5, 2022
Merged

[FSDP] Zero 3 Optimization Support #4903

merged 10 commits into from
Dec 5, 2022

Conversation

klshuster
Copy link
Contributor

Patch description

This PR adds support for zero3 fsdp optimization. Specify via --ddp-backend zero3. Zero3 optimization shards not only the optimizer and gradients but also the model weights. This improves memory pressure, and is especially useful for larger models. Note that this can increase latency as there is more communication cost, and as such for smaller models there may be a slight hit to speed compared to zero2 (and model parallel, of course).

Addresses #3753

NOTE: This requires pytorch >= 1.12 for use

Testing steps

  • Enabled the zero3 tests already in CI (thanks Stephen!)
  • Confirmed that zero3 works when the number of validation examples is not divisible by the number of GPUs
  • Confirmed that zero3 works when --skip-generation False during training.
  • Conducted a comprehensive analysis of a variety of "staple" models within ParlAI under different distributed settings; see screenshots below

The main conclusions from each of the models below are:

  • Performance remains roughly the same regardless of model parallel, zero2, or zero3
  • Speed: for BART, zero2 remains faster. For the other models, we see varying results; ultimately, everything is faster than model parallel
  • GPU Memory: zero3

BART-Large (400M)

Screen Shot 2022-12-01 at 2 53 15 PM

T5-Large (770M)

Screen Shot 2022-12-01 at 2 51 32 PM

Reddit 2.7B (base of BlenderBot)

Screen Shot 2022-12-01 at 2 51 51 PM

GPT2-XL (1.5B)

Screen Shot 2022-12-01 at 2 52 16 PM

R2C2 2.7B

Screen Shot 2022-12-01 at 2 52 31 PM

Copy link
Contributor

@stephenroller stephenroller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, pretty smooth implementation. i might recommend just axing support for fairscale but you're the boss

@@ -516,6 +516,16 @@ def __init__(self, opt: Opt, shared=None):
else:
# this is not a shared instance of this class, so do full init
self.criterion = self.build_criterion()

def load_init_model() -> Dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is this used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh it's not, that's an artifact. will delete

self.train_loop = single_train.TrainLoop(opt)
return self.train_loop.train()
self.train_loop = fsdp_utils.JoinableTrainLoop(opt)
with fsdp_utils.fsdp_join(self.train_loop):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need to do this for distributed_eval too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, forgot about that. also multiprocessing eval

Comment on lines 41 to 42
from fairscale.nn.wrap.auto_wrap import wrap, enable_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why keep around fairscale support?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so getting rid of fairscale would force anyone who wants to use distributed training to be on pytorch >=1.12. Is that a reasonable ask, you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's your call but pytorch 1.13 is out (and I'm using it successfully)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i've removed fairscale FSDP

@klshuster
Copy link
Contributor Author

Note: There is currently a bug in pytorch 1.13 that doesn't allow specifying your own pickle module for loading; this is breaking some tests (test_apex.py); the fix is merged in pytorch/pytorch#88570, but we'll need to wait for pytorch to patch to allow us to pass tests with pytorch 1.13

for now, will have requirements specify latest 1.12 version

@klshuster klshuster merged commit 96aa1bb into main Dec 5, 2022
@klshuster klshuster deleted the zero3 branch December 5, 2022 19:24
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants