Skip to content

Conversation

@ProExpertProg
Copy link
Collaborator

@ProExpertProg ProExpertProg commented Dec 4, 2024

This PR adds support for RMSNorm + (fp8) quant fusion. It also refactors the fusion pass to make it easier to add patterns. That includes support for multiple values of epsilon.

@github-actions
Copy link

github-actions bot commented Dec 4, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Dec 4, 2024
@mgoin
Copy link
Member

mgoin commented Dec 5, 2024

@ProExpertProg It looks like there is a real failure on TPU test https://buildkite.com/vllm/fastcheck/builds/9325#01939413-9e42-4fb8-898e-dfd2a3ec7828/6-306

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 5, 2024
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work. I mostly just have nits.

ProExpertProg and others added 18 commits December 11, 2024 21:43
Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: luka <luka@neuralmagic.com>
…ops to constants

Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: luka <luka@neuralmagic.com>
- extracted MultiOutputMatch to own file
- extracted utils to fx_utils
- added named tuples for op keys

Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: luka <luka@neuralmagic.com>
Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found some potential integer overflows. I don't think those will automatically promote to 64-bit even though the output is int64_t

@ProExpertProg ProExpertProg force-pushed the luka/rms-norm-fusion-refactor branch from 17ff1b9 to 720d537 Compare December 12, 2024 20:16
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future PR: use compilation counters for patterns replaced

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(use context manager to check the counter is increased by certain number)

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, excited to see the perf numbers!

@youkaichao
Copy link
Member

please also work with @tlrmchlsmth to address his comments

Signed-off-by: luka <luka@neuralmagic.com>
@ProExpertProg ProExpertProg force-pushed the luka/rms-norm-fusion-refactor branch from 21abcff to a70d496 Compare December 12, 2024 22:19
- add kFp8Type constant for cuda/hip agnostic torch type checking
- check contiguous
- overflow
- reduce number of tests

Signed-off-by: luka <luka@neuralmagic.com>
Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thank you for the great work @ProExpertProg!

@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) December 12, 2024 23:38
@ProExpertProg
Copy link
Collaborator Author

Seeing a 1-2% improvement in TPOT and 2-5% in TTFT.

Fused:


===== RUNNING neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 FOR 120 PROMPTS WITH 1 QPS =====

Namespace(backend='vllm', base_url=None, host='localhost', port=8000, endpoint='/v1/completions', dataset=None, dataset_name='sonnet', dataset_path='sonnet.txt', max_concurrency=None, model='neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8', tokenizer=None, best_of=1, use_beam_search=False, num_prompts=120, logprobs=None, request_rate=1.0, burstiness=1.0, seed=0, trust_remote_code=False, disable_tqdm=False, profile=False, save_result=False, metadata=None, result_dir=None, result_filename=None, ignore_eos=False, percentile_metrics='ttft,tpot,itl', metric_percentiles='99', goodput=None, sonnet_input_len=2048, sonnet_output_len=128, sonnet_prefix_len=200, sharegpt_output_len=None, random_input_len=1024, random_output_len=128, random_range_ratio=1.0, random_prefix_len=0, hf_subset=None, hf_split=None, hf_output_len=None)
Starting initial single prompt test run...
Initial test run completed. Starting main benchmark run...
Traffic request rate: 1.0
Burstiness factor: 1.0 (Poisson process)
Maximum request concurrency: None
============ Serving Benchmark Result ============
Successful requests:                     120       
Benchmark duration (s):                  116.77    
Total input tokens:                      224758    
Total generated tokens:                  15360     
Request throughput (req/s):              1.03      
Output token throughput (tok/s):         131.55    
Total Token throughput (tok/s):          2056.42   
---------------Time to First Token----------------
Mean TTFT (ms):                          49.27     
Median TTFT (ms):                        47.41     
P99 TTFT (ms):                           76.35     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          5.38      
Median TPOT (ms):                        5.27      
P99 TPOT (ms):                           6.20      
---------------Inter-token Latency----------------
Mean ITL (ms):                           5.38      
Median ITL (ms):                         5.18      
P99 ITL (ms):                            12.97     
==================================================

===== RUNNING neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 FOR 1200 PROMPTS WITH 10 QPS =====

