Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Fix sharing of stateful logits processors #5329

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

maxdebayser
Copy link
Contributor

This is related to PR 4109

In vLLM users can submit more than one sequence at the same time with the completion request in the legacy OpenAI API. Since the sampling params are validated only once a single SamplingsParam object is shared between all sequences even when they belong to different sequence groups. The SamplingsParam is mostly a data object, but it has a list of logits processors that are executable. If the logits processors are stateless there is no problem. But there are implementations of logits processors that have internal state that depends on the sequence seen so far.

For example, the CFGLogitsProcessor would crash with the request below:

curl http://localhost:8000/v1/completions   -H "Content-Type: application/json"   -d '{
    "model": "meta-llama/Llama-2-7b-hf",
    "prompt": ["An example of a json document: ", "Another example of a json document: "],
    "max_tokens": 100,
    "temperature": 0,
    "guided_decoding_backend": "outlines",
    "response_format": {"type":"json_object"},
  }' | jq

resulting in the following error:

  File "/home/vllm/.local/lib/python3.11/site-packages/vllm/model_executor/layers/logits_processor.py", line 57, in forward
    logits = _apply_logits_processors(logits, sampling_metadata)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vllm/.local/lib/python3.11/site-packages/vllm/model_executor/layers/logits_processor.py", line 106, in _apply_logits_processors
    logits_row = logits_processor(token_ids, logits_row)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vllm/.local/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/outlines_logits_processors.py", line 50, in __call__
    self.fsm_state[seq_id] = self.fsm.next_state(
                             ^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm/lib/python3.11/site-packages/outlines/fsm/fsm.py", line 364, in next_state
    return self.regex_fsm_last.next_state(state, token_id)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm/lib/python3.11/site-packages/outlines/fsm/fsm.py", line 178, in next_state
    last_token_to_end_state = self.states_to_token_maps[state]
                              ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^
KeyError: 6

PR 4109 solves the crashing behavior, but the output is still affected by the sharing problem.

The additional solution proposed here is based on adding factories for stateful logits processors. The basic idea is:

  1. We add processors and factories to the same list so that they are in the correct order
  2. We add a logits_processors list to the SequenceGroupState object
  3. When the SequenceGroup is created, we iterate over the sampling_params.logits_processors
    and copy the logits_processors and call the factories to populate SequenceData.logits_processors
  4. The LogitsProcessor(nn.Module) will iterate over the SequenceData.logits_processors instead of
    the sampling_params.logits_processors

Here are some diagrams to illustrate the current code structure to better visualize the proposed changes:

vllm_sampling
vllm_seq_classes

There are two scenarios that need further investigation:

  • Sequences that are preempted with the recompute policy: possibly the state needs to be reset.
  • Forking of sequences during beam search: would a deep copy of the processor be feasible?

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

This allows vllm to support stateful LPs that must be
unique for each sequence.

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
@njhill
Copy link
Member

njhill commented Jun 7, 2024

Thanks @maxdebayser! You need to run format.sh to fix the linting, and it looks like there's some tests that need updating.

@maxdebayser
Copy link
Contributor Author

maxdebayser commented Jun 7, 2024 via email

@njhill
Copy link
Member

njhill commented Jun 7, 2024

@br3no I have a couple of questions:

  1. Is it possible / relatively cheap to clone a Guide that has been built based on particular parameters before it is used, so that each copy may be used independently by concurrent sequences?
  2. Irrespective of the answer to (1), can one of these Guide instances be reset/reused for a separate sequence (sharing the same params) once it has finished being used for a first one?

Unless (1) is possible and very cheap, the best option would be to extend this PR to support pooling and make use of (2), assuming that (2) is possible. This library looks like exactly what we'd need for that: https://pypi.org/project/pondpond/, in conjunction with finding the right place to hook the sequence completion event (possibly Sequence destructor).

@br3no
Copy link
Contributor

br3no commented Jun 9, 2024

@njhill

Is it possible / relatively cheap to clone a Guide that has been built based on particular parameters before it is used, so that each copy may be used independently by concurrent sequences?

I haven't benchmarked anything, but the code suggests that copying the CFGGuide is currently as expensive as creating it from scratch. This is in line with my impression while working on it.
Which doesn't mean it wouldn't be possible to improve this in Outlines in principle.

There is no support for resetting the CFGGuide in Outlines currently. So (2) is most probably also a negative.

I believe to improve this situation, we need to work with the Outlines community.

NOTE This problem only applies to the CFGGuide. The other Guides are stateless.

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
@njhill
Copy link
Member

njhill commented Jun 10, 2024

Thanks @br3no, so CFGGuide is used for json mode but these currently can't be reused. Yes I think this will be important to fix (in outlines it sounds like).

@njhill
Copy link
Member

njhill commented Jun 10, 2024

Ah I just remembered there is another PR #5006 related to this that I had been meaning to look at, hopefully that can help solve it!

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
@jon-chuang
Copy link
Contributor

jon-chuang commented Aug 8, 2024

Forking of sequences during beam search: would a deep copy of the processor be feasible?

Sounds like there should be a replay of the FSM, at the very least; there should be a parameter passed when initializing new logit_processor, so that, since FSM is essentially autoregressive, it can recover the state from full list of tokens.

Just some more plumbing required to ensure that when you call method clone_with_history or something, it does the replay instead of just pulling the last token_id

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
@maxdebayser
Copy link
Contributor Author

I've synced the branch with main and all the unit tests that are touched by this PR work, but guided decoding in main currently seems to be broken: #7557

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants