-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feature-fix-refactor][ShardedDDP] Make it possible to change trainability graph on the fly #369
Conversation
…anged. Make it optional so that you save some time
@@ -565,24 +564,6 @@ def _broadcast_params(self) -> None: | |||
if last_work_handle: | |||
last_work_handle.wait() | |||
|
|||
def _consume_work_handles(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since the tensor views change (broadcast buckets use tensor views) this was not being used in OSS anymore, just in ShardedDDP, I figured it was cleaner to move it there
not handling multiple optimizers properly, fixing that |
…. which shows that this is not handled properly -the precvious case is
@@ -145,55 +171,29 @@ def check_parity(amp: bool): | |||
module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True | |||
) | |||
|
|||
next(model.parameters()).requires_grad = False # Test non-trainable parameters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the whole point of this PR.. check that you can change the training graph after instantiation and still get the correct results
FB only: tested with RegNet256 f249307456 f2578230350 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this cost savings we are getting from making this assumption really worth the complexity we are introducing to users who now have to think about calling refresh_trainable()?
What is the overhead of checking ourselves if we need to refresh?
I just saw this, I actually just pinged you on another PR because of that ! So to step back a little, the assumption was always there, it's not new from this PR, I never thought about it so it was an oversight (it was also in OSS). The ShardedDDP code was not very flexible on that front, hence this rather big PR to make it easier to update Now about the cost, one issue is that the partitioning changes, because we only broadcast or optimize the trainable params (for instance HF had a finetuning job with a huge embedding table I think, which was frozen, this should not count for OSS partitioning since there is no corresponding optimizer state). Now if this parameter becomes trainable, we need to repartition, which means change all the flat buffers. Doing this for every step will have a very sizeable speed impact I think, for big models it means traversing the whole graph and checking that nothing changed when compared to before, I can measure that indeed but it would be model size dependent (and a trivial implementation would be sequential - each rank checks - which would not scale so well). Would
What do @stas00 @mannatsingh @vreis @mrshenli or @SeanNaren think ? (picking users) |
Let me check if I understood the proposal correctly: The user doesn't need to do anything special ever, unless they freeze/unfreeze some layers after the fairscale components were initialized.. If I got that right and the cost of getting optimized performance is running an additional The alternative solution is to somehow have a flag on the model level that if "dirty" - will automatically trigger a refresh. Such flag would be turned off as soon as fairscale has initiated its machinery. And then fairscale or pytorch (since a lot of it will end up in pytorch) provides a method to freeze/unfreeze layers which besides doing its normal work, will also make the flag "dirty". So as long as the user uses this method to freeze/unfreeze layers they don't have to do anything else. If they choose to do it in their own way, they should invalidate the flag, or it'd be the same as calling
I think the correct solution here is an opt-out from auto-check followed by manual refresh, that is auto-check should be the default. Because otherwise things will sometimes work, and other times sort of work and give bad results that could be missed. So it's better to be slow but correct by default. The other more risky approach is not to enable it by default and run the auto-check anyway, but say every 1000 steps, and assert if a change is detected - explaining what needs to be done to make things right. I'm just not sure how to pick the right interval so that it's large enough not to be taxing, yet small enough for it to detect such changes. |
@stas00 brings up really good points, and I have similar understanding! We'll exclude PL for now, from a PyTorch perspective it would definitely be not recommended to introduce an additional call like this even for speed benefits, as users would preferably not have to diverge from typical freezing/unfreezing specifically for ShardedDDP. From a PL perspective (and I think HF transformers since you guys define freeze functions right?) I think it's cool since we have control over the freeze logic to some extent. On average I think |
HF Trainer currently doesn't have freeze functions. Some example scripts like
@SeanNaren, I think @blefaudeux is asking how the default should be handled - I'm with you to what you said above, but I don't think this behavior should be the default. Since it could lead to potentially undetectable problems. Especially, since users may not choose to use the framework trainer's freeze functions. For example if they are porting from a different framework and they already have an existing way that works. On the other hand since the framework's trainer initializes ShardedDDP, it can be left up to the specific trainer to choose what the default behavior is for that framework. So perhaps, there should be no default on the fairscale level, but the policy flag has to be required - which will force the integrator to make a choice and stand behind it. Does it make sense?
or to closer mimic pytorch's HF Trainer has
So the intention is correctness out of the box over speed. Do note that pytorch Should pytorch introduce |
agreed after thinking about it a bit more, the default should be correct no matter what, makes sense to me |
Thanks @stas00 and @SeanNaren for the comments, very much on point and appreciated. To make sure that this is clear, there's a link with "find unused parameters" (in that we're talking about the graph that we're training, and it's correctness), but it's not the exact same problem that we're trying to solve, in that case a parallel could be in between eager evaluation or static mode. I agree with Stas conclusion here, in that getting this wrong would be pretty subtle for a user (no crash or loss going NaN, it would "just"not optimize what's planned, I've recently seen an unrelated bug of the sort in a framework and it went unseen for a long time), so better err on the side of caution I think (it was also Mandeep's opinion I think). I'll implement that for both OSS and ShardedDDP, since they have the same issue on that front. |
sorry, late to the party. Just wondering: why didn't pytorch ddp need this? Is it because they use "find_unused_parameter"? In vissl's case though, I don't think it uses find_unused_parameter but I could be wrong. |
Whatever solution you guys commit to and once it's documented would you kindly ping us so that we could implement this change and document for users how they can further optimize this behavior. Actually probably this can only be done once a new release is made so that we could set the deps correctly. I highly recommend to give users a warning like pytorch does if a wasteful operation is done and it proves to be unneeded - this would be the easiest way for users to discover that they could do better. I guess the only tricky part would be how to tell the user to override the default when there is a framework trainer which hides the implementation. So if you decide to add this extra signal I'd say the sharded init function should probably have an extra optional arg, that if supplied will be used to instruct the user on how to overcome the default. e.g.:
and fairscale will provide the first part of the warning, so that part would remain the same. I hope I'm not over-complicating things. |
Update on the current status:
|
(the broken test seems to be related to an ssh misconfig or just a broken network for a sec, unrelated to this PR https://app.circleci.com/pipelines/github/facebookresearch/fairscale/1615/workflows/ce4480bb-289c-4a22-9632-e4cf416aaf5d/jobs/7784) |
alright, now in need to handle differences in between pytorch versions.. |
import io | ||
from typing import Any, Callable, Dict, Optional | ||
|
||
import torch | ||
from torch._six import container_abcs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this breaks on fbcode, not compatible with all torch versions it seems and not useful I presume given that it's in collections
def _setup_flat_buffers(self) -> None: | ||
"""Make all params which are on the same device and tied to the same rank views of a single buffer. | ||
This is used at construction time, and anytime parameter trainability is changed (frozen or unfrozen) and | ||
`refresh_trainability` is called. | ||
""" | ||
|
||
for device, per_rank_params in self.per_device_params.items(): | ||
self.buckets[device] = [] | ||
# Only wipe the existing buckets if there are none |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a new part / significantly changed, the idea is that the buckets get re-deployed when trainability changed (since we only broadcast trainable params), and when re-deploying them you need to take care not to loose the previous state
# Tensor cannot be really empty, even if its size is meaningless | ||
dummy_sync_tensor = torch.tensor([1], device=self._device) | ||
dummy_sync_tensor = torch.tensor([1], device=self._default_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_device -> _default_device was just a refactor, no logic change, guess was that _device was not clear what this meant
INPUTS = 2 | ||
BATCH_SIZE = 32 | ||
|
||
def check_parity(amp: bool, accumulate: bool, change_train_graph: bool): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a systematic check for parity with DDP, with AMP/accumulation/change train graph all flipped on and off
@@ -406,3 +406,11 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool: | |||
return False | |||
else: | |||
return a == b | |||
|
|||
|
|||
def check_same_model_params(model_a: torch.nn.Module, model_b: torch.nn.Module, message: str = "") -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor refactor, this was used in several places and copy pastaed
check_same_model_params() | ||
|
||
# Check that altering the trainable parameters does not cause DDP and OSS to diverge | ||
if change_train_graph: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new logic change in this test, check that changing trainability is properly taken into account. Commenting out this PRs additions in OSS/ShardedDDP breaks, as expected
@@ -71,6 +78,14 @@ class ShardedDataParallel(nn.Module): | |||
handled. In that case ShardedDDP will raise an exception and suggest to either remove the unused parameters from your model | |||
(https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=unused_parameters is helpful) | |||
or set `reduce_buffer_size` to 0 | |||
|
|||
.. warning: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tentatively explain the two options:
- auto detect (default)
- manual + explicit refresh() call
@@ -117,14 +134,19 @@ def __init__( | |||
# several optimizers can be present each working on seperate parameter set which is spread across multiple ranks | |||
|
|||
# - we build an iterator which goes through all the parameters involved globally | |||
all_param_iterator = chain( | |||
*[sum([sum(p, []) for p in optim.per_device_params.values()], []) for optim in self.sharded_optimizers] | |||
self._all_params = list( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cache this list because reused for every step if auto_detect_trainability
|
||
# - keep track of the grads which have already been reduced | ||
self._reduced_grads: Dict[OSS, int] = {} | ||
self._reduced_grads_max = {o: len(o.param_to_rank.values()) for o in self.sharded_optimizers} | ||
self._reduced_grads = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prior to this PR, the trainability graph was kind of baked in these structures. One way to make the comm change more manageable I think is to have all these being completely flat, and handle the partition in a single place
# Optionally check whether the trainable parameters have changed | ||
if self.auto_refresh_trainable: | ||
trainable_mask = list(map(_trainable, self._all_params)) | ||
if trainable_mask != self._reference_trainable_mask: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just compare trainability binary masks, any change (one parameter frozen/unfrozen) will trigger an update
self._trainable_param_to_rank = {} | ||
for optim in self.sharded_optimizers: | ||
# OSS may need to change the communication pattern | ||
optim.refresh_trainable() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the OSS broadcast pattern needs to be updated on the fly
for param in filter(lambda x: x.requires_grad, device_params): | ||
self._trainable_param_to_rank[param] = optim.param_to_rank[param] | ||
|
||
self._setup_bucket_strategy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for the buckets (we reduce the grads only..) and the hooks (new hooks could be required)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very nice. If not already, I'd suggest testing it on more than 2 GPUs since there might be corner cases there.
this will impact the long term memory consumption, because these buckets correspond to parameters which will not be sharded. | ||
Set to 0 to remove all bucketing. | ||
auto_refresh_trainable (bool): | ||
Check whether the parameters trainability (`requires_grad`) has changed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add default value?
automatically. | ||
If `auto_refresh_trainable` is set to `False`, ShardedDDP will not refresh its assumptions with respect to trainable parameters | ||
for every forward pass, in the hope of saving some time. If some parameters are frozen or unfrozen over time, please refresh | ||
ShardedDDP assumptions by calling `refresh_trainable()` just after said change (before the next forward pass). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice doc
def refresh_trainable(self) -> None: | ||
""" If the module trainability has changed, update all the assumptions """ | ||
|
||
self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
anything to assert (assumptions) up on entering this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, I'll try to think about something, all this is tricky (to me at least) so I could do with more asserts indeed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
definitely not "just to you"! 🤣
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pulling my hairs now on issues with the state dict loading & custom optimizers (turns out I should not write in their state), distributed training can certainly get a tad bit complex...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feel you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just added an assert, it was a very good call I think, for instance somebody could have had the 'no_sync' context activated, then refresh trainability while forgetting to send the gradients (which would be lost).
@@ -133,67 +133,66 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): | |||
torch.cuda.set_device(rank) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@min-xu-ai checking with you, the ddp_parity test runs with cuda_count(), so it should be 4 GPUs on CI. Is that ok ? (sanity checking that I'm not missing anything)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know CI's gpu limit. If it is >2, then it is great. Also, I have seen bugs only shows up in gpu>5 in other cases. (not for oss or shardedDP).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ouch for 5+, I'll test that on fb cluster again. Seems like CI/unit tests is 4 gpus (https://app.circleci.com/pipelines/github/facebookresearch/fairscale/1644/workflows/9e1b0fc9-92dd-4d4e-96de-be775cf5634b/jobs/7982)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checking with f250620664 and f250623267 [FB only]
…anks @min for the review
Before submitting
What does this PR do?
Fixes #368 and #354, required by VISSL and possibly others
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