-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[FSDP] Zero 3 Optimization Support #4903
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, pretty smooth implementation. i might recommend just axing support for fairscale but you're the boss
parlai/core/torch_generator_agent.py
Outdated
@@ -516,6 +516,16 @@ def __init__(self, opt: Opt, shared=None): | |||
else: | |||
# this is not a shared instance of this class, so do full init | |||
self.criterion = self.build_criterion() | |||
|
|||
def load_init_model() -> Dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how is this used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh it's not, that's an artifact. will delete
self.train_loop = single_train.TrainLoop(opt) | ||
return self.train_loop.train() | ||
self.train_loop = fsdp_utils.JoinableTrainLoop(opt) | ||
with fsdp_utils.fsdp_join(self.train_loop): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you need to do this for distributed_eval too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, forgot about that. also multiprocessing eval
parlai/utils/fsdp.py
Outdated
from fairscale.nn.wrap.auto_wrap import wrap, enable_wrap | ||
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why keep around fairscale support?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so getting rid of fairscale would force anyone who wants to use distributed training to be on pytorch >=1.12. Is that a reasonable ask, you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's your call but pytorch 1.13 is out (and I'm using it successfully)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i've removed fairscale FSDP
Note: There is currently a bug in pytorch 1.13 that doesn't allow specifying your own pickle module for loading; this is breaking some tests ( for now, will have requirements specify latest 1.12 version |
Patch description
This PR adds support for zero3 fsdp optimization. Specify via
--ddp-backend zero3
. Zero3 optimization shards not only the optimizer and gradients but also the model weights. This improves memory pressure, and is especially useful for larger models. Note that this can increase latency as there is more communication cost, and as such for smaller models there may be a slight hit to speed compared to zero2 (and model parallel, of course).Addresses #3753
NOTE: This requires pytorch >= 1.12 for use
Testing steps
--skip-generation False
during training.The main conclusions from each of the models below are:
BART-Large (400M)
T5-Large (770M)
Reddit 2.7B (base of BlenderBot)
GPT2-XL (1.5B)
R2C2 2.7B