-
-
Couldn't load subscription status.
- Fork 10.9k
[V1] AsyncLLM data parallel #13923
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[V1] AsyncLLM data parallel #13923
Changes from all commits
Commits
Show all changes
65 commits
Select commit
Hold shift + click to select a range
9ca44ce
[V1] AsyncLLM data parallel WIP
njhill 3f51611
Handle pausing loop
njhill d8c591e
More single-node updates
njhill 65e225d
some cleanup
njhill 5ce57b6
fix up utility methods
njhill a3f1102
revert config check
njhill a66fb01
fixes
njhill 67672c2
cleanup
njhill cf52fbf
fixes
njhill a4ec81b
reconcile with LLMEngine DP in decoupled engine case
njhill 292aa00
minor simplification
njhill 4b62ffd
rework
njhill 407c72e
class refactor
njhill 31bf7ea
fix
njhill fde51ce
adjust core engine init
njhill d5a3e68
Merge remote-tracking branch 'refs/remotes/origin/main' into multi-en…
njhill 6d89a1b
fix new typing
njhill 448abd9
fix :facepalm:
njhill a1e513e
bind socket first
njhill 50cf64c
do you have to let it linger
njhill f365998
Merge remote-tracking branch 'origin/main' into multi-engine
njhill b2571f0
add comments
njhill 32c6f24
aggregate stats
njhill 9c30cd7
Merge remote-tracking branch 'origin/main' into multi-engine
njhill 672d07e
Fix test
njhill dea382b
Merge remote-tracking branch 'origin/main' into multi-engine
njhill d24a626
fix and minor cleanup
njhill cd03c80
Add CI test
njhill f1004b7
Merge remote-tracking branch 'origin/main' into multi-engine
njhill d3298fa
Some simplification and fixes
njhill 74dde48
Merge remote-tracking branch 'origin/main' into multi-engine
njhill 5fe1b75
address @markmc's stats suggestion
njhill 648659f
address @tms's arg comment
njhill 119d1ec
fix utility method breakage
njhill 55328ee
rename AsyncMPClient output_processor to output_handler
njhill 4f5330e
Merge remote-tracking branch 'origin/main' into multi-engine
njhill 48770ec
Merge remote-tracking branch 'origin/main' into multi-engine
njhill d229f4d
Fix
njhill 2f91cc4
Merge remote-tracking branch 'origin/main' into multi-engine
njhill 518047a
Remove redundant logic related to removed stats aggregation
njhill cb2b099
Fixes
njhill ff1137a
Merge remote-tracking branch 'refs/remotes/origin/main' into multi-en…
njhill 61f4fcb
fix issue from main merge
njhill 44874c2
remove leftover unused field
njhill 66fc582
Fix offline DP compatibility
njhill 7764466
Add timeout to data_parallel.py
njhill 51e8bf0
Merge remote-tracking branch 'refs/remotes/origin/main' into multi-en…
njhill f692c12
Merge remote-tracking branch 'origin/main' into multi-engine
njhill 47b5e1c
Enable less-frequent all-reduce optimization
njhill f226139
Merge remote-tracking branch 'origin/main' into multi-engine
njhill af47920
Merge remote-tracking branch 'origin/main' into multi-engine
njhill 693c521
Merge remote-tracking branch 'origin/main' into multi-engine
njhill 6e131e3
clean distributed shutdown
njhill d9ac856
address misc loose-ends
njhill 3abbdef
Merge remote-tracking branch 'origin/main' into multi-engine
njhill b18417e
further tweaks
njhill 56b2b78
Merge remote-tracking branch 'refs/remotes/origin/main' into multi-en…
njhill 05ab310
Additional debug
njhill 5295c34
Merge remote-tracking branch 'origin/main' into multi-engine
njhill 4f897b8
Address review comments on tests
njhill 62f32ed
Merge remote-tracking branch 'origin/main' into multi-engine
njhill 771ccf1
Fix env var fallback
njhill 05a0e83
Fix test supports_v1 check
njhill bc41b13
Fix yapf :facepalm:
njhill ccecb42
Merge remote-tracking branch 'origin/main' into multi-engine
njhill File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import asyncio | ||
| import os | ||
| from contextlib import ExitStack | ||
| from typing import Optional | ||
|
|
||
| import pytest | ||
|
|
||
| from vllm import SamplingParams | ||
| from vllm.engine.arg_utils import AsyncEngineArgs | ||
| from vllm.inputs import PromptType | ||
| from vllm.platforms import current_platform | ||
| from vllm.sampling_params import RequestOutputKind | ||
| from vllm.v1.engine.async_llm import AsyncLLM | ||
| from vllm.v1.engine.core_client import DPAsyncMPClient | ||
|
|
||
| engine_args = AsyncEngineArgs( | ||
| model="ibm-research/PowerMoE-3b", | ||
| enforce_eager=True, | ||
| disable_log_requests=True, | ||
| tensor_parallel_size=int(os.getenv("TP_SIZE", 1)), | ||
| data_parallel_size=int(os.getenv("DP_SIZE", 2)), | ||
| ) | ||
|
|
||
| if not current_platform.supports_v1(engine_args.create_model_config()): | ||
| pytest.skip(reason="Requires V1-supporting platform.", | ||
| allow_module_level=True) | ||
|
|
||
|
|
||
| async def generate(engine: AsyncLLM, | ||
| request_id: str, | ||
| prompt: PromptType, | ||
| output_kind: RequestOutputKind, | ||
| max_tokens: int, | ||
| prompt_logprobs: Optional[int] = None) -> tuple[int, str]: | ||
| # Ensure generate doesn't complete too fast for cancellation test. | ||
| await asyncio.sleep(0.2) | ||
|
|
||
| count = 0 | ||
| sampling_params = SamplingParams(max_tokens=max_tokens, | ||
| ignore_eos=True, | ||
| output_kind=output_kind, | ||
| temperature=0, | ||
| prompt_logprobs=prompt_logprobs) | ||
| async for out in engine.generate(request_id=request_id, | ||
| prompt=prompt, | ||
| sampling_params=sampling_params): | ||
|
|
||
| num_tokens = len(out.outputs[0].token_ids) | ||
| if output_kind == RequestOutputKind.DELTA: | ||
| count += num_tokens | ||
| else: | ||
| count = num_tokens | ||
|
|
||
| await asyncio.sleep(0.) | ||
|
|
||
| return count, request_id | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) | ||
| @pytest.mark.asyncio | ||
| async def test_load(output_kind: RequestOutputKind): | ||
|
|
||
| with ExitStack() as after: | ||
|
|
||
| prompt = "This is a test of data parallel" | ||
|
|
||
| engine = AsyncLLM.from_engine_args(engine_args) | ||
| after.callback(engine.shutdown) | ||
|
|
||
| NUM_REQUESTS = 100 | ||
| NUM_EXPECTED_TOKENS = 10 | ||
|
|
||
| request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] | ||
|
|
||
| # Create concurrent requests. | ||
| tasks = [] | ||
| for request_id in request_ids: | ||
| tasks.append( | ||
| asyncio.create_task( | ||
| generate(engine, request_id, prompt, output_kind, | ||
| NUM_EXPECTED_TOKENS))) | ||
|
|
||
| # Confirm that we got all the EXPECTED tokens from the requests. | ||
| done, pending = await asyncio.wait(tasks, | ||
| return_when=asyncio.FIRST_EXCEPTION) | ||
| for task in pending: | ||
| task.cancel() | ||
| for task in done: | ||
| num_generated_tokens, request_id = await task | ||
| assert num_generated_tokens == NUM_EXPECTED_TOKENS, ( | ||
| f"{request_id} generated {num_generated_tokens} but " | ||
| f"expected {NUM_EXPECTED_TOKENS}") | ||
|
|
||
| assert not engine.output_processor.has_unfinished_requests() | ||
|
|
||
| # testing internals here which may break | ||
| core_client: DPAsyncMPClient = engine.engine_core | ||
| # the engines only synchronize stopping every N steps so | ||
| # allow a small amount of time here. | ||
| for _ in range(10): | ||
| if core_client.num_engines_running == 0: | ||
| break | ||
| await asyncio.sleep(0.5) | ||
|
|
||
| assert core_client.num_engines_running == 0 | ||
| assert not core_client.reqs_in_flight |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.