From 07b5058356babab66c469fce43e5aeeb638424ea Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 21 Oct 2024 13:04:45 -0400 Subject: [PATCH 1/2] enable qkv --- optimum/exporters/ipex/modeling_utils.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index b062528438..8916b03c48 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -278,15 +278,30 @@ def forward( class _IPEXLlamaAttention(_IPEXAttention): def __init__(self, module, config) -> None: super().__init__(module, config) + concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]) + bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias] + use_bias = bias_list != [] + self.concat_qkv = nn.Linear(concat_weight.shape[1], concat_weight.shape[0], bias=use_bias) + self.concat_qkv.weight = nn.Parameter(concat_weight) + if use_bias: + concat_bias = torch.concat(bias_list, 0) + self.concat_linear.bias = nn.Parameter(concat_bias) + self.q_slice = self.q_proj.out_features + self.k_slice = self.q_slice + self.k_proj.out_features + self.v_slice = self.k_slice + self.v_proj.out_features + del self.__dict__["_modules"]["q_proj"] + del self.__dict__["_modules"]["k_proj"] + del self.__dict__["_modules"]["v_proj"] if self.module_device == "cpu": if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: self.mha_linear_add = LinearAdd(module.o_proj) del self.__dict__["_modules"]["o_proj"] def qkv_gemm(self, hidden_states): - query = self.q_proj(hidden_states).view(-1, self.num_heads, self.head_dim) - key = self.k_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim) - value = self.v_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim) + qkv_out = self.concat_qkv(hidden_states) + query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim) + key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim) + value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim) return query, key, value From 9f0220888f2f7f99dce11e8a36c60965a43f8c9c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 23 Oct 2024 04:52:23 -0400 Subject: [PATCH 2/2] split key value into 2 lists --- optimum/exporters/ipex/cache_utils.py | 67 ++++++++---------------- optimum/exporters/ipex/modeling_utils.py | 7 ++- 2 files changed, 26 insertions(+), 48 deletions(-) diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py index d553ba87d2..0ed0b4368a 100755 --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -42,11 +42,8 @@ def __init__( super().__init__() self.max_batch_size = max_batch_size self.batch_size = max_batch_size - self.kv_cache = [] - - self._seen_tokens = max_batch_size * [ - 0 - ] # Used in `generate` to keep tally of how many tokens the cache has seen + # Used in `generate` to keep tally of how many tokens the cache has seen + self._seen_tokens = max_batch_size * [0] self.block_size = 16 self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * max_batch_size self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape( @@ -62,38 +59,20 @@ def __init__( head_size = config.hidden_size // config.num_attention_heads self.head_size = head_size + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + if device.type == "cpu": - self.kv_cache = [ - ( - torch.empty( - (self.num_blocks, self.num_kv_heads, self.block_size, head_size), - dtype=dtype, - device=device, - ), - torch.empty( - (self.num_blocks, self.num_kv_heads, self.block_size, head_size), - dtype=dtype, - device=device, - ), - ) - for _ in range(self.num_hidden_layers) - ] + key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size) + value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size) elif device.type == "xpu": - self.kv_cache = [ - ( - torch.empty( - (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1), - dtype=dtype, - device=device, - ), - torch.empty( - (self.num_blocks, self.num_kv_heads, head_size, self.block_size), - dtype=dtype, - device=device, - ), - ) - for _ in range(self.num_hidden_layers) - ] + key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1) + value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size) + for i in range(config.num_hidden_layers): + new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device) + new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) def update_for_prefill( self, @@ -125,8 +104,8 @@ def update_for_prefill( PagedAttention.reshape_and_cache( key_states, value_states, - self.kv_cache[layer_idx][0], - self.kv_cache[layer_idx][1], + self.key_cache[layer_idx], + self.value_cache[layer_idx], slots_tensor, ) @@ -158,8 +137,8 @@ def update_for_decode( PagedAttention.reshape_and_cache( key_states, value_states, - self.kv_cache[layer_idx][0], - self.kv_cache[layer_idx][1], + self.key_cache[layer_idx], + self.value_cache[layer_idx], torch.tensor(slots, device=key_states.device), ) @@ -175,7 +154,7 @@ def update( layer_idx: int, attention_mask: torch.Tensor, position_ids: torch.Tensor, - length_list: Optional[List], + length_list: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. @@ -199,7 +178,7 @@ def update( # decode self.update_for_decode(key_states, value_states, layer_idx, batch_size) - return self.kv_cache[layer_idx][0], self.kv_cache[layer_idx][1] + return self.key_cache[layer_idx], self.value_cache[layer_idx] def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model.""" @@ -227,9 +206,9 @@ def reorder_cache(self, beam_idx: torch.LongTensor): self.block_tables[i, 0 : num_blocks[i] - 1] = updated_block_tables[i, 0 : num_blocks[i] - 1] updated_table.append(self.block_tables[i : i + 1, num_blocks[i] - 1 : num_blocks[i]]) updated_table = torch.cat(tuple(updated_table), dim=0) - for layer_idx in range(len(self.kv_cache)): - self.kv_cache[layer_idx][0][updated_table] = self.kv_cache[layer_idx][0][updated_table[beam_idx]] - self.kv_cache[layer_idx][1][updated_table] = self.kv_cache[layer_idx][1][updated_table[beam_idx]] + for layer_idx in range(self.num_hidden_layers): + self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx][updated_table[beam_idx]] + self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx][updated_table[beam_idx]] free_table = origin_table[origin_table != self.block_tables] for i in range(free_table.shape[0]): diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 8916b03c48..49f3d34d3a 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -122,8 +122,7 @@ def _llama_model_forward( position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) else: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - input_lens = attention_mask.cumsum(-1)[:, -1] - lens_list = input_lens.tolist() + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -136,8 +135,8 @@ def _llama_model_forward( output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, - input_lens=input_lens.int(), - lens_list=lens_list, + input_lens=input_lens, + lens_list=input_lens, ) hidden_states = layer_outputs[0]