Namespace(backend='vllm', base_url=None, host='localhost', port=8000, endpoint='/v1/completions', dataset=None, dataset_name='sonnet', dataset_path='sonnet.txt', max_concurrency=None, model='neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8', tokenizer=None, best_of=1, use_beam_search=False, num_prompts=1200, logprobs=None, request_rate=10.0, burstiness=1.0, seed=0, trust_remote_code=False, disable_tqdm=False, profile=False, save_result=False, metadata=None, result_dir=None, result_filename=None, ignore_eos=False, percentile_metrics='ttft,tpot,itl', metric_percentiles='99', goodput=None, sonnet_input_len=2048, sonnet_output_len=128, sonnet_prefix_len=200, sharegpt_output_len=None, random_input_len=1024, random_output_len=128, random_range_ratio=1.0, random_prefix_len=0, hf_subset=None, hf_split=None, hf_output_len=None)
Starting initial single prompt test run...
Initial test run completed. Starting main benchmark run...
Traffic request rate: 10.0
Burstiness factor: 1.0 (Poisson process)
Maximum request concurrency: None
============ Serving Benchmark Result ============
Successful requests:                     1200      
Benchmark duration (s):                  122.87    
Total input tokens:                      2249049   
Total generated tokens:                  153592    
Request throughput (req/s):              9.77      
Output token throughput (tok/s):         1250.00   
Total Token throughput (tok/s):          19553.81  
---------------Time to First Token----------------
Mean TTFT (ms):                          73.50     
Median TTFT (ms):                        61.19     
P99 TTFT (ms):                           218.09    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          10.78     
Median TPOT (ms):                        10.55     
P99 TPOT (ms):                           17.48     
---------------Inter-token Latency----------------
Mean ITL (ms):                           10.78     
Median ITL (ms):                         7.20      
P99 ITL (ms):                            61.43     
==================================================

===== RUNNING neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 FOR 2400 PROMPTS WITH 20 QPS =====

Namespace(backend='vllm', base_url=None, host='localhost', port=8000, endpoint='/v1/completions', dataset=None, dataset_name='sonnet', dataset_path='sonnet.txt', max_concurrency=None, model='neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8', tokenizer=None, best_of=1, use_beam_search=False, num_prompts=2400, logprobs=None, request_rate=20.0, burstiness=1.0, seed=0, trust_remote_code=False, disable_tqdm=False, profile=False, save_result=False, metadata=None, result_dir=None, result_filename=None, ignore_eos=False, percentile_metrics='ttft,tpot,itl', metric_percentiles='99', goodput=None, sonnet_input_len=2048, sonnet_output_len=128, sonnet_prefix_len=200, sharegpt_output_len=None, random_input_len=1024, random_output_len=128, random_range_ratio=1.0, random_prefix_len=0, hf_subset=None, hf_split=None, hf_output_len=None)
Starting initial single prompt test run...
Initial test run completed. Starting main benchmark run...
Traffic request rate: 20.0
Burstiness factor: 1.0 (Poisson process)
Maximum request concurrency: None
============ Serving Benchmark Result ============
Successful requests:                     2400      
Benchmark duration (s):                  131.29    
Total input tokens:                      4498061   
Total generated tokens:                  307200    
Request throughput (req/s):              18.28     
Output token throughput (tok/s):         2339.90   
Total Token throughput (tok/s):          36601.09  
---------------Time to First Token----------------
Mean TTFT (ms):                          1462.86   
Median TTFT (ms):                        1547.10   
P99 TTFT (ms):                           3293.15   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          92.29     
Median TPOT (ms):                        96.82     
P99 TPOT (ms):                           104.09    
---------------Inter-token Latency----------------
Mean ITL (ms):                           92.23     
Median ITL (ms):                         38.79     
P99 ITL (ms):                            444.05    
==================================================

Unfused:


===== RUNNING neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 FOR 120 PROMPTS WITH 1 QPS =====

Namespace(backend='vllm', base_url=None, host='localhost', port=8000, endpoint='/v1/completions', dataset=None, dataset_name='sonnet', dataset_path='sonnet.txt', max_concurrency=None, model='neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8', tokenizer=None, best_of=1, use_beam_search=False, num_prompts=120, logprobs=None, request_rate=1.0, burstiness=1.0, seed=0, trust_remote_code=False, disable_tqdm=False, profile=False, save_result=False, metadata=None, result_dir=None, result_filename=None, ignore_eos=False, percentile_metrics='ttft,tpot,itl', metric_percentiles='99', goodput=None, sonnet_input_len=2048, sonnet_output_len=128, sonnet_prefix_len=200, sharegpt_output_len=None, random_input_len=1024, random_output_len=128, random_range_ratio=1.0, random_prefix_len=0, hf_subset=None, hf_split=None, hf_output_len=None)
Starting initial single prompt test run...
Initial test run completed. Starting main benchmark run...
Traffic request rate: 1.0
Burstiness factor: 1.0 (Poisson process)
Maximum request concurrency: None
============ Serving Benchmark Result ============
Successful requests:                     120       
Benchmark duration (s):                  116.77    
Total input tokens:                      224758    
Total generated tokens:                  15360     
Request throughput (req/s):              1.03      
Output token throughput (tok/s):         131.54    
Total Token throughput (tok/s):          2056.34   
---------------Time to First Token----------------
Mean TTFT (ms):                          49.56     
Median TTFT (ms):                        47.19     
P99 TTFT (ms):                           81.86     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          5.47      
Median TPOT (ms):                        5.35      
P99 TPOT (ms):                           6.31      
---------------Inter-token Latency----------------
Mean ITL (ms):                           5.47      
Median ITL (ms):                         5.28      
P99 ITL (ms):                            13.13     
==================================================

