Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable qkv concat layer #958

Merged
merged 2 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 23 additions & 44 deletions optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,8 @@ def __init__(
super().__init__()
self.max_batch_size = max_batch_size
self.batch_size = max_batch_size
self.kv_cache = []

self._seen_tokens = max_batch_size * [
0
] # Used in `generate` to keep tally of how many tokens the cache has seen
# Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = max_batch_size * [0]
self.block_size = 16
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * max_batch_size
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
Expand All @@ -62,38 +59,20 @@ def __init__(
head_size = config.hidden_size // config.num_attention_heads
self.head_size = head_size

self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []

if device.type == "cpu":
self.kv_cache = [
(
torch.empty(
(self.num_blocks, self.num_kv_heads, self.block_size, head_size),
dtype=dtype,
device=device,
),
torch.empty(
(self.num_blocks, self.num_kv_heads, self.block_size, head_size),
dtype=dtype,
device=device,
),
)
for _ in range(self.num_hidden_layers)
]
key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
elif device.type == "xpu":
self.kv_cache = [
(
torch.empty(
(self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1),
dtype=dtype,
device=device,
),
torch.empty(
(self.num_blocks, self.num_kv_heads, head_size, self.block_size),
dtype=dtype,
device=device,
),
)
for _ in range(self.num_hidden_layers)
]
key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1)
value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size)
for i in range(config.num_hidden_layers):
new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device)
new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

def update_for_prefill(
self,
Expand Down Expand Up @@ -125,8 +104,8 @@ def update_for_prefill(
PagedAttention.reshape_and_cache(
key_states,
value_states,
self.kv_cache[layer_idx][0],
self.kv_cache[layer_idx][1],
self.key_cache[layer_idx],
self.value_cache[layer_idx],
slots_tensor,
)

Expand Down Expand Up @@ -158,8 +137,8 @@ def update_for_decode(
PagedAttention.reshape_and_cache(
key_states,
value_states,
self.kv_cache[layer_idx][0],
self.kv_cache[layer_idx][1],
self.key_cache[layer_idx],
self.value_cache[layer_idx],
torch.tensor(slots, device=key_states.device),
)

Expand All @@ -175,7 +154,7 @@ def update(
layer_idx: int,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
length_list: Optional[List],
length_list: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Expand All @@ -199,7 +178,7 @@ def update(
# decode
self.update_for_decode(key_states, value_states, layer_idx, batch_size)

return self.kv_cache[layer_idx][0], self.kv_cache[layer_idx][1]
return self.key_cache[layer_idx], self.value_cache[layer_idx]

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
Expand Down Expand Up @@ -227,9 +206,9 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
self.block_tables[i, 0 : num_blocks[i] - 1] = updated_block_tables[i, 0 : num_blocks[i] - 1]
updated_table.append(self.block_tables[i : i + 1, num_blocks[i] - 1 : num_blocks[i]])
updated_table = torch.cat(tuple(updated_table), dim=0)
for layer_idx in range(len(self.kv_cache)):
self.kv_cache[layer_idx][0][updated_table] = self.kv_cache[layer_idx][0][updated_table[beam_idx]]
self.kv_cache[layer_idx][1][updated_table] = self.kv_cache[layer_idx][1][updated_table[beam_idx]]
for layer_idx in range(self.num_hidden_layers):
self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx][updated_table[beam_idx]]
self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx][updated_table[beam_idx]]

free_table = origin_table[origin_table != self.block_tables]
for i in range(free_table.shape[0]):
Expand Down
28 changes: 21 additions & 7 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ def _llama_model_forward(
position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1))
else:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
input_lens = attention_mask.cumsum(-1)[:, -1]
lens_list = input_lens.tolist()
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
Expand All @@ -136,8 +135,8 @@ def _llama_model_forward(
output_attentions=output_attentions,
use_cache=use_cache,
position_embeddings=position_embeddings,
input_lens=input_lens.int(),
lens_list=lens_list,
input_lens=input_lens,
lens_list=input_lens,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -278,15 +277,30 @@ def forward(
class _IPEXLlamaAttention(_IPEXAttention):
def __init__(self, module, config) -> None:
super().__init__(module, config)
concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight])
bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias]
use_bias = bias_list != []
self.concat_qkv = nn.Linear(concat_weight.shape[1], concat_weight.shape[0], bias=use_bias)
self.concat_qkv.weight = nn.Parameter(concat_weight)
if use_bias:
concat_bias = torch.concat(bias_list, 0)
self.concat_linear.bias = nn.Parameter(concat_bias)
self.q_slice = self.q_proj.out_features
self.k_slice = self.q_slice + self.k_proj.out_features
self.v_slice = self.k_slice + self.v_proj.out_features
del self.__dict__["_modules"]["q_proj"]
del self.__dict__["_modules"]["k_proj"]
del self.__dict__["_modules"]["v_proj"]
if self.module_device == "cpu":
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = LinearAdd(module.o_proj)
del self.__dict__["_modules"]["o_proj"]

def qkv_gemm(self, hidden_states):
query = self.q_proj(hidden_states).view(-1, self.num_heads, self.head_dim)
key = self.k_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)
value = self.v_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)
qkv_out = self.concat_qkv(hidden_states)
query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim)
key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim)
value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim)

return query, key, value

Expand Down