Skip to content

Commit 1a942d5

Browse files
committed
run linter
1 parent e8a7f9b commit 1a942d5

File tree

2 files changed

+61
-52
lines changed

2 files changed

+61
-52
lines changed

vllm/v1/attention/backends/pallas.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77
import torch_xla.experimental.custom_kernel # Required to register custom ops.
88

99
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
10-
AttentionLayer,
11-
AttentionMetadata, AttentionType)
10+
AttentionLayer, AttentionType)
1211
from vllm.attention.backends.utils import CommonAttentionState
1312

14-
1513
NUM_QUERIES_PER_BLOCK = 16
1614
NUM_KV_PAGES_PER_BLOCK = 128
1715

@@ -53,7 +51,7 @@ def swap_blocks(
5351

5452

5553
@dataclass
56-
class PallasMetadata():
54+
class PallasMetadata:
5755
# NOTE(sang): Definition of context_len, query_len, and seq_len.
5856
# |---------- N-1 iteration --------|
5957
# |---------------- N iteration ---------------------|
@@ -70,7 +68,6 @@ class PallasMetadata():
7068
num_seqs: int
7169

7270

73-
7471
class PallasAttentionBackendImpl(AttentionImpl):
7572

7673
def __init__(
@@ -88,7 +85,8 @@ def __init__(
8885
) -> None:
8986
if blocksparse_params is not None:
9087
raise ValueError(
91-
"Paged attention Pallas kernel does not support block-sparse attention.")
88+
"Paged attention Pallas kernel does not support block-sparse attention."
89+
)
9290
self.num_heads = num_heads
9391
self.head_size = head_size
9492
self.scale = float(scale)

vllm/v1/worker/tpu_model_runner.py

Lines changed: 57 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
import enum
32
import time
4-
from dataclasses import dataclass
53
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
64
from unittest.mock import patch
75

@@ -13,7 +11,6 @@
1311
import torch_xla.core.xla_model as xm
1412
import torch_xla.runtime as xr
1513

16-
from vllm.attention import AttentionMetadata
1714
from vllm.attention.backends.abstract import AttentionType
1815
from vllm.attention.layer import Attention
1916
from vllm.config import VllmConfig
@@ -22,13 +19,13 @@
2219
from vllm.model_executor.model_loader import get_model
2320
from vllm.sampling_params import SamplingType
2421
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
25-
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
26-
PallasMetadata,
22+
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
2723
NUM_QUERIES_PER_BLOCK,
28-
NUM_KV_PAGES_PER_BLOCK)
24+
PallasAttentionBackend,
25+
PallasMetadata)
2926
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
3027
KVCacheSpec)
31-
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
28+
from vllm.v1.outputs import ModelRunnerOutput
3229
from vllm.v1.utils import bind_kv_cache
3330
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
3431

@@ -48,7 +45,7 @@
4845
INVALID_TOKEN_ID = -1
4946

5047

51-
class TPUModelRunner():
48+
class TPUModelRunner:
5249

5350
def __init__(
5451
self,
@@ -80,8 +77,8 @@ def __init__(
8077
self.block_size = cache_config.block_size
8178
self.max_model_len = model_config.max_model_len
8279
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
83-
self.max_num_tokens = scheduler_config.max_num_batched_tokens # 8192
84-
self.max_num_reqs = scheduler_config.max_num_seqs # 16
80+
self.max_num_tokens = scheduler_config.max_num_batched_tokens
81+
self.max_num_reqs = scheduler_config.max_num_seqs
8582

8683
# Model-related.
8784
self.num_attn_layers = model_config.get_num_layers_by_block_type(
@@ -115,8 +112,8 @@ def __init__(
115112
# The pytorch tensor and numpy array share the same buffer.
116113
# Sometimes the numpy op is faster so we create both.
117114
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
118-
dtype=torch.int32,
119-
device="cpu")
115+
dtype=torch.int32,
116+
device="cpu")
120117
self.input_ids_np = self.input_ids_cpu.numpy()
121118

122119
self.positions_cpu = torch.zeros(self.max_num_tokens,
@@ -132,10 +129,12 @@ def __init__(
132129
# self.input_batch.block_table has a shape of [max_num_reqs, max_num_blocks_per_req].
133130
# To reduce the number of recompilation, we want the block_table.shape[0] to be num_tokens.
134131
# To make the block_table to be compatible with the paged attention kernel, we want the block_table[1] to be multiple of NUM_KV_PAGES_PER_BLOCK.
135-
padded_max_num_blocks_per_req = _get_padded_number(self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
136-
self.block_table_cpu = torch.zeros((self.max_num_tokens, padded_max_num_blocks_per_req),
137-
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
138-
device="cpu")
132+
padded_max_num_blocks_per_req = _get_padded_number(
133+
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
134+
self.block_table_cpu = torch.zeros(
135+
(self.max_num_tokens, padded_max_num_blocks_per_req),
136+
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
137+
device="cpu")
139138

140139
self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1,
141140
dtype=torch.int32,
@@ -325,9 +324,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
325324
assert req_id is not None
326325
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
327326
num_scheduled_tokens_per_req.append(num_tokens)
328-
max_num_scheduled_tokens_all_reqs = max(max_num_scheduled_tokens_all_reqs,
329-
num_tokens)
330-
num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req, dtype=np.int32)
327+
max_num_scheduled_tokens_all_reqs = max(
328+
max_num_scheduled_tokens_all_reqs, num_tokens)
329+
num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req,
330+
dtype=np.int32)
331331
assert max_num_scheduled_tokens_all_reqs > 0
332332

333333
# Get request indices.
@@ -341,13 +341,13 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
341341
# For each scheduled token, what is its position in the corresponding req.
342342
arange = np.concatenate(
343343
[self.arange_np[:n] for n in num_scheduled_tokens_per_req])
344-
344+
345345
# Get positions.
346346
positions_np = self.positions_np[:total_num_scheduled_tokens]
347347
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
348348
arange,
349349
out=positions_np)
350-
350+
351351
# Get token indices.
352352
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
353353
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
@@ -362,7 +362,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
362362
0,
363363
torch.from_numpy(token_indices),
364364
out=self.input_ids_cpu[:total_num_scheduled_tokens])
365-
365+
366366
# Calculate the slot mapping.
367367
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
368368
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
@@ -381,27 +381,40 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
381381
np.add(block_numbers * self.block_size,
382382
block_offsets,
383383
out=self.slot_mapping_np[:total_num_scheduled_tokens])
384-
384+
385385
# Prepare the attention metadata.
386386
self.query_start_loc_np[0] = 0
387387
np.cumsum(num_scheduled_tokens_per_req,
388388
out=self.query_start_loc_np[1:num_reqs + 1])
389-
389+
390390
self.seq_lens_np[:num_reqs] = (
391391
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
392392
num_scheduled_tokens_per_req)
393393

394394
# Do the padding and copy the tensors to the TPU.
395-
padded_total_num_scheduled_tokens = _get_padded_number(total_num_scheduled_tokens, NUM_QUERIES_PER_BLOCK)
396-
self.input_ids = self.input_ids_cpu[:padded_total_num_scheduled_tokens].to(self.device)
397-
self.position_ids = self.positions_cpu[:padded_total_num_scheduled_tokens].to(self.device)
395+
padded_total_num_scheduled_tokens = _get_padded_number(
396+
total_num_scheduled_tokens, NUM_QUERIES_PER_BLOCK)
397+
self.input_ids = self.input_ids_cpu[:
398+
padded_total_num_scheduled_tokens].to(
399+
self.device)
400+
self.position_ids = self.positions_cpu[:
401+
padded_total_num_scheduled_tokens].to(
402+
self.device)
398403
self.slot_mapping_cpu[total_num_scheduled_tokens:] = _PAD_SLOT_ID
399-
slot_mapping = self.slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(self.device)
400-
padded_block_table = self.block_table_cpu[:padded_total_num_scheduled_tokens]
401-
padded_block_table[:num_reqs, :self.max_num_blocks_per_req] = self.input_batch.block_table.get_cpu_tensor()[:num_reqs]
404+
slot_mapping = self.slot_mapping_cpu[:
405+
padded_total_num_scheduled_tokens].to(
406+
self.device)
407+
padded_block_table = self.block_table_cpu[:
408+
padded_total_num_scheduled_tokens]
409+
padded_block_table[:num_reqs, :self.
410+
max_num_blocks_per_req] = self.input_batch.block_table.get_cpu_tensor(
411+
)[:num_reqs]
402412
padded_block_table = padded_block_table.to(self.device)
403-
query_start_loc = self.query_start_loc_cpu[:padded_total_num_scheduled_tokens+1].to(self.device)
404-
seq_lens = self.seq_lens_cpu[:padded_total_num_scheduled_tokens].to(self.device)
413+
query_start_loc = self.query_start_loc_cpu[:
414+
padded_total_num_scheduled_tokens
415+
+ 1].to(self.device)
416+
seq_lens = self.seq_lens_cpu[:padded_total_num_scheduled_tokens].to(
417+
self.device)
405418

406419
attn_metadata = PallasMetadata(
407420
slot_mapping=slot_mapping,
@@ -418,7 +431,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
418431
logits_indices = query_start_loc[1:] - 1
419432
return attn_metadata, logits_indices
420433

421-
422434
@torch.no_grad()
423435
def execute_model(
424436
self,
@@ -432,7 +444,7 @@ def execute_model(
432444
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
433445

434446
# Run the decoder
435-
with set_forward_context(attn_metadata, self.vllm_config):
447+
with set_forward_context(attn_metadata, self.vllm_config):
436448
hidden_states = self.model(
437449
token_ids=self.input_ids,
438450
position_ids=self.position_ids,
@@ -461,14 +473,14 @@ def execute_model(
461473
if generator is not None:
462474
# This relies on cuda-specific torch-internal impl details
463475
generator.set_offset(generator.get_offset() - 4)
464-
476+
465477
# num_reqs entries should be non-None
466478
assert all(
467479
req_id is not None for req_id in
468480
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
469481
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
470482

471-
prompt_logprobs_dict: Dict[str, LogprobsTensors] = {}
483+
prompt_logprobs_dict = {}
472484
for req_id in self.input_batch.req_ids[:num_reqs]:
473485
prompt_logprobs_dict[req_id] = None
474486

@@ -526,7 +538,6 @@ def load_model(self) -> None:
526538
xm.mark_step()
527539
xm.wait_device_ops()
528540
model = ModelWrapperV1(model)
529-
# self.model = model
530541
self.model = torch.compile(model,
531542
backend="openxla",
532543
fullgraph=True,
@@ -546,12 +557,14 @@ def dummy_run(
546557
slot_mapping = torch.zeros(num_tokens,
547558
dtype=torch.int64,
548559
device=self.device)
549-
block_tables = torch.zeros(
550-
(num_tokens, self.block_table_cpu.shape[1]),
551-
dtype=torch.int32,
552-
device=self.device)
560+
block_tables = torch.zeros((num_tokens, self.block_table_cpu.shape[1]),
561+
dtype=torch.int32,
562+
device=self.device)
553563
query_lens = [1] * num_tokens
554-
query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32).to(self.device)
564+
query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
565+
dtype=torch.int32),
566+
dim=0,
567+
dtype=torch.int32).to(self.device)
555568
context_lens = torch.ones((num_tokens, ),
556569
dtype=torch.int32,
557570
device=self.device)
@@ -581,7 +594,6 @@ def capture_model(self) -> None:
581594

582595
start = time.perf_counter()
583596
num_tokens = 16
584-
# The num_tokens_list below is how GPU precompiles.
585597
while True:
586598
self.dummy_run(self.kv_caches, num_tokens)
587599
logger.info(" -- num_tokens: %d", num_tokens)
@@ -591,8 +603,7 @@ def capture_model(self) -> None:
591603
break
592604
num_tokens *= 2
593605
end = time.perf_counter()
594-
logger.info("Compilation finished in in %.2f [secs].",
595-
end - start)
606+
logger.info("Compilation finished in in %.2f [secs].", end - start)
596607

597608
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
598609
"""
@@ -682,7 +693,7 @@ def forward(
682693
position_ids,
683694
kv_caches,
684695
)
685-
696+
686697
return hidden_states
687698

688699
def compute_logits(

0 commit comments

Comments
 (0)