Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add auto_wrap option in fairscale integration #10673

Merged
merged 2 commits into from
Mar 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/main_classes/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ Known caveats:

- This feature is incompatible with :obj:`--predict_with_generate` in the `run_seq2seq.py` script.
- Using :obj:`--sharded_ddp zero_dp_3` requires wrapping each layer of the model in the special container
:obj:`FullyShardedDataParallelism` of fairscale. This is not done automatically by any of the example scripts of the
:class:`~transformers.Trainer`.
:obj:`FullyShardedDataParallelism` of fairscale. It should be used with the option :obj:`auto_wrap` if you are not
doing this yourself: :obj:`--sharded_ddp "zero_dp_3 auto_wrap"`.


DeepSpeed
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@

if version.parse(fairscale.__version__) >= version.parse("0.3"):
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
from fairscale.nn.wrap import auto_wrap
else:
FullyShardedDDP = None

Expand Down Expand Up @@ -775,8 +776,13 @@ def _wrap_model(self, model, training=True):
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
# XXX: Breaking the self.model convention but I see no way around it for now.
if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
model = auto_wrap(model)
self.model = model = FullyShardedDDP(
model, mixed_precision=mixed_precision, reshard_after_forward=zero_3, cpu_offload=cpu_offload
model,
mixed_precision=mixed_precision,
reshard_after_forward=zero_3,
cpu_offload=cpu_offload,
).to(self.args.device)

elif is_sagemaker_distributed_available():
Expand Down
1 change: 1 addition & 0 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,4 @@ class ShardedDDPOption(ExplicitEnum):
ZERO_DP_2 = "zero_dp_2"
ZERO_DP_3 = "zero_dp_3"
OFFLOAD = "offload"
AUTO_WRAP = "auto_wrap"
6 changes: 3 additions & 3 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,10 +470,10 @@ class TrainingArguments:
sharded_ddp: str = field(
default="",
metadata={
"choices": ["simple", "zero_dp_2", "zero_dp_3", "zero_dp_2 offload", "zero_dp_3 offload"],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choices started to grow longer as there are 4 more possibilities of options now. Can put it back if you really want to but I think the doc is enough.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but why do all combinations? why not just list the components

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"choices": ["simple", "zero_dp_2", "zero_dp_3", "offload", "auto_wrap"],

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This instruction will fail if we pass anything not in the list of choices, like `"zero_dp_2 offload".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I understand! It doesn't know of composite values and would have failed anyway if one were to pass: "offload zero_dp_2 offload" (reverse).

So it's either subclassing this argparse behavior to support compose values or removing it as you did.

I think as long as the help entry for that option lists all the choices, which it does, then it's good enough.

"help": "Whether or not to use sharded DDP training (in distributed training only). The base option "
"should be `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` "
"like this: zero_dp_2 offload` or `zero_dp_3 offload`",
"like this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or "
"with the same syntax: zero_dp_2 auto_wrap` or `zero_dp_3 auto_wrap`.",
},
)
deepspeed: Optional[str] = field(
Expand Down Expand Up @@ -570,7 +570,7 @@ def __post_init__(self):
"`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or "
'`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.'
)
elif len(self.sharded_ddp) > 1 and ShardedDDPOption.Simple in self.sharded_ddp:
elif len(self.sharded_ddp) > 1 and ShardedDDPOption.SIMPLE in self.sharded_ddp:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo fix

raise ValueError("`--sharded_ddp simple` is not compatible with any other option.")
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
Expand Down