Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,17 @@ def _is_sampling_metadata_changed(model_runner,
sampling_metadata_before)


def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
req_index = model_runner.input_batch.req_id_to_index[req_id]
block_table = model_runner.input_batch.block_table
req_state = model_runner.requests[req_id]
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids):
return False
num_blocks = block_table.num_blocks_per_row[req_index]
return (block_table.block_table_np[req_index, :num_blocks] ==
req_state.block_ids).all()


def test_update_states_new_request(model_runner):
req_id = "req_0"

Expand All @@ -100,6 +111,7 @@ def test_update_states_new_request(model_runner):
assert _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
assert _is_req_state_block_table_match(model_runner, req_id)


def test_update_states_request_finished(model_runner):
Expand Down Expand Up @@ -185,6 +197,7 @@ def test_update_states_request_resumed(model_runner):
assert _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
assert _is_req_state_block_table_match(model_runner, req_id)


def test_update_states_no_changes(model_runner):
Expand Down Expand Up @@ -215,6 +228,7 @@ def test_update_states_no_changes(model_runner):
assert not _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
assert _is_req_state_block_table_match(model_runner, req_id)


def test_update_states_request_unscheduled(model_runner):
Expand Down
13 changes: 6 additions & 7 deletions vllm/v1/worker/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@ class BlockTable:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_blocks_per_req: int,
pin_memory: bool,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_blocks_per_req = max_num_blocks_per_req
self.pin_memory = pin_memory
self.device = device
Expand All @@ -42,18 +40,19 @@ def __init__(

def append_row(
self,
row_idx: int,
start: int,
block_ids: List[int],
row_idx: int,
) -> None:
if not block_ids:
return
num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very familiar with this. Does this always equal to len(req_state.block_ids) - len(req_data.new_block_ids) if not 0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, and added some tests for the checking block_table is the same as req_state.block_ids

self.num_blocks_per_row[row_idx] += num_blocks
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
self.num_blocks_per_row[row_idx] = start + num_blocks

def add_row(self, row_idx: int, block_ids: List[int]) -> None:
self.append_row(row_idx, 0, block_ids)
def add_row(self, block_ids: List[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0
self.append_row(block_ids, row_idx)

def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]
Expand Down
3 changes: 1 addition & 2 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def __init__(
# Block table.
self.block_table = BlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_blocks_per_req=max_num_blocks_per_req,
pin_memory=pin_memory,
device=device,
Expand Down Expand Up @@ -242,7 +241,7 @@ def add_request(
self.num_tokens_no_spec[req_index] = request.num_tokens

self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(req_index, request.block_ids)
self.block_table.add_row(request.block_ids, req_index)

sampling_params = request.sampling_params
if sampling_params.sampling_type == SamplingType.GREEDY:
Expand Down
6 changes: 2 additions & 4 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
start_index = (len(req_state.block_ids) -
len(req_data.new_block_ids))
self.input_batch.block_table.append_row(req_index, start_index,
req_data.new_block_ids)
self.input_batch.block_table.append_row(req_data.new_block_ids,
req_index)
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(req_data.new_token_ids)
Expand Down
6 changes: 2 additions & 4 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens)
start_index = len(req_state.block_ids) - len(
req_data.new_block_ids)
self.input_batch.block_table.append_row(req_index, start_index,
req_data.new_block_ids)
self.input_batch.block_table.append_row(req_data.new_block_ids,
req_index)

# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
Expand Down