Skip to content

Conversation

@njhill
Copy link
Member

@njhill njhill commented Mar 9, 2025

  • Handle requests not finishing at the same time
  • Coalescing outputs from different child requests
  • Aborting n > 1 requests (parent request id needs to propagate to child requests)

Thanks to @himanshujaju for reporting the first two of these bugs in slack.

cc @markmc @afeldman-nm

- Requests not finishing at the same time
- Coalescing outputs from different child requests
- Aborting n > 1 requests (parent request id needs to propagate to child requests)

Thanks to @himanshujaju for reporting the first two of these bugs

Signed-off-by: Nick Hill <nhill@redhat.com>
@github-actions
Copy link

github-actions bot commented Mar 9, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

self.num_cached_tokens = num_cached_tokens

@classmethod
def new(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is now unused

Comment on lines -172 to -174
self.prompt = next_output.prompt
self.prompt_token_ids = next_output.prompt_token_ids
self.prompt_logprobs = next_output.prompt_logprobs
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was another preexisting bug

Comment on lines +155 to +156
completion.finish_reason = next_completion.finish_reason
completion.stop_reason = next_completion.stop_reason
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were another preexisting omission

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love that we have a higher layer aggregating delta RequestOutputs using this method (not least because it seems I basically missed it completely!)

There's no particular reason why OutputProcessor.process_outputs() couldn't ensure that it has aggregated before pushing the RequestOutput to the per-request queue - it's not like there's any advantage to pushing to the queue early, since this method doesn't yield

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be tricky and require some care though - I wouldn't attempt my suggestion as part of fixing this bug 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this aggregation is a non-deterministic performance optimization/safety-net. If things are running fast enough i.e. the queue then no aggregation will happen, but we have seen circumstances where the queues back up, and so it's better to coalesce the messages rather than incurring the overhead of sending them separately (and the additional asyncio task churn that it entails).

We could still do it in the output processor I guess but that would mean doing something a bit more custom than an asyncio.Queue (e.g. just using asyncio.Event)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW I am planning to change this as a follow-on.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@markmc FYI here is that follow-on change: #15156

@markmc
Copy link
Member

markmc commented Mar 10, 2025

It seems the tests for this code path aren't enabled in CI, and when I had been running them locally I wasn't aware I had to set VLLM_USE_V1=1 ... no wonder they never failed for me!

Suggest we include the following in this PR:

diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 2af76cb24..f84710b74 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -202,6 +202,8 @@ steps:
     # split the test to avoid interference
     - VLLM_USE_V1=1 pytest -v -s v1/core
     - VLLM_USE_V1=1 pytest -v -s v1/engine
+    # TODO: enable rest of v1/entrypoints
+    - VLLM_USE_V1=1 pytest -v -s v1/entrypoints/openai
     - VLLM_USE_V1=1 pytest -v -s v1/sample
     - VLLM_USE_V1=1 pytest -v -s v1/worker
     - VLLM_USE_V1=1 pytest -v -s v1/structured_output

Adding the TODO for now because attempting to enable all of v1/entrypoints gives:

RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

Copy link
Member

@markmc markmc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess all my comments suggest further refactoring, but none of them should take priority over getting the bug fix merged

Comment on lines +155 to +156
completion.finish_reason = next_completion.finish_reason
completion.stop_reason = next_completion.stop_reason
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love that we have a higher layer aggregating delta RequestOutputs using this method (not least because it seems I basically missed it completely!)

There's no particular reason why OutputProcessor.process_outputs() couldn't ensure that it has aggregated before pushing the RequestOutput to the per-request queue - it's not like there's any advantage to pushing to the queue early, since this method doesn't yield

if not outputs:
return None
request_id = self.parent_req.request_id
finished = not self.parent_req.child_requests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything under the parent_req is not None case should be encapsulated in parallel_sampling.py IMO - e.g. the parent request ID and finished handling shouldn't be in this file

self.request_states[request_id] = req_state
self.lora_states.add_request(req_state)
if parent_req:
self.parent_requests[parent_req.request_id] = parent_req
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this request_id -> ParentRequest mapping make more sense at the AsyncLLM level?

The only place we're using it is in aborts - so we could have this stuff at the AsyncLLM level:

                parent = self.parent_requests.pop(request_id, None)
                if parent and parent.child_requests:
                    self.abort_requests(parent.child_requests)
                    request_ids_to_abort.extend(parent.child_requests)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this request_id -> ParentRequest mapping make more sense at the AsyncLLM level?

Kind of yes but AsyncLLM doesn't otherwise currently track request-level state and this logic would then need to be duplicated in AsyncLLM and LLMEngine. So I'm not sure which is best TBH

self,
request_ids: list[str],
) -> None:
request_ids: Iterable[str],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are parent request IDs, but OutputProcessor.add_request() deals with child request IDs - don't love the asymmetry

