-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[Perf] Apply torch.compile for per_block_cast_to_fp8
#24611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Perf] Apply torch.compile for per_block_cast_to_fp8
#24611
Conversation
Signed-off-by: yewentao256 <zhyanwentao@126.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a high-performance Triton kernel for per_block_cast_to_fp8, delivering impressive speedups as shown in the benchmarks. The implementation is solid and the inclusion of a benchmark script is appreciated. I've identified a critical bug in the benchmark script that causes incorrect reporting and a minor inefficiency in the Triton kernel itself. Addressing these will improve the quality and correctness of this contribution.
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
|
Is this comparing to torch compiled kernel or just eager torch? |
Compared to the eager torch. Details at |
|
What is the result of making this function faster? It seems to only be used in tests or in vllm/vllm/model_executor/layers/quantization/utils/fp8_utils.py Lines 789 to 791 in cc99baf
If there isn't a measurable difference, then I'd prefer to stick to torch and at most torch.compile it |
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
|
@mgoin The main benefits of doing this is decreasing the model loading time
This branch (EngineCore_DP5 pid=4102557) INFO 09-11 11:46:47 [gpu_model_runner.py:2286] Model loading took 95.9683 GiB and 28.218114 seconds
(EngineCore_DP2 pid=4102554) INFO 09-11 11:46:47 [gpu_model_runner.py:2286] Model loading took 95.9683 GiB and 29.514149 seconds
(EngineCore_DP6 pid=4102558) INFO 09-11 11:46:47 [gpu_model_runner.py:2286] Model loading took 95.9683 GiB and 28.131574 seconds
(EngineCore_DP4 pid=4102556) INFO 09-11 11:46:47 [gpu_model_runner.py:2286] Model loading took 95.9683 GiB and 28.061149 seconds
(EngineCore_DP7 pid=4102559) INFO 09-11 11:46:47 [gpu_model_runner.py:2286] Model loading took 95.9683 GiB and 28.636173 seconds
(EngineCore_DP1 pid=4102553) INFO 09-11 11:46:47 [gpu_model_runner.py:2286] Model loading took 95.9683 GiB and 28.506170 seconds
(EngineCore_DP3 pid=4102555) INFO 09-11 11:46:47 [gpu_model_runner.py:2286] Model loading took 95.9683 GiB and 29.540192 seconds
(EngineCore_DP0 pid=4102552) INFO 09-11 11:46:47 [gpu_model_runner.py:2286] Model loading took 95.9683 GiB and 28.072259 secondsMain (EngineCore_DP1 pid=3243946) INFO 09-11 07:48:19 [gpu_model_runner.py:2251] Model loading took 95.9683 GiB and 27.890547 seconds
(EngineCore_DP2 pid=3243947) INFO 09-11 07:48:20 [gpu_model_runner.py:2251] Model loading took 95.9683 GiB and 28.551232 seconds
(EngineCore_DP4 pid=3243949) INFO 09-11 07:48:20 [gpu_model_runner.py:2251] Model loading took 95.9683 GiB and 29.112749 seconds
(EngineCore_DP7 pid=3243952) INFO 09-11 07:48:20 [gpu_model_runner.py:2251] Model loading took 95.9683 GiB and 29.540662 seconds
(EngineCore_DP3 pid=3243948) INFO 09-11 07:48:21 [gpu_model_runner.py:2251] Model loading took 95.9683 GiB and 29.962358 seconds
(EngineCore_DP0 pid=3243945) INFO 09-11 07:48:21 [gpu_model_runner.py:2251] Model loading took 95.9683 GiB and 29.903493 seconds
(EngineCore_DP5 pid=3243950) INFO 09-11 07:48:21 [gpu_model_runner.py:2251] Model loading took 95.9683 GiB and 30.417099 seconds
(EngineCore_DP6 pid=3243951) INFO 09-11 07:48:22 [gpu_model_runner.py:2251] Model loading took 95.9683 GiB and 31.128308 seconds |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like there is not a noticeable benefit on model loading time, so my feeling is still to stick to torch.
We have too much complexity in vLLM, so we have to take simplicity and portable logic when we can get it. Thank you for the work, but we should focus on more impactful problems
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mgoin OK, let's keep it simple, I removed all of the triton code and simply add @torch.compile
per_block_cast_to_fp8, 6x fasterper_block_cast_to_fp8
Signed-off-by: yewentao256 <zhyanwentao@126.com>
…24611) Signed-off-by: yewentao256 <zhyanwentao@126.com>
…24611) Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
…24611) Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: gaojc <1055866782@qq.com>
…24611) Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…24611) Signed-off-by: yewentao256 <zhyanwentao@126.com>
…24611) Signed-off-by: yewentao256 <zhyanwentao@126.com>
…24611) Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Purpose
Develop for #24607
Test Plan
Unit test & benchmark
python benchmark_per_block_cast_to_fp8.py --use-ue8m0GPU: NVIDIA B200 dtype=bf16, UE8M0=True, block=(128x128) Shape | Baseline (ms) | Triton (ms) | Speedup | Y equal | Y maxdiff | S equal | S maxdiff -------------------------------------------------------------------------------------- 128x128 | 0.090 | 0.028 | 3.20 | True | 0.00e+00 | True | 0.00e+00 1024x1024 | 0.087 | 0.028 | 3.12 | True | 0.00e+00 | True | 0.00e+00 2048x4096 | 0.157 | 0.028 | 5.57 | True | 0.00e+00 | True | 0.00e+00 4096x4096 | 0.252 | 0.037 | 6.81 | True | 0.00e+00 | True | 0.00e+00 4096x8192 | 0.464 | 0.068 | 6.84 | True | 0.00e+00 | True | 0.00e+00 8192x4096 | 0.464 | 0.068 | 6.80 | True | 0.00e+00 | True | 0.00e+00 3000x4097 | 0.255 | 0.053 | 4.85 | True | 0.00e+00 | True | 0.00e+00 7168x7168 | 0.679 | 0.098 | 6.91 | True | 0.00e+00 | True | 0.00e+00 16384x32768 | 6.294 | 0.952 | 6.61 | True | 0.00e+00 | True | 0.00e+00python benchmark_per_block_cast_to_fp8.pyGPU: NVIDIA B200 dtype=bf16, UE8M0=False, block=(128x128) Shape | Baseline (ms) | Triton (ms) | Speedup | Y equal | Y maxdiff | S equal | S maxdiff -------------------------------------------------------------------------------------- 128x128 | 0.064 | 0.028 | 2.26 | True | 0.00e+00 | True | 0.00e+00 1024x1024 | 0.064 | 0.027 | 2.35 | True | 0.00e+00 | True | 0.00e+00 2048x4096 | 0.134 | 0.029 | 4.59 | True | 0.00e+00 | True | 0.00e+00 4096x4096 | 0.229 | 0.037 | 6.13 | False | 1.95e-03 | True | 0.00e+00 4096x8192 | 0.439 | 0.068 | 6.51 | False | 1.95e-03 | True | 0.00e+00 8192x4096 | 0.439 | 0.068 | 6.44 | False | 1.95e-03 | True | 0.00e+00 3000x4097 | 0.231 | 0.032 | 7.33 | True | 0.00e+00 | True | 0.00e+00 7168x7168 | 0.654 | 0.098 | 6.64 | False | 1.95e-03 | True | 0.00e+00 16384x32768 | 6.269 | 0.969 | 6.47 | False | 1.95e-03 | True | 0.00e+00Accuracy
VLLM_USE_DEEP_GEMM=1 lm_eval --model vllm --model_args "pretrained=Qwen/Qwen3-30B-A3B-FP8,max_model_len=32768,enforce_eager=True" --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto