-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FSDP2 support for low-bit optimizers #484
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/484
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 87ae147 with merge base cc871c5 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -48,10 +49,17 @@ def step(self, closure=None): | |||
if grad.is_sparse: | |||
raise RuntimeError("Sparse gradient is not supported") | |||
|
|||
# unwrap DTensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc: @wanchaol
For context: What is the advantage of not unwrapping the DTensor
? The optimizer states ideally can be DTensor
s too, in which case they have their sharding info present on the tensor, making optimizer-state checkpointing simplified.
However, low-bit optimizers may present challenges to this, which we should understand (from the DTensor
side). It would be cool to understand the pain points and blockers (whether around functionality/UX in eager mode, compile support, etc.).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@awgu Thanks for bringing this up!
@gau-nernst wondering if there's any challenges to not unwrapping DTensors? as Andrew mentioned, optimizer states should be DTensors so that it should be save/load correctly when performing distributed training. We have implemented the sharding strategies for Adam in core so if low-bit optimizers are similar to the Adam implementation in core, we just need to append the operators to the corresponding list
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tried using DTensor directly (that contains low-bit tensor subclass as its local_tensor
) but faced the following error
[rank0]:E0708 20:40:01.356000 128135323341888 torch/testing/_internal/common_distributed.py:664] AssertionError: s10 (could be from ["L['grad']._base._local_tensor.size()[1]"]) not in {s2: ["L['exp_avg']._local_tensor.codes.size()[0]", "L['exp_avg']._local_tensor.codes.size()[0]", "L['exp_avg']._local_tensor.codes.size()[0]", "L['exp_avg']._local_tensor.codes.size()[0]", "L['exp_avg']._local_tensor.codes.size()[0]"], s3: ["L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]"], s4: ["L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]"], s13: ["L['grad']._local_tensor.size()[0]", "L['grad']._local_tensor.size()[0]"], s11: ["L['grad']._local_tensor.storage_offset()", "L['grad']._local_tensor.storage_offset()"], s16: ["L['exp_avg_sq']._local_tensor.codes.size()[0]", "L['exp_avg_sq']._local_tensor.codes.size()[0]", "L['exp_avg_sq']._local_tensor.codes.size()[0]", "L['exp_avg_sq']._local_tensor.codes.size()[0]", "L['exp_avg_sq']._local_tensor.codes.size()[0]"], s17: ["L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]"], s25: ["L['p']._local_tensor.size()[0]", "L['p']._local_tensor.size()[0]"]}. If this assert is failing, it could be due to the issue described in pytorch/pytorch#90665
I'm guessing it's due to me using .view(-1)
on the DTensor. If I remove torch.compile, the test passes again. Also, I'm not very clear if using .view(-1)
is a proper way? How does DTensor handle .view(-1)
? Does it keep the same local tensor, or it gathers the tensor and reshard? We probably don't want the latter.
Adding on about .view(-1)
. I added this since torch.compile will recompile and eventually hit max cache size limit when facing different inputs, even though I tried to set dynamic=True
. Perhaps we should try to fix this instead? (i.e. so that we don't need to use .view(-1)
anymore, seems like the source of many problems).
I will make another branch using DTensor as optim state, so you can inspect and debug the error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another option is init optim state as DTensor, but when call the Adam step, unwarp everything to local tensor. Kinda a workaround to get distributed saving/loading, while keeping torch.compile happy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good, I think on a Shard(0)
placement DTensor (where FSDP2 params look like), the view(-1)
shouldn't do any communication/reshard and it should return a Shard(0)
DTensor directly, so on eager side it should behave exactly what we want, otherwise it's a bug in DTensor.
I'm not sure why we hit recompliation when turning on torch.compile, it would be great if you could create a branch for repro, and we can take a look why compile not working expected.
cc @bdhirsh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whether compile cache size limit is hit depends on the model. To be more specific, it is the variety of param shapes. There are 2 main sources of recompilation for my single_param_adam()
functions:
- Input shape changes (if dynamic=True, then it is when ndim - or tensor rank - changes)
- Class object - we need to compile one version for normal tensor optim state, and another version for quantized tensor subclass optim state. This is because not all params will be qualified for quantization (e.g. too small, size is not divisible by block size, or user might want to exclude certain params like embedding layer)
Just 2 sources above can result in many combinations. For a ViT model from timm that I tested here, whose params have a lot of different shapes, not doing dynamic=True
and .view(-1)
will hit compile cache size limit easily. Maybe for LLMs, params don't have such a variety of shapes, but it's kinda not very nice if we don't handle the case when there are a variety of param shapes.
A simple solution (without doing .view(-1)
and/or dynamic=True
) is to just increase cache size limit. Not sure if it is encouraged?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A simple solution .... is to just increase cache size limit. Not sure if it is encouraged?
It probably depends on your use case. One failure-mode when increasing the cache limit is if e.g. your inputs have an unbounded amount of dynamism and you end up recompiling on every new input that comes in. But if you know that you have a fixed number of parameters with distinct shapes (e.g. ~50) and you know that you will only need to recompile once per param shape, bumping the cache size limit to what you need is probably ok
There are 2 main sources of recompilation for my single_param_adam()
Mostly a curiosity question: is there anything stopping you from compiling the entire step function? Or maybe a subset of it, that performs single_param_adam()
on every parameter? That way you get one giant graph with all of the params (allowing inductor to potentially fuse more things), instead of many small graphs.
Regardless, we want to fix the dynamic shape issues and this is probably a good reason to keep pushing on pytorch/pytorch#124619
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It probably depends on your use case. One failure-mode when increasing the cache limit is if e.g. your inputs have an unbounded amount of dynamism and you end up recompiling on every new input that comes in. But if you know that you have a fixed number of parameters with distinct shapes (e.g. ~50) and you know that you will only need to recompile once per param shape, bumping the cache size limit to what you need is probably ok
That's why I'm leaning towards the unwarping DTensor "hack", since dynamice=True
+ .view(-1)
already work well (at least I think so) for non-FSDP setup, so it's natural to "hack" it to support DTensor.
is there anything stopping you from compiling the entire step function?
I thought of it but haven't tried it myself. I can give it a try.
So in the end, how should we decide which approach to close this PR? Maybe I will try the horizontal fusion that you mention first to see if it can solve all of the problems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, re-reading @awgu's comment that was not answered
What is the advantage of not unwrapping the DTensor?
In other words, what are the problems that unwrapping DTensor might be discouraged? Quantized optim state can be wrapped by DTensor and then unwrapped before adam step too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In other words, what are the problems that unwrapping DTensor might be discouraged? Quantized optim state can be wrapped by DTensor and then unwrapped before adam step too.
Ah, you are saying you can keep the optimizer states as DTensor
s but just unwrap them for the actual step()
? If so, I think that works. We would only not prefer it because it requires the optimizer code to be aware of DTensor
. (Ideally, like in the native torch optimizers, we do not need to change the code to make it work with DTensor
, but this is just the ideal state.)
test/prototype/test_low_bit_optim.py
Outdated
if isinstance(m, TransformerBlock): | ||
fully_shard(m, **fsdp_kwargs) | ||
fully_shard(fsdp_model, **fsdp_kwargs) | ||
fsdp_optim = low_bit_optim.Adam8bit(fsdp_model.parameters(), lr=1e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is .grad
converted from high precision (fp32 or bf16) to 8bit inside Adam8bit
? Asking because FSDP2 seems to be performing gradient reduction in high precision base on fsdp_kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Grad is in original precision. To summarize
- optim state is dequant to FP32 (maybe I need to change it to param or grad dtype? haven't test full BF16 training)
- adam step as usual
- optim state is quantized back to 8-bit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for explaining. that means FSDP2 logic during fwd/bwd remains the same (doing high precision gradient reduction) no matter optimizer is in high precision or low precision. The above dicussion around DTensor is the core then
@bdhirsh I followed your suggestion (static compile optim step for all params in a single graph) and it worked out great! Everything works now. I needed to separate out the optim state init step (because it seems like torch.compile() can't produce fullgraph=True when there is However, there is a small problem: compile time increases. It shouldn't be a big problem for real training, but FSDP test in CI might timeout. The test timeout on my local machine, hope it doesn't in CI. (Even if CI doesn't timeout, it is also slightly not good in a sense that it increases CI time). UPDATE: indeed the test also timeout in CI. will skip CPUOffload test UPDATE 2: memory usage for 4-bit optimizer explodes :(
|
@awgu That looks awesome! Regarding memory fragmentation, do you know how we can improve it? I saw some people using Also, have you tried using |
The AdamW8bit error is indeed strange. I think something fails, then torch.compiler tries to print the tensor Anyway, it's probably a good idea to make ao/torchao/prototype/low_bit_optim/subclass_8bit.py Lines 37 to 39 in 5787e9e
|
Thinking about it again, the memory fragmentation might be partly because I create a normal DTensor and then swap the local tensor with its quantized version. A better way would be to create a quantized local tensor directly and wrap it with DTensor (either by calling ao/torchao/prototype/low_bit_optim/adamw.py Lines 41 to 47 in 5787e9e
|
I would need to step through the memory snapshot more systematically, but I am not sure if that is the issue because the non-quantized local tensor should be freed immediately after constructing the quantized local tensor and that memory can be reused for the next parameter's optimizer states (or next optimizer state for a given parameter). As long as you are not constructing local tensors for all parameters at once and then quantizing all of them, this should not be an issue in my understanding. |
Yup I won't say it's the only cause for memory fragmentation in this case. Python garbage collector and PyTorch memory allocator can work in mysterious ways I suppose, not an expert in this. But I think it's a good idea to create quantized local tensor directly anyway. I did what I did above mainly because I couldn't get |
I think calling |
Thanks, I will try! |
@gau-nernst cannot run low bit optim with FSDP2
|
This PR turns out to be a larger change than expected.
Background
In my previous implementation of low-bit optimizers, I only compile the optim step for 1 param weight. Let's call this
single_param_adam()
. Since weights in a model can have different shapes and sizes, to avoid re-compilation:torch.compile(dynamic=True)
to generate dynamic kernel.view(-1)
before passing them to the compiled optim step, since tensors with different rank/ndim will be re-compiled too, even for dynamic kernel.Support FSDP and DTensor
Everything was fine for normal tensors. To support FSDP, the following needs to be done
local_tensor
being our quantized subclass.I soon discovered that torch.compile cannot generate dynamic kernel for sharded dimension (#490 for a small reproduction snippet) + using
.view(-1)
on DTensor before torch.compile give strange errors. Brian has a more detailed analysis in this comment.Following Brian's advice, I switched to compiling static kernel for the all-params optim step instead. Briefly,
Since we don't expect
param_groups
to change during the course of training (though technically possible), static compile is all good. Works great with DTensor. The added bonus is that 8-bit optim is now much faster, presumably thanks to static specialization + reduced overhead from repeated single-param optim step calls.Probem with 4-bit opim
Wait, but 4-bit optim memory usage explodes 🙁. It's likely due to bit-packing, as only 4-bit optim does that, and torch.compile uses a different fusion pattern when compiling
param_groups_adam()
. Note that the memory usage is even higher than 8-bit optim and nearly equal to FP32 optim, indicating that some dequantized optim states are materialized in global memory.To resolve this, I thought of 2 solutions:
single_param_adam()
) for 4-bit optim. To handle DTensor, I can unwrap it before callingsingle_param_adam()
.(state[::2] << 4) | state[1::2]
), I can pack optim state 1 (exp_avg
) and optim state 2 (exp_avg_sq
) together ((exp_avg << 4) | exp_avg_sq
). This would be easier for torch.compile to fuse ops correctly, but at the expense of (a) requiring a separate implementation for 4-bit optim, and (b) does not use subclass design anymore, as unpacking and re-packing must be explicit to createexp_avg
andexp_avg_sq
. It also means that it's not so easy to accessexp_avg
andexp_avg_sq
directly anymore.There is no clear speed advantage between Solution 1 and Solution 2. However, when running FSDP test, Solution 2 is getting this error
which is required by bit-packing and unpacking. This should be an easy fix for PyTorch core (I have opened an issue here pytorch/pytorch#130671).
So overall, there is no advantage for Solution 2. Using Solution 1, once the excessive memory usage bug is fixed, we can simply remove the
single_param_adam()
call and replace withparam_groups_adam()
, while keeping the tensor subclass design.Summary
New benchmark results