Comment on lines +155 to +156
completion.finish_reason = next_completion.finish_reason
completion.stop_reason = next_completion.stop_reason
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be tricky and require some care though - I wouldn't attempt my suggestion as part of fixing this bug 👍

@markmc
Copy link
Member

markmc commented Mar 10, 2025

As per my comment above, it would be good to check we have CI coverage for all three of these if possible:

  • Handle requests not finishing at the same time
  • Coalescing outputs from different child requests
  • Aborting n > 1 requests (parent request id needs to propagate to child requests)

@markmc
Copy link
Member

markmc commented Mar 10, 2025

Adding the TODO for now because attempting to enable all of v1/entrypoints gives:

RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

As discussed with @russellb just now, this is reproducible just running VLLM_USE_V1=1 pytest -s -v tests/v1/entrypoints/llm too ... looks like something in there might be initializing CUDA before forking the engine core

@njhill
Copy link
Member Author

njhill commented Mar 10, 2025

Adding the TODO for now because attempting to enable all of v1/entrypoints gives:

RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

As discussed with @russellb just now, this is reproducible just running VLLM_USE_V1=1 pytest -s -v tests/v1/entrypoints/llm too ... looks like something in there might be initializing CUDA before forking the engine core

Can we change this test to use spawn?

Signed-off-by: Nick Hill <nhill@redhat.com>
@markmc
Copy link
Member

markmc commented Mar 10, 2025

Adding the TODO for now because attempting to enable all of v1/entrypoints gives:

RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

As discussed with @russellb just now, this is reproducible just running VLLM_USE_V1=1 pytest -s -v tests/v1/entrypoints/llm too ... looks like something in there might be initializing CUDA before forking the engine core

Can we change this test to use spawn?

See #14579 - the issue is in the structured output tests, and we're struggling to reproduce it

You should be able to enable tests/v1/entrypoints/openai/test_completion.py in this PR 👍

Signed-off-by: Nick Hill <nhill@redhat.com>
@robertgshaw2-redhat
Copy link
Collaborator

This is a genuine failure in the CI

@russellb
Copy link
Member

This is a genuine failure in the CI

which, the structured output one?

@markmc
Copy link
Member

markmc commented Mar 11, 2025

This is the valid failure:

[2025-03-11T01:02:04Z]     def _verify_greedy_sampling(self) -> None:
[2025-03-11T01:02:04Z]         if self.n > 1:
[2025-03-11T01:02:04Z] >           raise ValueError("n must be 1 when using greedy sampling, "
[2025-03-11T01:02:04Z]                              f"got {self.n}.")
[2025-03-11T01:02:04Z] E           ValueError: n must be 1 when using greedy sampling, got 3.

n = 3
max_tokens = 5
max_tokens = 50 # we want some to finish earlier than others

Copy link
Member

@markmc markmc Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file isn't actually run in the V1 tests in the buildkite config

See #14579

@njhill
Copy link
Member Author

njhill commented Mar 11, 2025

Sorry, I will get this fixed up, had intermittent access on a plane yesterday so just pushed without testing.

Signed-off-by: Nick Hill <nhill@redhat.com>
@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 11, 2025
@simon-mo simon-mo added this to the v0.8.0 milestone Mar 12, 2025
@simon-mo simon-mo merged commit f5d3acd into vllm-project:main Mar 12, 2025
42 checks passed
@njhill njhill deleted the fix-n-gt-1 branch March 12, 2025 20:07
richardsliu pushed a commit to richardsliu/vllm that referenced this pull request Mar 14, 2025
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Richard Liu <ricliu@google.com>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants