Skip to content

Commit 1c6bada

Browse files
Chunk prefill cache writes, remove div_i32 from insert_or_update_cache (#289)
Re-implements following PRs for current habana_main: #102 (Removing div_i32 operations from each layer) #115 (removing scatter for reshape&cache in case of prompt) Accuracy (GSM8K on Llama3.1-8B-Instruct): | Tasks |Version| Filter |n-shot| Metric | |Value | |Stderr| |---------------|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k_cot_llama| 3|flexible-extract| 8|exact_match|↑ |0.8415|± |0.0101| | | |strict-match | 8|exact_match|↑ |0.8400|± |0.0101| I've benchmarked this change on Llama3.1-8B-Instruct and on average, +2.50% throughput gain (+558.14 tok/s, ~21594 tok/s -> ~22152 tok/s) can be observed across all prefill buckets on G2, with up to +4.40% (+956.79 tok/s, ~25031 -> ~25988 tok/s) throughput increase in compute-bound scenarios.
1 parent 4c8a6c6 commit 1c6bada

File tree

4 files changed

+33
-11
lines changed

4 files changed

+33
-11
lines changed

requirements-hpu.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,4 @@ ray == 2.32.0
66
triton
77
pandas
88
tabulate
9-
10-
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@0a7adab
9+
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@940fdb7

vllm/attention/backends/habana_attn.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import torch
1010
import vllm_hpu_extension.ops as ops
11-
from vllm_hpu_extension import cache_ops
1211
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
1312

1413
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
@@ -166,20 +165,22 @@ def forward(
166165
query = query.view(-1, self.num_heads, self.head_size)
167166
key = key.view(-1, self.num_kv_heads, self.head_size)
168167
value = value.view(-1, self.num_kv_heads, self.head_size)
168+
block_indices = attn_metadata.block_indices
169+
block_offsets = attn_metadata.block_offsets
170+
if attn_metadata.is_prompt:
171+
key = key.unflatten(0, (block_indices.size(0), -1))
172+
value = value.unflatten(0, (block_indices.size(0), -1))
169173
if kv_cache is not None:
170174
key_cache, value_cache = HabanaPagedAttention.split_kv_cache(
171175
kv_cache, self.num_kv_heads, self.head_size)
172176

173177
# Reshape the input keys and values and store them in the cache.
174178
# If kv_cache is not provided, the new key and value tensors are
175179
# not cached. This happens during the initial memory profiling run.
176-
num_kv_cache_passes, num_slots_available, indices, offsets = \
177-
cache_ops.prepare_to_cache(key_cache,
178-
attn_metadata.slot_mapping)
179-
key_cache = self.k_cache(key, key_cache, num_kv_cache_passes,
180-
num_slots_available, indices, offsets)
181-
value_cache = self.v_cache(value, value_cache, num_kv_cache_passes,
182-
num_slots_available, indices, offsets)
180+
key_cache = self.k_cache(key, key_cache, block_indices,
181+
block_offsets)
182+
value_cache = self.v_cache(value, value_cache, block_indices,
183+
block_offsets)
183184

184185
if attn_metadata.is_prompt:
185186
# Prompt run.

vllm/attention/ops/habana_paged_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ class HabanaPagedAttentionMetadata:
1818
block_list: Optional[torch.Tensor]
1919
block_mapping: Optional[torch.Tensor]
2020
block_usage: Optional[torch.Tensor]
21+
block_indices: Optional[torch.Tensor]
22+
block_offsets: Optional[torch.Tensor]
2123

2224

2325
class HabanaPagedAttention:

vllm/worker/habana_model_runner.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,17 @@ def pad_list(list, k, v):
245245
return list + [v] * padding
246246

247247

248+
def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt):
249+
slot_mapping = slot_mapping.flatten()
250+
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
251+
if is_prompt:
252+
indices = indices.unflatten(0, (-1, block_size))[:, 0]
253+
offsets = None
254+
else:
255+
offsets = torch.fmod(slot_mapping, block_size)
256+
return indices, offsets
257+
258+
248259
class HpuModelAdapter():
249260

250261
def __init__(self, model, block_size, dtype, enforce_eager):
@@ -890,11 +901,15 @@ def _prepare_prompt(
890901
dtype=torch.long,
891902
device=self.device)
892903

904+
block_indices, block_offsets = precompute_indices_and_offsets(
905+
self.block_size, slot_mapping, True)
893906
attn_metadata = self.attn_backend.make_metadata(
894907
is_prompt=True,
895908
block_list=None,
896909
block_mapping=None,
897910
block_usage=None,
911+
block_indices=block_indices,
912+
block_offsets=block_offsets,
898913
attn_bias=None,
899914
seq_lens_tensor=seq_lens_tensor,
900915
num_prefills=real_num_seqs,
@@ -1044,11 +1059,15 @@ def _prepare_decode(
10441059
dtype=torch.long,
10451060
device=self.device)
10461061

1062+
block_indices, block_offsets = precompute_indices_and_offsets(
1063+
self.block_size, slot_mapping, False)
10471064
attn_metadata = self.attn_backend.make_metadata(
10481065
is_prompt=False,
10491066
block_list=block_list,
10501067
block_mapping=block_mapping,
10511068
block_usage=block_usage,
1069+
block_indices=block_indices,
1070+
block_offsets=block_offsets,
10521071
attn_bias=None,
10531072
seq_lens_tensor=None,
10541073
num_prefills=0,
@@ -1266,7 +1285,8 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
12661285
# input_hash("abc") != input_hash("cba")
12671286
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
12681287
'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping',
1269-
'block_usage', 'slot_mapping', 'is_prompt'
1288+
'block_usage', 'slot_mapping', 'is_prompt', 'block_indices',
1289+
'block_offsets'
12701290
])
12711291
return attention_metadata
12721292

0 commit comments

Comments
 (0)