-
Notifications
You must be signed in to change notification settings - Fork 369
[Reland][CPU] Support int8 scaled embedding bag #3060
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Reland][CPU] Support int8 scaled embedding bag #3060
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3060
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 02a1bc4 with merge base 0d3217d ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ac75d21 to
e846df7
Compare
| } | ||
|
|
||
| } // namespace torchao | ||
| } // namespace torchao |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unnecessary change?
|
CC @mingfeima for review. Thanks. |
| # Next setp: support more out_dtype | ||
| out_dtype = torch.float32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this arg be exposed to the op as well in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this arg be exposed to the op as well in the future?
Yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamping, please make sure CI passes
|
Hi @mingfeima Could you please review this PR? Thanks. |
| #endif | ||
|
|
||
| template <typename index_t> | ||
| template <typename index_t, typename data_t> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not mandatory, but you can use block_dim as an template argument and make this function simpler.
template <typename index_t, typename scalar_t, int block_dim>
|
one more thing, it would be better to refactor the fp8 conversition simd code with https://github.com/pytorch/ao/blob/main/torchao/csrc/cpu/aten_kernels/float8_linear.cpp, maybe put them in to vec util.h. |
int8 scaled_embedding_bag reverted by #2974
On #2972, they shared error reason related to the unused qtype.
We re-enable int8 scaled_embedding_bag and fix the unused variables issue on this PR