Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 236 additions & 0 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# SPDX-License-Identifier: Apache-2.0
import pytest

from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.worker.gpu_model_runner import GPUModelRunner


@pytest.fixture
def model_runner():
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
max_model_len=512,
)
model_config = ModelConfig(
model="facebook/opt-125m",
task="generate",
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=True,
dtype="float16",
seed=42,
)
cache_config = CacheConfig(
block_size=16,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
)

device = "cuda"
return GPUModelRunner(vllm_config, device)


def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
new_reqs = []
num_scheduled_tokens = {}
total_num_scheduled_tokens = 0
for req_id in req_ids:
new_reqs.append(
NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
prompt="test",
mm_inputs=[],
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(),
block_ids=[0],
num_computed_tokens=0,
lora_request=None,
))
num_scheduled_tokens[req_id] = 3
total_num_scheduled_tokens += num_scheduled_tokens[req_id]

return SchedulerOutput(
scheduled_new_reqs=new_reqs,
scheduled_cached_reqs=[],
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
)


def _is_req_scheduled(model_runner, req_id: str) -> bool:
return req_id in model_runner.input_batch.req_id_to_index


def _is_req_added(model_runner, req_id: str) -> bool:
return req_id in model_runner.requests


def test_update_states_new_request(model_runner):
req_id = "req_0"

# new req
scheduler_output = _schedule_new_request(req_id)

batch_changed = model_runner._update_states(scheduler_output)
assert batch_changed is True
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)


def test_update_states_request_finished(model_runner):
req_id = "req_0"

# new req
scheduler_output = _schedule_new_request(req_id)

model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)

# finish req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids={req_id},
free_encoder_input_ids=[],
)

batch_changed = model_runner._update_states(scheduler_output)
assert batch_changed is True
assert not _is_req_added(model_runner, req_id)
assert not _is_req_scheduled(model_runner, req_id)


def test_update_states_request_resumed(model_runner):
req_id = "req_0"

# new req
scheduler_output = _schedule_new_request(req_id)

model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)

# unschedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids={},
free_encoder_input_ids=[],
)

model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert not _is_req_scheduled(model_runner, req_id)

# resume req
cached_req_data = CachedRequestData(
req_id=req_id,
resumed_from_preemption=False,
new_block_ids=[],
num_computed_tokens=0,
)

scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[cached_req_data],
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
)

batch_changed = model_runner._update_states(scheduler_output)
assert batch_changed is True
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)


def test_update_states_no_changes(model_runner):
req_id = "req_0"

# new req
scheduler_output = _schedule_new_request(req_id)

model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)

# schedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
)

batch_changed = model_runner._update_states(scheduler_output)
assert batch_changed is False
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)


def test_update_states_request_unscheduled(model_runner):
req_ids = ("req_0", "req_1")

# new reqs
scheduler_output = _schedule_new_request(*req_ids)

model_runner._update_states(scheduler_output)

assert _is_req_added(model_runner, req_ids[0])
assert _is_req_scheduled(model_runner, req_ids[0])

assert _is_req_added(model_runner, req_ids[1])
assert _is_req_scheduled(model_runner, req_ids[1])

# unschedule req_1
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={req_ids[0]: 1},
total_num_scheduled_tokens=1,
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
)

batch_changed = model_runner._update_states(scheduler_output)
assert batch_changed is True

assert _is_req_added(model_runner, req_ids[0])
assert _is_req_scheduled(model_runner, req_ids[0])

assert _is_req_added(model_runner, req_ids[1])
assert not _is_req_scheduled(model_runner, req_ids[1])
3 changes: 2 additions & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
return (len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
or len(scheduler_output.finished_req_ids) > 0)

def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
Expand Down