===== RUNNING neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 FOR 1200 PROMPTS WITH 10 QPS =====

Namespace(backend='vllm', base_url=None, host='localhost', port=8000, endpoint='/v1/completions', dataset=None, dataset_name='sonnet', dataset_path='sonnet.txt', max_concurrency=None, model='neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8', tokenizer=None, best_of=1, use_beam_search=False, num_prompts=1200, logprobs=None, request_rate=10.0, burstiness=1.0, seed=0, trust_remote_code=False, disable_tqdm=False, profile=False, save_result=False, metadata=None, result_dir=None, result_filename=None, ignore_eos=False, percentile_metrics='ttft,tpot,itl', metric_percentiles='99', goodput=None, sonnet_input_len=2048, sonnet_output_len=128, sonnet_prefix_len=200, sharegpt_output_len=None, random_input_len=1024, random_output_len=128, random_range_ratio=1.0, random_prefix_len=0, hf_subset=None, hf_split=None, hf_output_len=None)
Starting initial single prompt test run...
Initial test run completed. Starting main benchmark run...
Traffic request rate: 10.0
Burstiness factor: 1.0 (Poisson process)
Maximum request concurrency: None
============ Serving Benchmark Result ============
Successful requests:                     1200      
Benchmark duration (s):                  122.93    
Total input tokens:                      2249049   
Total generated tokens:                  153595    
Request throughput (req/s):              9.76      
Output token throughput (tok/s):         1249.48   
Total Token throughput (tok/s):          19545.23  
---------------Time to First Token----------------
Mean TTFT (ms):                          77.81     
Median TTFT (ms):                        61.64     
P99 TTFT (ms):                           273.52    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          11.30     
Median TPOT (ms):                        10.79     
P99 TPOT (ms):                           23.10     
---------------Inter-token Latency----------------
Mean ITL (ms):                           11.30     
Median ITL (ms):                         7.40      
P99 ITL (ms):                            64.54     
==================================================

===== RUNNING neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 FOR 2400 PROMPTS WITH 20 QPS =====

Namespace(backend='vllm', base_url=None, host='localhost', port=8000, endpoint='/v1/completions', dataset=None, dataset_name='sonnet', dataset_path='sonnet.txt', max_concurrency=None, model='neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8', tokenizer=None, best_of=1, use_beam_search=False, num_prompts=2400, logprobs=None, request_rate=20.0, burstiness=1.0, seed=0, trust_remote_code=False, disable_tqdm=False, profile=False, save_result=False, metadata=None, result_dir=None, result_filename=None, ignore_eos=False, percentile_metrics='ttft,tpot,itl', metric_percentiles='99', goodput=None, sonnet_input_len=2048, sonnet_output_len=128, sonnet_prefix_len=200, sharegpt_output_len=None, random_input_len=1024, random_output_len=128, random_range_ratio=1.0, random_prefix_len=0, hf_subset=None, hf_split=None, hf_output_len=None)
Starting initial single prompt test run...
Initial test run completed. Starting main benchmark run...
Traffic request rate: 20.0
Burstiness factor: 1.0 (Poisson process)
Maximum request concurrency: None
============ Serving Benchmark Result ============
Successful requests:                     2400      
Benchmark duration (s):                  132.63    
Total input tokens:                      4498061   
Total generated tokens:                  307200    
Request throughput (req/s):              18.10     
Output token throughput (tok/s):         2316.27   
Total Token throughput (tok/s):          36231.34  
---------------Time to First Token----------------
Mean TTFT (ms):                          1921.30   
Median TTFT (ms):                        2041.44   
P99 TTFT (ms):                           4459.93   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          93.22     
Median TPOT (ms):                        98.07     
P99 TPOT (ms):                           104.94    
---------------Inter-token Latency----------------
Mean ITL (ms):                           93.13     
Median ITL (ms):                         39.18     
P99 ITL (ms):                            431.80    
==================================================

@tlrmchlsmth tlrmchlsmth merged commit 30870b4 into vllm-project:main Dec 13, 2024
77 checks passed
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Signed-off-by: luka <luka@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants