Skip to content

Conversation

@MatthewBonanni
Copy link
Contributor

@MatthewBonanni MatthewBonanni commented Oct 23, 2025

Purpose

Per suggestion by @LucasWilkinson, fixes #27414 by choosing short vs long RoPE factors at init time based on max_model_len

Test Plan

Functionality

vllm serve microsoft/Phi-3.5-mini-instruct

guidellm benchmark \
  --target "http://localhost:8080" \
  --rate-type concurrent \
  --rate 16 \
  --data "prompt_tokens=3900,output_tokens=500" \
  --max-requests 100

Accuracy

vllm serve microsoft/Phi-3.5-mini-instruct --max-model-len <len>

lm_eval \
  --model local-chat-completions \
  --model_args base_url=http://localhost:8000/v1/chat/completions,
               model=microsoft/Phi-3.5-mini-instruct,
               num_concurrent=32,
               max_retries=3,
               tokenized_requests=False \
  --tasks gsm8k \
  --apply_chat_template

Test Result

Functionality

No longer produces random tokens after threshold

Accuracy

with max-model-len=4096 (new default):

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7559|±  |0.0118|
|     |       |strict-match    |     5|exact_match|↑  |0.5679|±  |0.0136|

with max-model-len=4097 (to trigger long rope)

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6702|±  |0.0129|
|     |       |strict-match    |     5|exact_match|↑  |0.3071|±  |0.0127|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug in the LongRoPE implementation for Phi-3 models. Previously, the selection between short and long RoPE factors was dynamic, which could lead to KV cache invalidation during generation. The proposed fix makes this decision static at model initialization time, based on max_model_len, ensuring consistency. My review includes a suggestion to improve the implementation by removing a dependency on global configuration, which enhances the component's modularity and testability.

@LucasWilkinson
Copy link
Collaborator

Can you please provide gsm8k with a max_model_len > original_max_position_embeddings and max_model_len <= original_max_position_embeddings

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
@MatthewBonanni
Copy link
Contributor Author

@LucasWilkinson done!

@LucasWilkinson LucasWilkinson added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 23, 2025
@LucasWilkinson LucasWilkinson enabled auto-merge (squash) October 23, 2025 20:09
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
auto-merge was automatically disabled October 23, 2025 20:15

Head branch was pushed to by a user without write access

@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) October 23, 2025 21:39
Comment on lines +2123 to +2130
# For LongRoPE, default to original_max_position_embeddings to avoid
# performance degradation for shorter sequences
if rope_scaling is not None and rope_scaling["rope_type"] == "longrope":
max_model_len = int(
getattr(
hf_config, "original_max_position_embeddings", derived_max_model_len
)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this not belong above where derived_max_model_len is calculated?

It could come after the first rope_scaling check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to override the default max_model_len to original_max_position_embeddings while still allowing the user to manually specify a max_model_len which is larger than this value. I had originally written the code as modifying derived_max_model_len, but (as I read it) derived_max_model_len is really intended to be max_position_embeddings, because of the later check in line 2143:

https://github.com/vllm-project/vllm/pull/27431/files/684e475735f9594a784dd62aa14e0845884872db#diff-998c640befaf137b9af825f29f4e6e47d273caab1fd04093c97df24b18f5c417L2134

If max_model_len is manually set to a value exceeding derived_max_model_len, it will throw an error. Hopefully I'm understanding your question correctly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for explaining, I think I understand now!

auto-merge was automatically disabled October 27, 2025 14:46

Head branch was pushed to by a user without write access

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) October 28, 2025 04:35
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) October 28, 2025 04:35
@DarkLight1337 DarkLight1337 merged commit 44b5ce9 into vllm-project:main Oct 28, 2025
50 checks passed
@MatthewBonanni MatthewBonanni deleted the fix_longrope branch October 28, 2025 14:02
bhagyashrigai pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Oct 29, 2025
…lm-project#27431)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Bhagyashri <Bhagyashri.Gaikwad2@ibm.com>
ilmarkov pushed a commit to neuralmagic/vllm that referenced this pull request Nov 7, 2025
ZhengHongming888 pushed a commit to ZhengHongming888/vllm that referenced this pull request Nov 8, 2025
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: LongRoPE transition causes random outputs mid-sequence in Phi-3.5

5 participants