Skip to content

Commit

Permalink
refine class IPEXPagedCache's update method (#945)
Browse files Browse the repository at this point in the history
* refine class IPEXPagedCache's update method

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* replace tensor on xpu to List to avoid memory copy

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* split IPEXPagedCache's update function into `update_for_prefill` and `update_for_decode`

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
  • Loading branch information
kaixuanliu authored Oct 17, 2024
1 parent 541a236 commit 35cd0c1
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 35 deletions.
108 changes: 78 additions & 30 deletions optimum/exporters/ipex/cache_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import List, Optional, Tuple

import torch
from intel_extension_for_pytorch.llm.modules import PagedAttention
Expand Down Expand Up @@ -95,14 +95,87 @@ def __init__(
for _ in range(self.num_hidden_layers)
]

def update_for_prefill(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
batch_size: int,
length_list: Optional[List],
):
all_block_indices = []
all_slot_offsets = []
for i in range(batch_size):
num_blocks = (length_list[i] + self.block_size - 1) // self.block_size
for b_idx in range(num_blocks):
if self.block_tables[i][b_idx] == -1:
# need a free block
self.block_tables[i][b_idx] = self.free_blocks.pop(0)

slots_range = torch.arange(length_list[i], device=key_states.device)
block_indices = slots_range // self.block_size
slot_offsets = slots_range % self.block_size
all_block_indices.append(self.block_tables[i][block_indices])
all_slot_offsets.append(slot_offsets)

all_block_indices = torch.cat(all_block_indices)
all_slot_offsets = torch.cat(all_slot_offsets)
slots_tensor = all_block_indices * self.block_size + all_slot_offsets
# Update the cache
PagedAttention.reshape_and_cache(
key_states,
value_states,
self.kv_cache[layer_idx][0],
self.kv_cache[layer_idx][1],
slots_tensor,
)

# Update the number of seen tokens
if layer_idx == self.num_hidden_layers - 1:
for i in range(batch_size):
self._seen_tokens[i] += length_list[i]

def update_for_decode(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
batch_size: int,
):
slots = []
for i in range(batch_size):
start_block_idx = self._seen_tokens[i] // self.block_size
num_blocks = (self._seen_tokens[i] + self.block_size) // self.block_size
for b_idx in range(start_block_idx, num_blocks):
if self.block_tables[i][b_idx] == -1:
# need a free block
self.block_tables[i][b_idx] = self.free_blocks.pop(0)
block_idx = (self._seen_tokens[i]) // self.block_size
slot_offset_in_block = (self._seen_tokens[i]) % self.block_size
slots.append(self.block_tables[i][block_idx].item() * self.block_size + slot_offset_in_block)

# Update the cache
PagedAttention.reshape_and_cache(
key_states,
value_states,
self.kv_cache[layer_idx][0],
self.kv_cache[layer_idx][1],
torch.tensor(slots, device=key_states.device),
)

# Update the number of seen tokens
if layer_idx == self.num_hidden_layers - 1:
for i in range(batch_size):
self._seen_tokens[i] += 1

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
input_lens: torch.Tensor,
length_list: Optional[List],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Expand All @@ -117,39 +190,14 @@ def update(
Return:
A tuple containing the updated key and value states.
"""

batch_size = position_ids.shape[0]
slots = []
if self.get_seq_length() == 0:
# prefill
num_slots = input_lens.tolist()
self.update_for_prefill(key_states, value_states, layer_idx, batch_size, length_list)
else:
# decode
num_slots = [1] * batch_size
for i in range(batch_size):
start_block_idx = self._seen_tokens[i] // self.block_size
num_blocks = (self._seen_tokens[i] + num_slots[i] + self.block_size - 1) // self.block_size
for b_idx in range(start_block_idx, num_blocks):
if self.block_tables[i][b_idx] == -1:
# need a free block
self.block_tables[i][b_idx] = self.free_blocks.pop(0)
for slot in range(num_slots[i]):
block_idx = (self._seen_tokens[i] + slot) // self.block_size
slot_offset_in_block = (self._seen_tokens[i] + slot) % self.block_size
slots.append(self.block_tables[i][block_idx].item() * self.block_size + slot_offset_in_block)

# Update the cache
PagedAttention.reshape_and_cache(
key_states,
value_states,
self.kv_cache[layer_idx][0],
self.kv_cache[layer_idx][1],
torch.tensor(slots, device=key_states.device),
)

# Update the number of seen tokens
if layer_idx == self.num_hidden_layers - 1:
for i in range(batch_size):
self._seen_tokens[i] += num_slots[i]
self.update_for_decode(key_states, value_states, layer_idx, batch_size)

return self.kv_cache[layer_idx][0], self.kv_cache[layer_idx][1]

Expand Down
10 changes: 5 additions & 5 deletions optimum/exporters/ipex/modeling_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _llama_model_forward(
else:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
input_lens = attention_mask.cumsum(-1)[:, -1]

lens_list = input_lens.tolist()
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
Expand All @@ -137,6 +137,7 @@ def _llama_model_forward(
use_cache=use_cache,
position_embeddings=position_embeddings,
input_lens=input_lens.int(),
lens_list=lens_list,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -210,6 +211,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
input_lens: Optional[torch.Tensor] = None,
lens_list: Optional[List] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if past_key_value is None and kwargs.get("layer_past", None) is not None:
Expand All @@ -227,15 +229,13 @@ def forward(

if past_key_value is not None:
key_cache, value_cache = past_key_value.update(
key, value, self.layer_idx, attention_mask, position_ids, input_lens
key, value, self.layer_idx, attention_mask, position_ids, lens_list
)

attn_output = torch.empty_like(query)
if past_len == 0:
# prefill, remove padding
seq_len_tensor = torch.cat(
(torch.tensor([0], device=input_lens.device, dtype=torch.int), input_lens.cumsum(-1).int())
)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
Expand Down

0 comments on commit 35cd0c1

Please sign in to comment.