-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor][CPP] Fix int8 cvt half #136353
[Inductor][CPP] Fix int8 cvt half #136353
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136353
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 3a43c94 with merge base failed to retrieve merge base, please contact dev infra: FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 3a3a2525167d209cb087a0e3294b7362893bf2bd Pull Request resolved: #136353
ghstack-source-id: 0f08e4c6f85916fe0b80247b3faa4965b1542ee8 Pull Request resolved: #136353
ghstack-source-id: 6bc8a27d35cab306b15f0b1150a817ad951eaee7 Pull Request resolved: #136353
@pytorchbot merge |
Merge failedReason: Approvers from one of the following sets are needed:
|
Hi @jerryzh168, could you kindly help to take a look of this PR? it will fix failure of pytorch/ao#884 with FP16 input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamping
please make sure test passes before landing |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fix the correctness issue of pytorch/ao#884. The current implementation for converting between `Half/BFloat16` and `int8/uint8` incorrectly assumes that 1/4 of the int8/uint8 vector lane maps to 1/2 of the Half/BFloat16 vector lane. This assumption leads to accuracy issues after the full bit-width vectorization of the Half data type was introduced. When converting between int8 weights and the half data type, the generated code is as the following: ``` #include "/tmp/torchinductor_leslie/xw/cxww3s7wxrujoyxna7mlcjktid2uu6nntixqwm542xfkd756gl3x.h" extern "C" void kernel(const int8_t* in_ptr0, half* out_ptr0) { { for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2048L); x0+=static_cast<int64_t>(32L)) { auto tmp0 = at::vec::Vectorized<int8_t>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32)); auto tmp1 = at::vec::convert<half>(tmp0); tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(32)); } } } ``` In this PR, we address the issue by changing the implementation to convert 1/2 of the int8/uint8 vector lane into a full vector lane of Half/BFloat16. **TestPlan** * AO: `python test/integration/test_integration.py -k test_int8_weight_only_quant_subclass_api` * `python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_convert_int8_to_half_vec` * Due to the CPP backend legalization pass, we are unable to create a unit test to simulate the conversion from `Half` to `int8`. Instead, we rely on a C++ test case. * `./build/bin/vec_test_all_types_AVX512 --gtest_filter="VecConvertTestsReducedFloat/*.ConvertReduced"` * `./build/bin/vec_test_all_types_AVX2 --gtest_filter="VecConvertTestsReducedFloat/*.ConvertReduced"` Pull Request resolved: pytorch#136353 Approved by: https://github.com/jgong5, https://github.com/jerryzh168
…136630) **Summary** Optimize the WOQ int8 AMX performance by changing the int8 -> bf16 conversion. Earlier, 16 int8 elements were being loaded at a time & converted to 16 BF16 elements. With this change, 32 int8 elements will be loaded at a time, and converted to a cache-line of 32 BF16 elements more efficiently. Performance before ``` AUTOTUNE _weight_int8pack_mm(4096x4096, 4096x4096, 4096) cpp_packed_gemm_0 38.0439 ms 100.0% _weight_int8pack_mm 50.2524 ms 75.7% SingleProcess AUTOTUNE benchmarking takes 1.1087 seconds and 1.9791 seconds precompiling AUTOTUNE _weight_int8pack_mm(4096x4096, 11008x4096, 11008) cpp_packed_gemm_4 78.2038 ms 100.0% _weight_int8pack_mm 119.1962 ms 65.6% SingleProcess AUTOTUNE benchmarking takes 1.9274 seconds and 1.9949 seconds precompiling AUTOTUNE _weight_int8pack_mm(4096x11008, 4096x11008, 4096) cpp_packed_gemm_6 79.2368 ms 100.0% _weight_int8pack_mm 118.3212 ms 67.0% SingleProcess AUTOTUNE benchmarking takes 1.9200 seconds and 2.0015 seconds precompiling AUTOTUNE _weight_int8pack_mm(4096x4096, 32000x4096, 32000) cpp_packed_gemm_224 225.7201 ms 100.0% _weight_int8pack_mm 388.5588 ms 58.1% ``` Performance after this PR ``` AUTOTUNE _weight_int8pack_mm(4096x4096, 4096x4096, 4096) cpp_packed_gemm_0 11.0086 ms 100.0% _weight_int8pack_mm 50.2918 ms 21.9% SingleProcess AUTOTUNE benchmarking takes 1.0837 seconds and 2.0301 seconds precompiling AUTOTUNE _weight_int8pack_mm(4096x4096, 11008x4096, 11008) cpp_packed_gemm_4 24.3528 ms 100.0% _weight_int8pack_mm 119.8492 ms 20.3% SingleProcess AUTOTUNE benchmarking takes 1.8303 seconds and 1.8195 seconds precompiling AUTOTUNE _weight_int8pack_mm(4096x11008, 4096x11008, 4096) cpp_packed_gemm_6 24.6148 ms 100.0% _weight_int8pack_mm 119.1908 ms 20.7% SingleProcess AUTOTUNE benchmarking takes 1.8315 seconds and 1.8352 seconds precompiling AUTOTUNE _weight_int8pack_mm(4096x4096, 32000x4096, 32000) cpp_packed_gemm_224 78.1369 ms 100.0% _weight_int8pack_mm 387.6289 ms 20.2% SingleProcess AUTOTUNE benchmarking takes 4.5059 seconds and 1.8010 seconds precompiling ``` Pull Request resolved: #136630 Approved by: https://github.com/jgong5 ghstack dependencies: #136353
…ytorch#136630) **Summary** Optimize the WOQ int8 AMX performance by changing the int8 -> bf16 conversion. Earlier, 16 int8 elements were being loaded at a time & converted to 16 BF16 elements. With this change, 32 int8 elements will be loaded at a time, and converted to a cache-line of 32 BF16 elements more efficiently. Performance before ``` AUTOTUNE _weight_int8pack_mm(4096x4096, 4096x4096, 4096) cpp_packed_gemm_0 38.0439 ms 100.0% _weight_int8pack_mm 50.2524 ms 75.7% SingleProcess AUTOTUNE benchmarking takes 1.1087 seconds and 1.9791 seconds precompiling AUTOTUNE _weight_int8pack_mm(4096x4096, 11008x4096, 11008) cpp_packed_gemm_4 78.2038 ms 100.0% _weight_int8pack_mm 119.1962 ms 65.6% SingleProcess AUTOTUNE benchmarking takes 1.9274 seconds and 1.9949 seconds precompiling AUTOTUNE _weight_int8pack_mm(4096x11008, 4096x11008, 4096) cpp_packed_gemm_6 79.2368 ms 100.0% _weight_int8pack_mm 118.3212 ms 67.0% SingleProcess AUTOTUNE benchmarking takes 1.9200 seconds and 2.0015 seconds precompiling AUTOTUNE _weight_int8pack_mm(4096x4096, 32000x4096, 32000) cpp_packed_gemm_224 225.7201 ms 100.0% _weight_int8pack_mm 388.5588 ms 58.1% ``` Performance after this PR ``` AUTOTUNE _weight_int8pack_mm(4096x4096, 4096x4096, 4096) cpp_packed_gemm_0 11.0086 ms 100.0% _weight_int8pack_mm 50.2918 ms 21.9% SingleProcess AUTOTUNE benchmarking takes 1.0837 seconds and 2.0301 seconds precompiling AUTOTUNE _weight_int8pack_mm(4096x4096, 11008x4096, 11008) cpp_packed_gemm_4 24.3528 ms 100.0% _weight_int8pack_mm 119.8492 ms 20.3% SingleProcess AUTOTUNE benchmarking takes 1.8303 seconds and 1.8195 seconds precompiling AUTOTUNE _weight_int8pack_mm(4096x11008, 4096x11008, 4096) cpp_packed_gemm_6 24.6148 ms 100.0% _weight_int8pack_mm 119.1908 ms 20.7% SingleProcess AUTOTUNE benchmarking takes 1.8315 seconds and 1.8352 seconds precompiling AUTOTUNE _weight_int8pack_mm(4096x4096, 32000x4096, 32000) cpp_packed_gemm_224 78.1369 ms 100.0% _weight_int8pack_mm 387.6289 ms 20.2% SingleProcess AUTOTUNE benchmarking takes 4.5059 seconds and 1.8010 seconds precompiling ``` Pull Request resolved: pytorch#136630 Approved by: https://github.com/jgong5 ghstack dependencies: pytorch#136353
Stack from ghstack (oldest at bottom):
Fix the correctness issue of pytorch/ao#884. The current implementation for converting between
Half/BFloat16
andint8/uint8
incorrectly assumes that 1/4 of the int8/uint8 vector lane maps to 1/2 of the Half/BFloat16 vector lane. This assumption leads to accuracy issues after the full bit-width vectorization of the Half data type was introduced. When converting between int8 weights and the half data type, the generated code is as the following:In this PR, we address the issue by changing the implementation to convert 1/2 of the int8/uint8 vector lane into a full vector lane of Half/BFloat16.
TestPlan
python test/integration/test_integration.py -k test_int8_weight_only_quant_subclass_api
python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_convert_int8_to_half_vec
Half
toint8
. Instead, we rely on a C++ test case../build/bin/vec_test_all_types_AVX512 --gtest_filter="VecConvertTestsReducedFloat/*.ConvertReduced"
./build/bin/vec_test_all_types_AVX2 --gtest_filter="VecConvertTestsReducedFloat/*.ConvertReduced"
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang