-
Notifications
You must be signed in to change notification settings - Fork 204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support INT8 mixed-precision training from torchao? #578
Comments
cc: @weifengpy |
@gau-nernst nice work! I took a look at the original torchao PR
|
What do you mean by this?
Yea this ("rowwise became column wise in the backward") is the main problem preventing me from implementing INT8 all-gather.
Some extra thoughts.
|
thanks for explaining everything in detail
I thought
yeah. INT8 all-gather might be the main justfication to land into torchtitan, since this repo is used to demonstrate composability with distributed api for rowwise, if backward is too hard, are you comfortable with supporting INT8 all-gather with
if the numerics does not become too bad witht tensor-wise scaling, it's a great demonstration for INT8 all-gather |
Oh yea right now I don't have any special logic with it. So that state_dict will be a tensor subclass wrapper holding the original high precision weight (NOT int8). For INT8 mixed-precision training, I only inject custom matmul logic, weights stay the same (same as FP8 training).
Does that mean INT8 post-all-gathered weights remain in memory starting from forward until backward? If that's the case, we can just do what I suggested earlier?
More concretely:
In other words, it differs in which version of the weight will be used for column-wise quantization in backward: whether to use the original weight, or use the row-wise quantized weight used in forward. Otherwise, to just demonstrate INT8 all-gather, I think it is easier (and save efforts) to do INT8 tensor-wise scaling 🤣. |
agree, having tensor-wise scaling is already a good thing. I will bring this topic for discussion with the team |
I think long term it's great to unify training APIs in torchao, to enable torchtitan to work with float8/int8/mx/etc training in the same way. I'm working on this, no ETA yet. Short term if someone wants to add int8 training to torchtitan as an experimental feature - SGTM personally, but I'll also defer to torchtitan folks on that. |
we would love to have this feature after discussion. we can start with tensor-wise scaling. it's also consistent with our float8 offering |
Recently I worked on INT8 mixed-precision training in torchao. The relevant PR is here pytorch/ao#748
Preliminary results show that with torchtitan, it improves speed by 20% on 8x A100 with no noticeable difference in loss curve. See the PR for more details.
Would you be open to add an experimental flag for this in torchtitan? Similar to Float8 training. This can also help to profile and improve INT8 training performance directly in torchtitan for future perf optimization.
cc @msaroufim
The text was updated successfully, but these errors were encountered: