From b7bfce38ec62ce6e68ea1aa8740a18769d5f2172 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Fri, 13 Jun 2025 18:38:21 +0000 Subject: [PATCH 1/3] [TPU] support attention head dim smaller than 128 Signed-off-by: Chengji Yao --- tests/v1/tpu/test_basic.py | 37 ++++++++++++++++++++++++++++ vllm/v1/attention/backends/pallas.py | 27 ++++++++++++++------ 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/tests/v1/tpu/test_basic.py b/tests/v1/tpu/test_basic.py index 7117a66c2958..fe65976a58a1 100644 --- a/tests/v1/tpu/test_basic.py +++ b/tests/v1/tpu/test_basic.py @@ -67,6 +67,43 @@ def test_basic( assert "1024" in output or "0, 1" in output +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="This is a basic test for TPU only") +@pytest.mark.parametrize("max_tokens", [8]) +@pytest.mark.parametrize("max_num_seqs", [16]) +def test_phi3( + vllm_runner: type[VllmRunner], + monkeypatch: pytest.MonkeyPatch, + max_tokens: int, + max_num_seqs: int, +) -> None: + prompts = [ + "A robot may not injure a human being", + "It is only with the heart that one can see rightly;", + "The greatest glory in living lies not in never falling,", + ] + answers = [ + " or, by violating privacy", + " what is essential is love.", + " but in rising every time we fall.", + ] + # test head dim = 96 + model = "microsoft/Phi-3-mini-128k-instruct" + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + with vllm_runner(model, + max_num_batched_tokens=256, + max_num_seqs=max_num_seqs) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens) + # vllm_outputs is a list of tuples whose first element is the token id + # and the second element is the output (including the prompt). + for output, answer in zip(vllm_outputs, answers): + generated_text = output[1] + assert answer in generated_text + + TP_SIZE_8 = 8 diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 62c72f43f147..015c8c51e0a6 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -43,6 +43,11 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> tuple[int, ...]: + # TPU requires the head size to be a multiple of 128. + if head_size % 128 != 0: + padded_head_size = cdiv(head_size, 128) * 128 + num_blocks = num_blocks * head_size // padded_head_size + head_size = padded_head_size return (num_blocks, block_size, num_kv_heads * 2, head_size) @staticmethod @@ -133,8 +138,6 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - if head_size % 128 != 0: - raise NotImplementedError("Head size must be a multiple of 128.") if alibi_slopes is not None: raise NotImplementedError("Alibi slopes is not supported.") if kv_cache_dtype != "auto": @@ -188,6 +191,16 @@ def forward( assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 num_tokens, hidden_size = query.shape query = query.view(num_tokens, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + if self.head_size % 128 != 0: + padded_head_size = cdiv(self.head_size, 128) * 128 + query = torch.nn.functional.pad( + query, (0, padded_head_size - self.head_size), value=0.0) + key = torch.nn.functional.pad( + key, (0, padded_head_size - self.head_size), value=0.0) + value = torch.nn.functional.pad( + value, (0, padded_head_size - self.head_size), value=0.0) if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0: # Write input keys and values to the KV cache. @@ -214,6 +227,9 @@ def forward( soft_cap=self.logits_soft_cap, ) + if self.head_size % 128 != 0: + output = output[:, :, :self.head_size] + return output.reshape(num_tokens, hidden_size) @@ -232,11 +248,8 @@ def write_to_kv_cache( """ _, _, num_combined_kv_heads, head_size = kv_cache.shape - num_kv_heads = num_combined_kv_heads // 2 - - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - + if head_size % 128 != 0: + head_size = cdiv(head_size, 128) * 128 kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, head_size) From e0f6ae3c20eb6b54eab1d387d87f1ad0d230a43c Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Fri, 13 Jun 2025 19:23:17 +0000 Subject: [PATCH 2/3] fix comment Signed-off-by: Chengji Yao --- vllm/v1/attention/backends/pallas.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 015c8c51e0a6..3f113a05880b 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -17,6 +17,9 @@ logger = init_logger(__name__) +# TPU requires the head size to be a multiple of 128. +TPU_HEAD_SIZE_ALIGNMENT = 128 + class PallasAttentionBackend(AttentionBackend): @@ -43,11 +46,10 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> tuple[int, ...]: - # TPU requires the head size to be a multiple of 128. - if head_size % 128 != 0: - padded_head_size = cdiv(head_size, 128) * 128 - num_blocks = num_blocks * head_size // padded_head_size - head_size = padded_head_size + padded_head_size = cdiv( + head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + num_blocks = num_blocks * head_size // padded_head_size + head_size = padded_head_size return (num_blocks, block_size, num_kv_heads * 2, head_size) @staticmethod @@ -193,8 +195,10 @@ def forward( query = query.view(num_tokens, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) - if self.head_size % 128 != 0: - padded_head_size = cdiv(self.head_size, 128) * 128 + if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0: + padded_head_size = cdiv( + self.head_size, + TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT query = torch.nn.functional.pad( query, (0, padded_head_size - self.head_size), value=0.0) key = torch.nn.functional.pad( @@ -227,7 +231,7 @@ def forward( soft_cap=self.logits_soft_cap, ) - if self.head_size % 128 != 0: + if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0: output = output[:, :, :self.head_size] return output.reshape(num_tokens, hidden_size) @@ -248,8 +252,8 @@ def write_to_kv_cache( """ _, _, num_combined_kv_heads, head_size = kv_cache.shape - if head_size % 128 != 0: - head_size = cdiv(head_size, 128) * 128 + head_size = cdiv(head_size, + TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, head_size) From 1a56b0a830f88f09f66f4a2b24cb7d14966fea20 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Sat, 14 Jun 2025 21:31:09 +0000 Subject: [PATCH 3/3] fix comment Signed-off-by: Chengji Yao --- vllm/v1/attention/backends/pallas.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 3f113a05880b..b89ea7cce680 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -49,6 +49,10 @@ def get_kv_cache_shape( padded_head_size = cdiv( head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT num_blocks = num_blocks * head_size // padded_head_size + if padded_head_size != head_size: + logger.warning_once( + "head size is padded to %d, and num_blocks is adjusted to %d" + " accordingly", padded_head_size, num_blocks) head_size = padded_head_size return (num_blocks, block_size, num_kv_heads * 2, head_size)