-
Notifications
You must be signed in to change notification settings - Fork 517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Delete torchtune's inplace copy definition for NF4 #1294
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1294
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 9ebfb42 with merge base f9f75bb (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -6,7 +6,6 @@ | |||
|
|||
import torch | |||
from torchao.dtypes.nf4tensor import implements as nf4_tensor_impl, to_nf4 | |||
from torchtune.modules.low_precision._utils import _get_torchao_version |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strictly speaking this util is now no longer used anywhere in our library, but I am inclined to keep it in for possible (likely?) future usage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope - get rid of it until we need it again. It's easy enough to find in git history.
@weifengpy seems we may need to add an override for |
yes, the error msg looks relevant. but I do not fully understand why torchtune implementation dispatch aten.to.dtype_layout but torchao’s implementation does not need it FSDP2 needs to support around 9 tensor ops but aten.to.dtype_layout is totally new (1st time seen this) |
I see the PR is accepted. curious how did we resolve the local_tensor=nan issue? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice! glad the version check code is also gone 💀
Thanks to @msaroufim for root causing this tricky bug. This addresses pytorch/ao#642 on our end as well as #1246.
Update: Instead of reenabling our inplace copy, @gau-nernst has found the upstream fix in torchao (described here). With this fix, we can fully delete the inplace copy from torchtune. Until we're on the ao version with the fix our QLoRA will continue to be slow, but after that things should be back to normal.
###Stuff below this line is outdated but kept as background###
We were defining our own override of aten copy op in torchtune prior to torchao version 0.2, when they added their own. After that we version-gated this override to ao < 0.2. But it turns out that the version we have is faster. Tbh idk why that is yet, but it should be safe to re-enable this to get us back to our previous perf on QLoRA state dict load.
Tested by logging time from recipe startup to the beginning of the train loop.
On main:
init time: 101.48360734200105
On this PR:
init time: 15.445241551031359