diff --git a/.github/workflows/test_inc.yml b/.github/workflows/test_inc.yml index 2e51cdfbb3..7f622a2da9 100644 --- a/.github/workflows/test_inc.yml +++ b/.github/workflows/test_inc.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - torch-version: ["2.4.*", "2.5.0"] + torch-version: ["2.4.0", "2.5.*"] runs-on: ubuntu-22.04 diff --git a/.github/workflows/test_ipex.yml b/.github/workflows/test_ipex.yml index 41330ed42b..de933e3795 100644 --- a/.github/workflows/test_ipex.yml +++ b/.github/workflows/test_ipex.yml @@ -18,8 +18,8 @@ jobs: strategy: fail-fast: false matrix: - torch-version: ["2.2.0", "2.3.*"] - transformers-version: ["4.39.0", "4.44.*"] + transformers-version: ["4.46.0", "4.46.3"] + torch-version: ["2.4.0", "2.5.*"] runs-on: ubuntu-22.04 @@ -38,10 +38,6 @@ jobs: pip install torch==${{ matrix.torch-version }} torchaudio torchvision --extra-index-url https://download.pytorch.org/whl/cpu pip install .[ipex,tests] transformers[testing]==${{ matrix.transformers-version }} intel_extension_for_pytorch==${{ matrix.torch-version }} - - if: ${{ matrix.torch-version == '2.2.0' }} - name: Downgrade Numpy - run: pip install numpy==1.* - - name: Assert versions run: | python -c "import torch; print(torch.__version__); assert torch.__version__.startswith('${{ matrix.torch-version }}'.replace('.*', ''))" diff --git a/docs/source/ipex/inference.mdx b/docs/source/ipex/inference.mdx index c712275e42..54b586924d 100644 --- a/docs/source/ipex/inference.mdx +++ b/docs/source/ipex/inference.mdx @@ -14,8 +14,8 @@ Optimum Intel can be used to load models from the [Hub](https://huggingface.co/m ## Loading -You can load your model and apply IPEX optimizations (including weight prepacking and graph mode). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators. -For now, support is only enabled for CPUs and the original model will be exported via TorchScript. In the future `torch.compile` will be used and model exported via TorchScript will get deprecated. +You can load your model and apply IPEX optimizations (apply torch.compile for non-generation tasks). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators. +For now, support is enabled for Intel CPU/GPU. Previous models converted to TorchScript will be deprecated in v1.22. ```diff import torch @@ -25,7 +25,7 @@ For now, support is only enabled for CPUs and the original model will be exporte model_id = "gpt2" - model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) -+ model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, export=True) ++ model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) results = pipe("He's a dreadful magician and") diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py new file mode 100755 index 0000000000..dec1e81895 --- /dev/null +++ b/optimum/exporters/ipex/cache_utils.py @@ -0,0 +1,238 @@ +from typing import List, Optional, Tuple + +import torch +from intel_extension_for_pytorch.llm.modules import PagedAttention +from transformers import Cache, PretrainedConfig + + +class IPEXPagedCache(Cache): + """ + A PagedCache that grows dynamically as more tokens are generated. everytime it grows block-size memory, vendor could set the pageCache memory layout. + ipex-xpu: + ipex-cpu: + + Example: + + ```python + >>> from transformers import AutoTokenizer + >>> from optimum.intel import IPEXModelForCausalLM + >>> from optimum.exporters.ipex.cache_utils import IPEXPagedCache + + >>> model = IPEXModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", export=True) + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = IPEXPagedCache() + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + """ + + def __init__( + self, + config: PretrainedConfig, + batch_size: int, + max_cache_len: int, + device, + dtype=None, + layer_device_map=None, + **kwargs, + ) -> None: + super().__init__() + self.batch_size = batch_size + # Used in `generate` to keep tally of how many tokens the cache has seen + self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device) + self.block_size = 16 + self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size + self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape( + batch_size, -1 + ) + self.free_blocks = torch.arange(self.num_blocks, device=device) + self.max_cache_len = max_cache_len + self.num_kv_heads = config.num_key_value_heads + self.num_hidden_layers = config.num_hidden_layers + if hasattr(config, "head_dim"): + head_size = config.head_dim + else: + head_size = config.hidden_size // config.num_attention_heads + self.head_size = head_size + self.max_seq_len = 0 + + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + + if device.type == "cpu": + key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size) + value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size) + elif device.type == "xpu": + key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1) + value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size) + for i in range(config.num_hidden_layers): + new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device) + new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + def update_for_prefill( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + batch_size: int, + input_lens: torch.Tensor, + ): + if layer_idx == 0: + all_block_indices = [] + all_slot_offsets = [] + num_blocks = (input_lens + self.block_size - 1) // self.block_size + for i in range(batch_size): + for b_idx in range(num_blocks[i]): + if self.block_tables[i][b_idx] == -1: + # need a free block + self.block_tables[i][b_idx] = self.free_blocks[0] + self.free_blocks = self.free_blocks[1:] + + slots_range = torch.arange(input_lens[i], device=key_states.device) + block_indices = slots_range // self.block_size + slot_offsets = slots_range % self.block_size + all_block_indices.append(self.block_tables[i][block_indices]) + all_slot_offsets.append(slot_offsets) + + all_block_indices = torch.cat(all_block_indices) + all_slot_offsets = torch.cat(all_slot_offsets) + self.slots = all_block_indices * self.block_size + all_slot_offsets + + # Update the cache + PagedAttention.reshape_and_cache( + key_states, + value_states, + self.key_cache[layer_idx], + self.value_cache[layer_idx], + self.slots, + ) + + # Update the number of seen tokens + if layer_idx == self.num_hidden_layers - 1: + self._seen_tokens = self._seen_tokens + input_lens + self.max_seq_len, _ = self._seen_tokens.max(dim=0) + + def update_for_decode( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + batch_size: int, + ): + if layer_idx == 0: + start_block_idx = self._seen_tokens // self.block_size + num_blocks = (self._seen_tokens + self.block_size) // self.block_size + slot_offset_in_block = (self._seen_tokens) % self.block_size + self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32) + for i in range(batch_size): + for b_idx in range(start_block_idx[i], num_blocks[i]): + if self.block_tables[i][b_idx] == -1: + # need a free block + self.block_tables[i][b_idx] = self.free_blocks[0] + self.free_blocks = self.free_blocks[1:] + + self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i] + # Update the cache + PagedAttention.reshape_and_cache( + key_states, + value_states, + self.key_cache[layer_idx], + self.value_cache[layer_idx], + self.slots, + ) + + # Update the number of seen tokens + if layer_idx == self.num_hidden_layers - 1: + self._seen_tokens = self._seen_tokens + 1 + self.max_seq_len = self.max_seq_len + 1 + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + attention_mask: torch.Tensor, + input_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + Return: + A tuple containing the updated key and value states. + """ + + batch_size = input_lens.shape[-1] + if self.get_seq_length() == 0: + # prefill + self.update_for_prefill(key_states, value_states, layer_idx, batch_size, input_lens) + else: + # decode + self.update_for_decode(key_states, value_states, layer_idx, batch_size) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + return self.max_seq_len + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states.""" + return self.max_cache_len + + def reset(self): + """Resets the cache values while preserving the objects""" + self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.block_tables.device) + self.block_tables.fill_(-1) + self.free_blocks = torch.arange(self.num_blocks, device=self.block_tables.device) + self.max_seq_len = 0 + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + device = self.block_tables.device + origin_table = self.block_tables.clone() + updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device)) + mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0) + num_blocks = mask.cumsum(-1)[:, -1] + updated_table = [] + for i in range(beam_idx.shape[0]): + self.block_tables[i, 0 : num_blocks[i] - 1] = updated_block_tables[i, 0 : num_blocks[i] - 1] + updated_table.append(self.block_tables[i : i + 1, num_blocks[i] - 1 : num_blocks[i]]) + updated_table = torch.cat(tuple(updated_table), dim=0) + for layer_idx in range(self.num_hidden_layers): + self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx][updated_table[beam_idx]] + self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx][updated_table[beam_idx]] + free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1)) + self.free_blocks = torch.cat((self.free_blocks, free_table)) + + def crop(self, maximum_length: int): + """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be + negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" + + max_seq_len = self.get_seq_length() + if maximum_length < 0: + maximum_length = max_seq_len - abs(maximum_length) + + if max_seq_len <= maximum_length: + return + origin_table = self.block_tables.clone() + for bs in range(self._seen_tokens.shape[0]): + new_tokens = self._seen_tokens[bs] + maximum_length - max_seq_len + num_blocks = (new_tokens + self.block_size - 1) // self.block_size + self.block_tables[bs, num_blocks:] = -1 + self._seen_tokens[bs] = new_tokens + self.max_seq_len, _ = self._seen_tokens.max(dim=0) + free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1)) + self.free_blocks = torch.cat((self.free_blocks, free_table)) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 484fd38077..03937754a6 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -13,11 +13,10 @@ # limitations under the License. from transformers.models.bert.modeling_bert import BertIntermediate -from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconForCausalLM -from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel +from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel +from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer, - LlamaForCausalLM, LlamaModel, LlamaRMSNorm, ) @@ -28,7 +27,9 @@ from .modeling_utils import ( _IPEX_MINIMUM_VERSION_FOR_PATCHING, + _falcon_model_forward, _gpt2_block_forward, + _gpt2_model_forward, _ipex_rms_layer_norm_forward, _IPEXFalconDecoderLayer, _IPEXGPT2Attention, @@ -39,8 +40,8 @@ # Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version -_TRANSFORMERS_MIN_VERSION = "4.39.0" -_TRANSFORMERS_MAX_VERSION = "4.44.99" +_TRANSFORMERS_MIN_VERSION = "4.46.0" +_TRANSFORMERS_MAX_VERSION = "4.46.99" _IPEX_EXPORTED_GENERATION_TASKS = ("text-generation",) @@ -75,7 +76,7 @@ def patch_op(m, target_m, new_op_name, new_op): def _patch_llama_model(model): """ Patch llama model: - 1. Use IPEX Rope and IAKV cache + 1. Use IPEX rope and paged cache 2. Linear fusion with (2 Linears + Silu + Mul) and (Linear + Add) """ convert_functions(model, LlamaModel, "forward", _llama_model_forward) @@ -87,11 +88,14 @@ def _patch_llama_model(model): def _patch_falcon_model(model): """ Patch falcon model: - 1. Disable SDPA so the attention mask will be compatible to ipex attention. - 2. Use IPEX Rope and IAKV cache - 3. Linear fusion with (Linear + Gelu) and (Linear + Add + Add) + 1. Use IPEX rope and paged cache + 2. Linear fusion with (Linear + Gelu) and (Linear + Add + Add) """ - model.transformer._use_sdpa = False + num_key_value_heads = ( + model.config.num_kv_heads if (model.config.new_decoder_architecture or not model.config.multi_query) else 1 + ) + setattr(model.config, "num_key_value_heads", num_key_value_heads) + convert_functions(model, FalconModel, "forward", _falcon_model_forward) replace_customized_linear_with_linear(model) convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.config) return model @@ -100,12 +104,13 @@ def _patch_falcon_model(model): def _patch_gpt2_model(model): """ Patch gpt2 model: - 1. Disable SDPA so the attention mask will be compatible to ipex attention. - 2. Use IAKV cache + 1. Use IPEX paged attention """ - model.transformer._attn_implementation = "eager" - convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config) + num_key_value_heads = model.config.num_attention_heads + setattr(model.config, "num_key_value_heads", num_key_value_heads) + convert_functions(model, GPT2Model, "forward", _gpt2_model_forward) convert_functions(model, GPT2Block, "forward", _gpt2_block_forward) + convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config) return model @@ -136,11 +141,11 @@ def _patch_model(model): raise ImportError( f"Only transformers versions {_TRANSFORMERS_MIN_VERSION} ~ {_TRANSFORMERS_MAX_VERSION} are verified." ) - if isinstance(model, LlamaForCausalLM): + if model.config.model_type == "llama": model = _patch_llama_model(model) - elif isinstance(model, FalconForCausalLM): + elif model.config.model_type == "falcon": model = _patch_falcon_model(model) - elif isinstance(model, GPT2LMHeadModel): + elif model.config.model_type == "gpt2": model = _patch_gpt2_model(model) elif model.config.model_type == "bert": model = _patch_bert_model(model) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py old mode 100644 new mode 100755 index 3d28350b86..ccd98ce2e9 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -18,19 +18,18 @@ import torch from torch import nn -from torch.nn import functional as F -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.gpt2.modeling_gpt2 import GPT2Block -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv +from transformers.cache_utils import Cache +from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions from optimum.intel.utils.import_utils import is_ipex_version from optimum.intel.utils.modeling_utils import _setattr_from_module +from .cache_utils import IPEXPagedCache + logger = logging.getLogger(__name__) -_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.3.0" +_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.4.0" if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): @@ -38,28 +37,114 @@ f"Please upgrade the IPEX version to at least {_IPEX_MINIMUM_VERSION_FOR_PATCHING} if you want to patch the model." ) else: + from intel_extension_for_pytorch.llm.functional import rms_norm, rotary_embedding, varlen_attention from intel_extension_for_pytorch.llm.modules import ( - IndirectAccessKVCacheAttention, Linear2SiluMul, LinearAdd, LinearAddAdd, LinearGelu, - RotaryEmbedding, + PagedAttention, ) +# TODO: Following XPULinearXXX op classes will be put into ipex after 2.6.0 version +class XPULinear2SiluMul(torch.nn.Module): + def __init__( + self, + gate_proj: torch.nn.Module, + up_proj: torch.nn.Module, + ): + super().__init__() + self.gate_proj_weight = gate_proj.weight.transpose(0, 1).contiguous() + self.up_proj_weight = up_proj.weight.transpose(0, 1).contiguous() + self.gate_proj_bias = gate_proj.bias + self.up_proj_bias = up_proj.bias + + def forward( + self, + hidden_states, + ): + up = torch.ops.torch_ipex.mm_silu(hidden_states, self.gate_proj_weight) + if self.gate_proj_bias is not None: + up += self.gate_proj_bias + hidden_states = torch.ops.torch_ipex.mm_resmul(hidden_states, self.up_proj_weight, up) + if self.up_proj_bias is not None: + hidden_states += self.up_proj_bias + return hidden_states + + +class XPULinearGelu(torch.nn.Module): + def __init__(self, module: torch.nn.Module): + super().__init__() + self.weight = module.weight.transpose(0, 1).contiguous() + self.bias = module.bias + + def forward(self, x): + return torch.ops.torch_ipex.matmul_gelu(x, self.weight, self.bias, 1.0, "tanh") + + +class XPULinearAdd(torch.nn.Module): + def __init__( + self, + module: torch.nn.Module, + ): + super().__init__() + self.weight = module.weight.transpose(0, 1).contiguous() + self.bias = module.bias + + def forward( + self, + hidden_states, + residual, + ): + token_len, _ = hidden_states.size() + if residual is None: + hidden_states = torch.matmul(hidden_states, self.weight) + if self.bias is not None: + hidden_states += self.bias + else: + if self.bias is not None: + hidden_states = torch.ops.torch_ipex.mm_bias_resadd( + hidden_states, self.weight, self.bias, 1.0, residual, 1.0 + ) + else: + hidden_states = torch.addmm( + residual.flatten(0, -2), + hidden_states.flatten(0, -2), + self.weight, + beta=1.0, + ) + hidden_states = hidden_states.view(token_len, -1) + return hidden_states + + +class XPUlinearAddAdd(torch.nn.Module): + def __init__(self, module: torch.nn.Module): + super().__init__() + self.weight = module.weight.transpose(0, 1).contiguous() + self.bias = module.bias + + def forward(self, x, y, z): + if self.bias is not None: + x = torch.ops.torch_ipex.mm_bias_resadd(x, self.weight, self.bias, 1.0, y, 1.0) + x += z + else: + x = torch.ops.torch_ipex.mm_bias_resadd(x, self.weight, z, 1.0, y, 1.0) + return x + + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83 def _ipex_rms_layer_norm_forward(self, hidden_states): - return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon) + return rms_norm(hidden_states, self.weight, self.variance_epsilon) -# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130 +# Adapted from https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llama/modeling_llama.py#L918 def _llama_model_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -85,9 +170,10 @@ def _llama_model_forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - past_key_values_length = 0 - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] + if past_key_values is not None and not isinstance(past_key_values, IPEXPagedCache): + raise ValueError("only support IPEXPagedCache input now") + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -99,15 +185,6 @@ def _llama_model_forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if getattr(self.config, "_flash_attn_2_enabled", False): - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - # embed positions hidden_states = inputs_embeds @@ -116,25 +193,41 @@ def _llama_model_forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None + position_embeddings = self.rotary_emb(hidden_states, position_ids) + if past_key_values_length == 0: + # first token, remove the padding from hidden_states, varlen do not accept attention mask + hidden_states_copy = hidden_states + index = attention_mask.view(-1) != 0 + hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index] + cos = position_embeddings[0] + sin = position_embeddings[1] + cos = (cos.reshape(-1, cos.shape[-1]))[index] + sin = (sin.reshape(-1, sin.shape[-1]))[index] + position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) + else: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + position_embeddings=position_embeddings, + input_lens=input_lens, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -146,6 +239,10 @@ def _llama_model_forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if hidden_states.shape[0] != batch_size * seq_length: + (hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states + hidden_states = hidden_states_copy + hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1]) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -156,17 +253,318 @@ def _llama_model_forward( ) -def _gpt2_block_forward(self, hidden_states, *args, **kwargs): - attention_mask = kwargs.get("attention_mask", None) - if attention_mask is not None: - bsz, seq_len, _ = hidden_states.size() - layer_past = kwargs.get("layer_past", None) - past_len = layer_past[0].size(-2) if layer_past is not None else 0 - attention_mask = (1 - attention_mask / torch.finfo(attention_mask.dtype).min).squeeze(1, 2) - attention_mask = _prepare_4d_causal_attention_mask(attention_mask, (bsz, seq_len), hidden_states, past_len) - kwargs["attention_mask"] = attention_mask +# Adapted from https://github.com/huggingface/transformers/blob/v4.46.1/src/transformers/models/falcon/modeling_falcon.py#L945 +def _falcon_model_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + batch_size, seq_length, _ = inputs_embeds.shape + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + if past_key_values_length == 0: + # first token, remove the padding from hidden_states, varlen do not accept attention mask + hidden_states_copy = hidden_states + index = attention_mask.view(-1) != 0 + hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index] + cos = position_embeddings[0] + sin = position_embeddings[1] + cos = (cos.reshape(-1, cos.shape[-1]))[index] + sin = (sin.reshape(-1, sin.shape[-1]))[index] + position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) + else: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + + next_decoder_cache = None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for i, block in enumerate(self.h): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = block( + hidden_states, + layer_past=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=None, + cache_position=cache_position, + position_embeddings=position_embeddings, + input_lens=input_lens, + ) + + hidden_states = outputs[0] + if use_cache is True: + next_decoder_cache = outputs[1] + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) - return GPT2Block.forward(self, hidden_states, *args, **kwargs) + next_cache = next_decoder_cache if use_cache else None + + if hidden_states.shape[0] != batch_size * seq_length: + (hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states + hidden_states = hidden_states_copy + + hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +def _gpt2_model_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + batch_size, seq_length, _ = inputs_embeds.shape + position_embeddings = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeddings + + encoder_attention_mask = None + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + if past_length == 0: + # first token, remove the padding from hidden_states, varlen do not accept attention mask + hidden_states_copy = hidden_states + index = attention_mask.view(-1) != 0 + hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index] + else: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + + presents = None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, block in enumerate(self.h): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = block( + hidden_states, + layer_past=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + input_lens=input_lens, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = outputs[1] + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + hidden_states = self.ln_f(hidden_states) + if hidden_states.shape[0] != batch_size * seq_length: + (hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states + hidden_states = hidden_states_copy + + hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# To pass input_lens, adapted from https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/gpt2/modeling_gpt2.py#L602 +def _gpt2_block_forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs, +) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + **kwargs, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) class _IPEXAttention(nn.Module): @@ -174,14 +572,11 @@ def __init__(self, module, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config - self.ipex_scale_dot_product = IndirectAccessKVCacheAttention(text_max_length=config.max_position_embeddings) - if hasattr(config, "rope_theta"): - self.ipex_rope = RotaryEmbedding( - config.max_position_embeddings, - config.hidden_size // config.num_attention_heads, - config.rope_theta, - config.architectures[0], - ) + self.module_device = next(module.parameters()).device + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=self.module_device + ).repeat_interleave(self.num_groups) def qkv_gemm(self, hidden_states): raise NotImplementedError("Need to implement in specific model class") @@ -189,29 +584,8 @@ def qkv_gemm(self, hidden_states): def rope(self, *args, **kwargs): raise NotImplementedError("Need to implement in specific model class") - def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask, **kwargs): - # This ipex op pre-allocates buffers for past_key_values and use beam index history - # which to decide which beam should be used to make attention scale dot more efficient. - (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( - query, - key, - value, - math.sqrt(self.head_dim), - past_key_value, - kwargs.get("head_mask", None), - attention_mask, - kwargs.get("alibi", None), - ) - return attn_output, past_key_value, attn_weights - - def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask, **kwargs): - raise NotImplementedError("Need to implement in specific model class") - - def prepare_attention_mask_float(self, attention_mask, *args): - return attention_mask - - def postprocess_attention_output(self, attn_output, bsz, seq_len): - attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, self.hidden_size) + def postprocess_attention_output(self, attn_output): + attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) return attn_output def forward( @@ -219,40 +593,60 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[IPEXPagedCache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # For llama inputs: https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/llama/modeling_llama.py#L308 - # For falcon inputs: https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/falcon/modeling_falcon.py#L370 if past_key_value is None and kwargs.get("layer_past", None) is not None: past_key_value = kwargs.pop("layer_past", None) - bsz, seq_len, _ = hidden_states.size() - past_len = past_key_value[0].size(-2) if past_key_value is not None else 0 - kv_seq_len = seq_len + past_len - - qkv_out = self.qkv_gemm(hidden_states) - if isinstance(qkv_out, tuple) and len(qkv_out) == 3: - query, key, value = self.qkv_gemm(hidden_states) - query, key = self.rope(query, key, kv_seq_len, use_cache, position_ids=position_ids) + input_lens = kwargs.pop("input_lens", None) + past_len = 0 + if past_key_value is not None: + past_len = past_key_value.get_seq_length() + query, key, value = self.qkv_gemm(hidden_states) + query, key = self.rope(query, key, **kwargs) + + if past_key_value is not None: + key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens) + + attn_output = torch.empty_like(query) + if past_len == 0: + # prefill, remove padding + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) + varlen_attention( + query.contiguous() if query.device.type == "xpu" else query, + key.contiguous() if key.device.type == "xpu" else key, + value.contiguous() if value.device.type == "xpu" else value, + attn_output, + seq_len_tensor, + seq_len_tensor, + input_lens.max(), + input_lens.max(), + 0.0, + 1.0 / math.sqrt(self.head_dim), + False, + True, + False, + None, + ) else: - query, key, value = self.rope(qkv_out, kv_seq_len, use_cache, past_len=past_len) - - attention_mask = self.prepare_attention_mask_float(attention_mask, query.dtype) - sdpa = self.sdpa_with_cache if use_cache else self.sdpa_without_cache - attn_output, past_key_value, attn_weights = sdpa( - query, - key, - value, - past_key_value, - attention_mask, - position_ids=position_ids, - head_mask=kwargs.get("head_mask", None), - alibi=kwargs.get("alibi", None), - ) - attn_output = self.postprocess_attention_output(attn_output, bsz, seq_len) + # decode + PagedAttention.single_query_cached_kv_attention( + attn_output, + query, + key_cache, + value_cache, + self.kv_head_mapping, + 1.0 / math.sqrt(self.head_dim), + past_key_value.block_tables, + input_lens, + past_key_value.block_size, + input_lens.max(), + None, + ) + attn_output = self.postprocess_attention_output(attn_output) if not output_attentions: attn_weights = None @@ -262,105 +656,83 @@ def forward( class _IPEXLlamaAttention(_IPEXAttention): def __init__(self, module, config) -> None: super().__init__(module, config) - if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: - self.mha_linear_add = LinearAdd(module.o_proj) - del self.__dict__["_modules"]["o_proj"] + concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]).contiguous() + bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias] + use_bias = bias_list != [] + self.concat_qkv = nn.Linear(concat_weight.shape[1], concat_weight.shape[0], bias=use_bias) + self.concat_qkv.weight = nn.Parameter(concat_weight) + if use_bias: + concat_bias = torch.concat(bias_list, 0).contiguous() + self.concat_linear.bias = nn.Parameter(concat_bias) + self.q_slice = self.q_proj.out_features + self.k_slice = self.q_slice + self.k_proj.out_features + self.v_slice = self.k_slice + self.v_proj.out_features + if self.module_device.type == "cpu": + if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: + self.mha_linear_add = LinearAdd(module.o_proj) + + elif self.module_device.type == "xpu": + if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: + self.mha_linear_add = XPULinearAdd(module.o_proj) def qkv_gemm(self, hidden_states): - bsz, seq_len, _ = hidden_states.size() - query = self.q_proj(hidden_states).view(bsz, seq_len, self.num_heads, self.head_dim) - key = self.k_proj(hidden_states).view(bsz, seq_len, self.num_key_value_heads, self.head_dim) - value = self.v_proj(hidden_states).view(bsz, seq_len, self.num_key_value_heads, self.head_dim) + qkv_out = self.concat_qkv(hidden_states) + query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim) + key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim) + value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim) return query, key, value - def rope(self, query, key, kv_seq_len, use_cache, position_ids): - if use_cache: - args = (self.head_dim, self.head_dim // 2, self.head_dim, kv_seq_len) - key = self.ipex_rope(key, position_ids, self.num_key_value_heads, *args) - query = self.ipex_rope(query, position_ids, self.num_heads, *args) + def rope(self, query, key, **kwargs): + position_embeddings = kwargs.pop("position_embeddings", None) + rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True) return query, key - # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L341 - def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask, position_ids, **kwargs): - query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) - cos, sin = self.rotary_emb(value, position_ids) - query, key = apply_rotary_pos_emb(query, key, cos, sin) - # repeat k/v heads if n_kv_heads < n_heads - key = repeat_kv(key, self.num_key_value_groups) - value = repeat_kv(value, self.num_key_value_groups) - attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask) - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_output = torch.matmul(attn_weights, value) - - return attn_output, None, attn_weights - class _IPEXFalconAttention(_IPEXAttention): - def qkv_gemm(self, hidden_states): - return self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + def __init__(self, module, config): + self.num_key_value_heads = config.num_key_value_heads + super().__init__(module, config) + self.q_slice = self.head_dim * config.num_kv_heads + self.k_slice = self.q_slice + self.head_dim + self.v_slice = self.k_slice + self.head_dim - def rope(self, fused_qkv, seq_len, use_cache, past_len): - if use_cache: - query, key, value = self.ipex_rope( - fused_qkv, - torch.tensor(past_len), - self.num_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - seq_len, - 3, - ) + def qkv_gemm(self, hidden_states): + qkv_out = self.query_key_value(hidden_states) + if self.new_decoder_architecture: + qkv_out = qkv_out.view(qkv_out.shape[0], -1, self.num_heads // self.num_kv_heads + 2, self.head_dim) + query = qkv_out[:, :, :-2, :].flatten(1, 2) + key = qkv_out[:, :, [-2], :].flatten(1, 2) + value = qkv_out[:, :, [-1], :].flatten(1, 2) else: - (query, key, value) = self._split_heads(fused_qkv) + query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim) + key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim) + value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim) return query, key, value - def prepare_attention_mask_float(self, attention_mask, dtype): - attention_mask_float = ( - (attention_mask * 1.0).masked_fill(attention_mask.to(torch.bool), float("-1e9")).to(dtype) - ) - return attention_mask_float - - def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask, **kwargs): - bs, q_len = query.shape[0], query.shape[1] - query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) - attn_output = F.scaled_dot_product_attention(query, key, value, attention_mask, 0.0, is_causal=False) - attn_output = attn_output.view(bs, self.num_heads, q_len, self.head_dim) - - return attn_output, None, None + def rope(self, query, key, **kwargs): + position_embeddings = kwargs.pop("position_embeddings", None) + rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True) + return query, key class _IPEXGPT2Attention(_IPEXAttention): def __init__(self, module, config) -> None: + self.num_key_value_heads = config.num_key_value_heads super().__init__(module, config) - def _split_heads_ipex(self, tensor, num_heads, attn_head_size): - new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - return tensor.view(new_shape) # (batch, seq_length, head, head_features) - def qkv_gemm(self, hidden_states): - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - query = self._split_heads_ipex(query, self.num_heads, self.head_dim) - key = self._split_heads_ipex(key, self.num_heads, self.head_dim) - value = self._split_heads_ipex(value, self.num_heads, self.head_dim) + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1) + query = query.view(-1, self.num_heads, self.head_dim) + key = key.view(-1, self.num_heads, self.head_dim) + value = value.view(-1, self.num_heads, self.head_dim) return query, key, value def rope(self, query, key, *args, **kwargs): return query, key - def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask, **kwargs): - query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) - attn_output = F.scaled_dot_product_attention(query, key, value, attention_mask, 0.0, is_causal=True) - - return attn_output, None, None - - def postprocess_attention_output(self, attn_output, bsz, seq_len): - attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, self.embed_dim) + def postprocess_attention_output(self, attn_output): + attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) return attn_output @@ -372,13 +744,17 @@ def __init__(self, module, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config - # LinearAllreduce and LinearLayer cannot use fused op LinearAdd - if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]: - self.mlp_linear_add = LinearAdd(module.down_proj) - del self.__dict__["_modules"]["down_proj"] - self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) - del self.__dict__["_modules"]["gate_proj"] - del self.__dict__["_modules"]["up_proj"] + self.module_device = next(module.parameters()).device.type + if self.module_device == "cpu": + # LinearAllreduce and LinearLayer cannot use fused op LinearAdd + if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]: + self.mlp_linear_add = LinearAdd(module.down_proj) + self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) + elif self.module_device == "xpu": + # LinearAllreduce and LinearLayer cannot use fused op LinearAdd + if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]: + self.mlp_linear_add = XPULinearAdd(module.down_proj) + self.linear_silu_mul = XPULinear2SiluMul(module.gate_proj, module.up_proj) def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs): if hasattr(self, "linear_silu_mul"): @@ -401,11 +777,16 @@ def __init__(self, module, config) -> None: _setattr_from_module(self, module) self.config = config # LinearAllreduce and LinearLayer cannot use fused op LinearAdd - self.linear_gelu = LinearGelu(module.dense_h_to_4h) - del self.__dict__["_modules"]["dense_h_to_4h"] + self.module_device = next(module.parameters()).device.type + if self.module_device == "cpu": + self.linear_gelu = LinearGelu(module.dense_h_to_4h) + elif self.module_device == "xpu": + self.linear_gelu = XPULinearGelu(module.dense_h_to_4h) if module.dense_4h_to_h.__class__.__name__ not in ["LinearAllreduce"]: - self.linear_add_add = LinearAddAdd(module.dense_4h_to_h) - del self.__dict__["_modules"]["dense_4h_to_h"] + if self.module_device == "cpu": + self.linear_add_add = LinearAddAdd(module.dense_4h_to_h) + elif self.module_device == "xpu": + self.linear_add_add = XPUlinearAddAdd(module.dense_4h_to_h) def forward( self, @@ -490,7 +871,6 @@ def __init__(self, module, config): super().__init__() _setattr_from_module(self, module) self.linear_gelu = LinearGelu(module.dense) - del self.__dict__["_modules"]["dense"] def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.linear_gelu(hidden_states) diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index 22a4745f0c..a6e8a76f4f 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -373,6 +373,7 @@ def _from_pretrained( file_name: Optional[str] = WEIGHTS_NAME, local_files_only: bool = False, use_cache: bool = True, + subfolder: str = None, **kwargs, ): if use_auth_token is not None: @@ -402,6 +403,7 @@ def _from_pretrained( cache_dir=cache_dir, force_download=force_download, local_files_only=local_files_only, + subfolder=subfolder, ) model_save_dir = Path(model_cache_path).parent model = cls.load_model(model_cache_path) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 739a2f2b44..8611bddd21 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -16,17 +16,12 @@ import inspect import logging import os -import warnings from pathlib import Path from tempfile import TemporaryDirectory from typing import Dict, Optional, Tuple, Union -import intel_extension_for_pytorch as ipex import torch import transformers -from huggingface_hub import hf_hub_download -from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE -from intel_extension_for_pytorch.cpu._auto_kernel_selection import _enable_tpp from transformers import ( AutoConfig, AutoModel, @@ -40,29 +35,23 @@ GenerationConfig, GenerationMixin, PretrainedConfig, - is_torch_xpu_available, ) from transformers.dynamic_module_utils import get_class_from_dynamic_module from transformers.generation.candidate_generator import _crop_past_key_values -from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput +from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.auto.auto_factory import _get_model_class as get_model_class -from transformers.utils import WEIGHTS_NAME -from optimum.exporters import TasksManager -from optimum.exporters.tasks import make_backend_config_constructor_for_task from optimum.modeling_base import OptimizedModel from optimum.utils import NormalizedConfigManager -from ...exporters.ipex.model_config import ipex_onnx_config +from ...exporters.ipex.cache_utils import IPEXPagedCache from ...exporters.ipex.model_patcher import ( _IPEX_EXPORTED_GENERATION_TASKS, _IPEX_MINIMUM_VERSION_FOR_PATCHING, _patch_model, ) -from ..generation.modeling import get_float_type -from ..utils.constant import _TASK_ALIASES -from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version -from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, recursive_to_device +from ..generation.modeling import prepare_jit_inputs +from ..utils.import_utils import is_ipex_version, is_transformers_version logger = logging.getLogger(__name__) @@ -70,91 +59,19 @@ _IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2") _IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation") +_IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0" +# TODO: Already fixed in torch 2.6, will enable when torch upgrading to 2.6 +_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "beit") -def _is_patched_with_ipex(model, task): +def _is_patched_with_ipex(model, task, use_cache: bool = True): if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): return False - - if isinstance(model, torch.jit.ScriptModule): - for node in model.graph.nodes(): - # Only patched model enabled fusion linear. - if "/fusions/" in node.__str__(): - return True - return False - elif task in _IPEX_EXPORTED_GENERATION_TASKS and model.config.hidden_size < 64: - # The ipex IAKV op in patched model requires the hidden size at least 64 + if not use_cache and task in _IPEX_EXPORTED_GENERATION_TASKS: return False - return model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES -def _prepare_inputs_for_ipex_model(model, task, use_cache): - task = _TASK_ALIASES.get(task, task) - signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__) - if _is_patched_with_ipex(model, task) and model.config.model_type in ipex_onnx_config: - onnx_config_class = make_backend_config_constructor_for_task( - ipex_onnx_config[model.config.model_type], task=task - ) - else: - onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) - float_dtype = get_float_type(model.dtype) - if "text-generation" in task: - onnx_config = onnx_config_class( - model.config, use_past=use_cache, use_past_in_inputs=use_cache, float_dtype=float_dtype - ) - else: - onnx_config = onnx_config_class(model.config) - - dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt") - - # Check attention_mask shape - if _is_patched_with_ipex(model, task) and model.config.model_type in ipex_onnx_config and use_cache: - past_len = dummy_inputs["past_key_values"][0][0].shape[-2] - input_len = dummy_inputs["input_ids"].shape[-1] - attention_len = dummy_inputs["attention_mask"].shape[-1] - if attention_len != input_len + past_len: - dummy_inputs["attention_mask"] = torch.ones([dummy_inputs["input_ids"].shape[0], input_len + past_len]).to( - dummy_inputs["input_ids"].dtype - ) - - return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None} - - -def ipex_jit_trace(model, task, use_cache): - # Only support torch version >= 2.1.0 to support example_kwarg_inputs in jit.trace - if is_torch_version("<", "2.1.0"): - raise ImportError("`torch>=2.1.0` is needed to trace your model") - - if _is_patched_with_ipex(model, task): - model = _patch_model(model) - - sample_inputs = _prepare_inputs_for_ipex_model(model, task, use_cache) - - model.config.return_dict = False - model.config.use_cache = use_cache - - # Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755. - # Only ipex >= 2.3.0 supports tpp. The tpp is only verified for llm in generation tasks. - if is_ipex_version(">=", "2.3.0") and task in _IPEX_EXPORTED_GENERATION_TASKS: - _enable_tpp() - model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True) - # Disable repack while jit tracing to reduce the memory - ipex._C.disable_jit_linear_repack() - with torch.no_grad(): - trace_model = torch.jit.trace( - model, - example_kwarg_inputs=sample_inputs, - strict=False, - check_trace=False, - ) - trace_model = torch.jit.freeze(trace_model) - trace_model(**sample_inputs) - trace_model(**sample_inputs) - - return trace_model - - class IPEXModel(OptimizedModel): auto_model_class = AutoModel export_feature = "feature-extraction" @@ -166,49 +83,46 @@ def __init__( self, model, config: PretrainedConfig = None, - export: bool = False, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, - warmup: bool = True, **kwargs, ): - if is_torch_xpu_available(check_device=True): - self._device = torch.device("xpu:0") - elif torch.cuda.is_available(): - self._device = torch.device("cuda:0") - else: - self._device = torch.device("cpu") - - # CPU only support jit model for now. - if export: - if isinstance(model, torch.jit.RecursiveScriptModule): - logger.warning("The model has been exported already.") - else: - config = model.config if config is None else config - use_cache = kwargs.get("use_cache", True) - model = ipex_jit_trace(model, self.export_feature, use_cache) - config.torchscript = True - + config = config or model.config OptimizedModel.__init__(self, model=model, config=config) - self.model.to(self._device) - self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32 + self._dtype = self.model.dtype if self.model.dtype is not None else torch.float32 + self.use_cache = kwargs.get("use_cache", False) self.model_save_dir = model_save_dir - self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature) + self._add_patch = _is_patched_with_ipex(model, self.export_feature, self.use_cache) - if isinstance(model, torch.jit.RecursiveScriptModule): - self.input_names = { - inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self" - } - else: - self.input_names = set(inspect.signature(model.forward).parameters) + self.input_names = set(inspect.signature(model.forward).parameters) + if self._add_patch: + model = _patch_model(model) # Registers the IPEXModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating # a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863 AutoConfig.register(self.base_model_prefix, AutoConfig) if hasattr(self.auto_model_class, "register"): self.auto_model_class.register(AutoConfig, self.__class__) - if warmup: - self._init_warmup() + + # Non-generation tasks can use torch.compile to get acceleration. + if ( + model.device.type == "cpu" + and self.export_feature not in _IPEX_EXPORTED_GENERATION_TASKS + and config.model_type not in _COMPILE_NOT_READY_MODEL_TYPES + and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_COMPILE) + ): + from torch._inductor import config + + # System level optimization + torch._inductor.config.cpp_wrapper = True + os.environ["TORCHINDUCTOR_FREEZING"] = "1" + logger.info("Enable torch.compile optimization, start warm up") + self.model.forward = torch.compile(self.model.forward) + inputs = prepare_jit_inputs(model, self.export_feature, False) + with torch.no_grad(): + self.model(**inputs) + self.model(**inputs) + logger.info("Warm up end") @classmethod def _from_transformers(cls, *args, **kwargs): @@ -219,16 +133,6 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: PretrainedConfig, - use_auth_token: Optional[Union[bool, str]] = None, - token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, - force_download: bool = False, - cache_dir: Union[str, Path] = HUGGINGFACE_HUB_CACHE, - subfolder: str = "", - local_files_only: bool = False, - torch_dtype: Optional[Union[str, "torch.dtype"]] = None, - trust_remote_code: bool = False, - file_name: Optional[str] = WEIGHTS_NAME, **kwargs, ): """ @@ -240,121 +144,23 @@ def _from_pretrained( Can be either: - The model id of a pretrained model hosted inside a model repo on huggingface.co. - The path to a directory containing the model weights. - use_auth_token (Optional[Union[bool, str]], defaults to `None`): - Deprecated. Please use `token` instead. - token (Optional[Union[bool, str]], defaults to `None`): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `huggingface-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*): - The specific model version to use. It can be a branch name, a tag name, or a commit id. - force_download (`bool`, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - cache_dir (`Union[str, Path]`, *optional*): - The path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - subfolder (`str`, *optional*) - In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can specify the folder name here. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - torch_dtype (`Optional[Union[str, "torch.dtype"]]`, *optional*) - float16 or bfloat16 or float32: load in a specified dtype, ignoring the model config.torch_dtype if one exists. If not specified, the model will get loaded in float32. - trust_remote_code (`bool`, *optional*) - Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the model repository. - file_name (`str`, *optional*): - The file name of the model to load. Overwrites the default file name and allows one to load the model - with a different name. """ - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", - FutureWarning, - ) - if token is not None: - raise ValueError( - "Both the arguments `use_auth_token` and `token` were specified, which is not supported. Please specify only `token`." - ) - token = use_auth_token - - commit_hash = kwargs.pop("_commit_hash", None) - - model_kwargs = { - "revision": revision, - "token": token, - "cache_dir": cache_dir, - "subfolder": subfolder, - "local_files_only": local_files_only, - "force_download": force_download, - } - - if not getattr(config, "torchscript", False): - logger.warning("Detect torchscript is false. Convert to torchscript model!") - - if is_torch_version("<", "2.1.0"): - raise ImportError("`torch>=2.0.0` is needed to trace your model") - - task = cls.export_feature - config.torch_dtype = torch_dtype - model = TasksManager.get_model_from_task( - task, - model_id, - library_name="transformers", - trust_remote_code=trust_remote_code, - torch_dtype=torch_dtype, - _commit_hash=commit_hash, - **model_kwargs, - ) + if getattr(config, "torchscript", False): + raise ValueError("IPEXModel is no longer support torchscript models.") - return cls(model, config=config, export=True, **kwargs) - - # Load the model from local directory - if os.path.isdir(model_id): - model_cache_path = os.path.join(model_id, file_name) - model_save_dir = model_id - # Download the model from the hub - else: - model_cache_path = hf_hub_download(repo_id=model_id, filename=file_name, **model_kwargs) - model_save_dir = Path(model_cache_path).parent - - model = torch.jit.load(model_cache_path) - torch.jit.freeze(model.eval()) - - return cls(model, config=config, model_save_dir=model_save_dir, **kwargs) + model = cls.auto_model_class.from_pretrained(model_id, **kwargs) + return cls(model, config=model.config, **kwargs) def _save_pretrained(self, save_directory: Union[str, Path]): - output_path = os.path.join(save_directory, WEIGHTS_NAME) - if getattr(self.config, "torchscript", None): - torch.jit.save(self.model, output_path) - else: - logger.warning("The module is not a torchscript model, will be treated as a transformers model.") - self.model.save_pretrained(output_path) - - def forward( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: torch.Tensor = None, - position_ids: torch.Tensor = None, - **kwargs, - ): - inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - - if "token_type_ids" in self.input_names: - inputs["token_type_ids"] = token_type_ids + self.model.save_pretrained(save_directory, safe_serialization=False) - if "position_ids" in self.input_names: - inputs["position_ids"] = position_ids + def push_to_hub(self, *args, **kwargs): + kwargs["safe_serialization"] = False + return self.model.push_to_hub(*args, **kwargs) - outputs = self._call_model(**inputs) - if isinstance(outputs, dict): - model_output = ModelOutput(**outputs) - else: - model_output = ModelOutput() - model_output[self.output_name] = outputs[0] - return model_output + @torch.no_grad() + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) def eval(self): self.model.eval() @@ -362,7 +168,7 @@ def eval(self): @property def device(self) -> torch.device: - return self._device + return self.model.device @property def dtype(self) -> torch.dtype: @@ -375,34 +181,17 @@ def model_dtype(self): ) return self._dtype + @property + def add_patch(self) -> bool: + return self._add_patch + def to(self, device: Union[torch.device, str]): - self._device = device if isinstance(device, torch.device) else torch.device(device) - self.model.to(self._device) + self.model.to(device) return self def can_generate(self): return isinstance(self, GenerationMixin) - def _call_model(self, *args, **kwargs): - try: - with torch.autocast(self.device.type, self.dtype), torch.no_grad(): - out = self.model(*args, **kwargs) - except RuntimeError: - out = self.model(*args, **kwargs) - return out - - def _init_warmup(self): - # warmup, the first 2 forwards of an IPEX model include some preprocessing steps and - # the results of the compute are unpredictable - # TODO : add warmup for IPEX exported model - if not self._is_ipex_exported: - use_cache = "past_key_values" in self.input_names - dummy_inputs = _prepare_inputs_for_ipex_model(self, self.export_feature, use_cache) - if self._device.type != "cpu": - dummy_inputs = recursive_to_device(value=dummy_inputs, device=self._device) - for _ in range(2): - self(**dummy_inputs) - class IPEXModelForSequenceClassification(IPEXModel): auto_model_class = AutoModelForSequenceClassification @@ -426,98 +215,44 @@ class IPEXModelForImageClassification(IPEXModel): auto_model_class = AutoModelForImageClassification export_feature = "image-classification" - def forward( - self, - pixel_values: torch.Tensor, - **kwargs, - ): - inputs = { - "pixel_values": pixel_values, - } - - outputs = self._call_model(**inputs) - return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) - class IPEXModelForAudioClassification(IPEXModel): auto_model_class = AutoModelForAudioClassification export_feature = "audio-classification" - def forward( - self, - input_values: torch.Tensor, - attention_mask: torch.Tensor = None, - **kwargs, - ): - inputs = { - "input_values": input_values, - } - - if "attention_mask" in self.input_names: - inputs["attention_mask"] = attention_mask - - outputs = self._call_model(**inputs) - return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) - class IPEXModelForQuestionAnswering(IPEXModel): auto_model_class = AutoModelForQuestionAnswering export_feature = "question-answering" - def forward( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: torch.Tensor = None, - **kwargs, - ): - inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - - if "token_type_ids" in self.input_names: - inputs["token_type_ids"] = token_type_ids - - outputs = self._call_model(**inputs) - start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0] - end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1] - return ModelOutput(start_logits=start_logits, end_logits=end_logits) - class IPEXModelForCausalLM(IPEXModel, GenerationMixin): auto_model_class = AutoModelForCausalLM export_feature = "text-generation" - _supports_cache_class = False - _is_stateful = False def __init__( self, model, config: PretrainedConfig = None, - export: bool = False, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, use_cache: bool = True, - warmup: bool = True, **kwargs, ): - # Perform the initial warmup at the end of __init__ - super().__init__( - model, config, export=export, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache - ) + super().__init__(model, config, model_save_dir=model_save_dir, use_cache=use_cache) + + self._supports_cache_class = getattr(model, "_supports_cache_class", None) + self._supports_sdpa = getattr(model, "_supports_sdpa", None) + self._supports_cache_class = getattr(model, "_supports_cache_class", None) + self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None) + self._supports_static_cache = getattr(model, "_supports_static_cache", None) + + if self._add_patch: + self._supports_cache_class = True GenerationMixin.__init__(self) model_type = self.config.model_type.replace("_", "-") self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(self.config) - self.use_cache = "past_key_values" in self.input_names - if isinstance(model, torch.jit.RecursiveScriptModule) and use_cache ^ self.use_cache: - raise ValueError( - f"`use_cache` was set to `{use_cache}` but the loaded model only supports `use_cache={self.use_cache}`. " - f"Please load your current model with `use_cache={self.use_cache}` or export the original model " - f"once again with `use_cache={use_cache}` when calling the `from_pretrained` method. " - "To export your model, simply set `export=True`." - ) self.config.is_decoder = True self.config.is_encoder_decoder = False @@ -529,140 +264,19 @@ def __init__( except AttributeError: self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping) - if self._is_ipex_exported: - self._reorder_cache = _ipex_reorder_cache - else: - # Check if _reorder_cache is a static method - if "_reorder_cache" in self.model_cls.__dict__ and isinstance( - self.model_cls.__dict__["_reorder_cache"], staticmethod - ): - self._reorder_cache = self.model_cls._reorder_cache - elif "_reorder_cache" in self.model_cls.__dict__: - self._reorder_cache = self.model_cls._reorder_cache.__get__(self) - - if is_transformers_version(">=", "4.38.0") and model_type in { - "llama", - "phi", - "persimmon", - "mistral", - "falcon", - "gpt2", - }: - self.prepare_inputs_for_generation = _ipex_prepare_inputs_for_generation - else: - self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self) - if hasattr(self.model_cls, "_convert_to_standard_cache"): self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache if hasattr(self.model_cls, "_convert_to_bloom_cache"): self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache - if warmup: - self._init_warmup() - - def _prepare_past_key_values(self, input_ids): - model_type = self.config.model_type.replace("_", "-") - nb_pkv = 2 - num_layers = self.normalized_config.num_layers - d_k = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads - batch_size = input_ids.shape[0] - - if model_type in {"mistral", "llama", "falcon"}: - num_attention_heads = getattr(self.normalized_config, "num_key_value_heads", 1) - else: - num_attention_heads = self.normalized_config.num_attention_heads - - if self._is_ipex_exported: - # Indirect access kv cache has a different data layout compared with most transformers model, - # see https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/llm.html#indirect-access-kv-cache - beam_idx_tmp = torch.zeros( - (self.config.max_position_embeddings, input_ids.shape[0]), dtype=torch.long - ).contiguous() - past_key_values = tuple( - [ - ( - torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), - torch.zeros([1, 1, 1, 1]).contiguous(), - torch.zeros([1, 1, 1, 1]).contiguous(), - beam_idx_tmp, - ) - for i in range(num_layers) - ] - ) - return past_key_values - elif model_type == "bloom" and is_transformers_version("<", "4.44"): - shape_key = (batch_size * num_attention_heads, d_k, 0) - shape_value = (batch_size * num_attention_heads, 0, d_k) - key = torch.empty(size=shape_key, dtype=self.model_dtype, device=self._device) - value = torch.empty(size=shape_value, dtype=self.model_dtype, device=self._device) - past_key_values = tuple( - tuple(key if idx % 2 == 0 else value for idx in range(nb_pkv)) for _ in range(num_layers) - ) - elif model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS: - shape = (batch_size, 0, d_k * 2) - pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device) - past_key_values = tuple(pkv for _ in range(num_layers)) - else: - shape = (batch_size, num_attention_heads, 0, d_k) - pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device) - past_key_values = tuple(tuple(pkv for _ in range(nb_pkv)) for _ in range(num_layers)) - - return past_key_values - - # Temporary fix, will delete when https://github.com/huggingface/transformers/pull/31226 release. - def _get_initial_cache_position(self, input_ids, model_kwargs): - """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" - if not model_kwargs.get("use_cache", True): - model_kwargs["cache_position"] = None - return model_kwargs - - past_length = 0 - if "past_key_values" in model_kwargs: - past_length = model_kwargs["past_key_values"][0][0].shape[-2] - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] - else: - cur_len = input_ids.shape[-1] - model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) - return model_kwargs + @torch.no_grad() def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - position_ids: Optional[torch.FloatTensor] = None, **kwargs, ) -> CausalLMOutputWithPast: - # 1. Prepare model inputs - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - - inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - - if "position_ids" in self.input_names or not self.input_names: - inputs["position_ids"] = position_ids - - if self.use_cache: - if past_key_values is None: - past_key_values = self._prepare_past_key_values(input_ids) - - inputs["past_key_values"] = past_key_values - - # 2. Model forward - outputs = self._call_model(**inputs) - - # 3. Process model outputs - if isinstance(outputs, (list, tuple)): - logits = outputs[0] - past_key_values = outputs[1] if self.use_cache else None - else: - logits = outputs["logits"] - past_key_values = outputs["past_key_values"] if self.use_cache else None - - return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) + return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) def _prepare_generation_config( self, generation_config: Optional[GenerationConfig], **kwargs: Dict @@ -676,15 +290,32 @@ def _prepare_generation_config( return generation_config, model_kwargs + def _reorder_cache(self, *args, **kwargs): + return self.model._reorder_cache(*args, **kwargs) + + def prepare_inputs_for_generation(self, *args, **kwargs): + return self.model.prepare_inputs_for_generation(*args, **kwargs) + def generate(self, *args, **kwargs): - if is_ipex_version("<", "2.4.0") and self._is_ipex_exported and kwargs.get("assistant_model", None): + if is_ipex_version("<", "2.4.0") and self._add_patch and kwargs.get("assistant_model", None): raise ValueError( f"Assisted decoding is not supported for patched models if ipex < 2.4, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}" ) - # Patch functions to support IAKV cache - if self._is_ipex_exported and kwargs.get("assistant_model", None): + # Patch functions to support ipex_paged cache + if self._add_patch: + transformers.generation.utils.NEED_SETUP_CACHE_CLASSES_MAPPING["ipex_paged"] = IPEXPagedCache + self.generation_config.cache_implementation = "ipex_paged" + if is_transformers_version(">=", "4.45.0"): + if "ipex_paged" not in transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS: + transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS.append("ipex_paged") + if kwargs.get("generation_config", None): + # Change cache implementation temporarily + orig_cache_implementation = kwargs["generation_config"].cache_implementation + kwargs["generation_config"].cache_implementation = "ipex_paged" + + if self._add_patch and kwargs.get("assistant_model", None): transformers.generation.utils._crop_past_key_values = _ipex_crop_past_key_values - elif self._is_ipex_exported: + elif self._add_patch: transformers.generation.candidate_generator._crop_past_key_values = _ipex_crop_past_key_values try: @@ -694,100 +325,23 @@ def generate(self, *args, **kwargs): transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values raise e - if self._is_ipex_exported and kwargs.get("assistant_model", None): + if self._add_patch and kwargs.get("assistant_model", None): transformers.generation.utils._crop_past_key_values = _crop_past_key_values transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values - return result - - -def _ipex_prepare_inputs_for_generation( - input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs -): - from transformers.cache_utils import Cache + # change back cache_implementation + if self._add_patch and kwargs.get("generation_config", None): + kwargs["generation_config"].cache_implementation = orig_cache_implementation - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - -def _ipex_reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor -) -> Tuple[Tuple[torch.Tensor]]: - # Ipex patched model uses indirect access kv cache which has a different shape with other transformers models - if len(past_key_values[0]) == 4 and past_key_values[0][0].shape[-1] == 1: - for layer_past in past_key_values: - layer_past[3][layer_past[0].size(-2) - 1] = beam_idx - return past_key_values - elif len(past_key_values[0]) == 8: - for layer_past in past_key_values: - layer_past[3][layer_past[0].size(-2) - 1] = beam_idx - layer_past[7][layer_past[0].size(-2) - 1] = beam_idx - return past_key_values - else: - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) + return result def _ipex_crop_past_key_values(model, past_key_values, max_length): if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"): - new_past_key_values = [] - for i in range(len(past_key_values)): - pkv = [] - pkv.append(past_key_values[i][0][:, :max_length, :max_length, :]) - pkv += [past_key_values[i][_] for _ in range(1, 4)] - new_past_key_values.append(tuple(pkv)) - new_past_key_values = tuple(new_past_key_values) - return new_past_key_values + if isinstance(past_key_values, IPEXPagedCache): + # .crop is an inplace op, returns None + past_key_values = past_key_values.crop(max_length) + return past_key_values + else: + raise ValueError("only support IPEXPagedCache input now") return _crop_past_key_values(model, past_key_values, max_length) diff --git a/setup.py b/setup.py index cd49ea041a..7f7e91df33 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,7 @@ "nncf": ["nncf>=2.14.0"], "openvino": ["nncf>=2.14.0", "openvino>=2024.5.0", "openvino-tokenizers>=2024.5.0"], "neural-compressor": ["neural-compressor[pt]>3.0", "accelerate", "transformers<4.46"], - "ipex": ["intel-extension-for-pytorch>=2.2,<2.4", "transformers>=4.39,<4.45"], + "ipex": ["intel-extension-for-pytorch>=2.4", "transformers>4.45,<4.47"], "diffusers": ["diffusers"], "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE, diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 53c733c4f5..7f1104d7f7 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -44,12 +44,12 @@ IPEXModelForSequenceClassification, IPEXModelForTokenClassification, ) -from optimum.intel.utils.import_utils import is_ipex_version from optimum.utils.testing_utils import grid_parameters -from utils_tests import MODEL_NAMES +from utils_tests import MODEL_NAMES, IS_XPU_AVAILABLE SEED = 42 +torch.use_deterministic_algorithms(True) class Timer(object): @@ -74,17 +74,21 @@ class IPEXModelTest(unittest.TestCase): "squeezebert", "xlm", ) + IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("bert",) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] set_seed(SEED) - ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) + ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id) + if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES: + self.assertTrue(ipex_model.add_patch) + device = ipex_model.device self.assertIsInstance(ipex_model.config, PretrainedConfig) - transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id) + transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id).to(device) tokenizer = AutoTokenizer.from_pretrained(model_id) inputs = "This is a sample input" - tokens = tokenizer(inputs, return_tensors="pt") + tokens = tokenizer(inputs, return_tensors="pt").to(device) with torch.no_grad(): transformers_outputs = transformers_model(**tokens) outputs = ipex_model(**tokens) @@ -95,21 +99,20 @@ def test_compare_to_transformers(self, model_arch): loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname) loaded_model_outputs = loaded_model(**tokens) # Test init method - init_model = self.IPEX_MODEL_CLASS(transformers_model, export=True) + init_model = self.IPEX_MODEL_CLASS(transformers_model) init_model_outputs = init_model(**tokens) - self.assertIsInstance(init_model.model, torch.jit.RecursiveScriptModule) # Compare tensor outputs for output_name in {"logits", "last_hidden_state"}: if output_name in transformers_outputs: - self.assertTrue(torch.allclose(outputs[output_name], transformers_outputs[output_name], atol=1e-4)) + self.assertTrue(torch.allclose(outputs[output_name], transformers_outputs[output_name], atol=1e-3)) self.assertTrue(torch.allclose(outputs[output_name], loaded_model_outputs[output_name])) self.assertTrue(torch.allclose(outputs[output_name], init_model_outputs[output_name])) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) + model = self.IPEX_MODEL_CLASS.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline(self.IPEX_MODEL_CLASS.export_feature, model=model, tokenizer=tokenizer) text = "This restaurant is awesome" @@ -144,12 +147,13 @@ class IPEXModelForQuestionAnsweringTest(unittest.TestCase): def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] set_seed(SEED) - ipex_model = IPEXModelForQuestionAnswering.from_pretrained(model_id, export=True) + ipex_model = IPEXModelForQuestionAnswering.from_pretrained(model_id) + device = ipex_model.device self.assertIsInstance(ipex_model.config, PretrainedConfig) - transformers_model = AutoModelForQuestionAnswering.from_pretrained(model_id) + transformers_model = AutoModelForQuestionAnswering.from_pretrained(model_id).to(device) tokenizer = AutoTokenizer.from_pretrained(model_id) inputs = "This is a sample input" - tokens = tokenizer(inputs, return_tensors="pt") + tokens = tokenizer(inputs, return_tensors="pt").to(device) with torch.no_grad(): transformers_outputs = transformers_model(**tokens) outputs = ipex_model(**tokens) @@ -161,9 +165,8 @@ def test_compare_to_transformers(self, model_arch): loaded_model_outputs = loaded_model(**tokens) # Test init method - init_model = self.IPEX_MODEL_CLASS(transformers_model, export=True) + init_model = self.IPEX_MODEL_CLASS(transformers_model) init_model_outputs = init_model(**tokens) - self.assertIsInstance(init_model.model, torch.jit.RecursiveScriptModule) self.assertIn("start_logits", outputs) self.assertIn("end_logits", outputs) @@ -178,7 +181,7 @@ def test_compare_to_transformers(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = IPEXModelForQuestionAnswering.from_pretrained(model_id, export=True) + model = IPEXModelForQuestionAnswering.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("question-answering", model=model, tokenizer=tokenizer) question = "What's my name?" @@ -188,11 +191,8 @@ def test_pipeline(self, model_arch): self.assertGreaterEqual(outputs["score"], 0.0) self.assertIsInstance(outputs["answer"], str) - @unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching") def test_patched_model(self): - ipex_model = IPEXModelForQuestionAnswering.from_pretrained( - "Jiqing/patched_tiny_random_bert_for_question_answering" - ) + ipex_model = IPEXModelForQuestionAnswering.from_pretrained("Intel/tiny-random-bert_ipex_model") transformers_model = AutoModelForQuestionAnswering.from_pretrained("hf-internal-testing/tiny-random-bert") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") inputs = "This is a sample input" @@ -225,7 +225,7 @@ class IPEXModelForCausalLMTest(unittest.TestCase): "mpt", "opt", ) - IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama2", "distilgpt2", "falcon") + IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama2", "falcon", "gpt2") GENERATION_LENGTH = 100 SPEEDUP_CACHE = 1.0 @@ -233,9 +233,10 @@ class IPEXModelForCausalLMTest(unittest.TestCase): def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] set_seed(SEED) - ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + # Test model forward do not need cache. + ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype) self.assertIsInstance(ipex_model.config, PretrainedConfig) - self.assertTrue(ipex_model.use_cache) tokenizer = AutoTokenizer.from_pretrained(model_id) tokens = tokenizer( "This is a sample", @@ -246,22 +247,20 @@ def test_compare_to_transformers(self, model_arch): outputs = ipex_model(**inputs) self.assertIsInstance(outputs.logits, torch.Tensor) - self.assertIsInstance(outputs.past_key_values, (tuple, list)) - transformers_model = AutoModelForCausalLM.from_pretrained(model_id) + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype) with torch.no_grad(): transformers_outputs = transformers_model(**tokens) # Test re-load model with tempfile.TemporaryDirectory() as tmpdirname: ipex_model.save_pretrained(tmpdirname) - loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, torch_dtype=dtype) loaded_model_outputs = loaded_model(**inputs) # Test init method - init_model = self.IPEX_MODEL_CLASS(transformers_model, export=True) + init_model = self.IPEX_MODEL_CLASS(transformers_model) init_model_outputs = init_model(**inputs) - self.assertIsInstance(init_model.model, torch.jit.RecursiveScriptModule) # Compare tensor outputs self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4)) @@ -271,26 +270,30 @@ def test_compare_to_transformers(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 model_id = MODEL_NAMES[model_arch] tokenizer = AutoTokenizer.from_pretrained(model_id) - model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) + model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype) model.config.encoder_no_repeat_ngram_size = 0 - model.to("cpu") + # model.to("cpu") pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) outputs = pipe("This is a sample", max_new_tokens=10) self.assertEqual(pipe.device, model.device) self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs)) @parameterized.expand(SUPPORTED_ARCHITECTURES) + @unittest.skip(reason="Paged attention do not support assisted decoding for now") def test_assisted_decoding(self, model_arch): - # Patched models are not support assisted decoding if ipex < 2.5. - if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES and is_ipex_version("<", "2.4.0"): + # assist decoding does not support static cache now + if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES: return model_id = MODEL_NAMES[model_arch] + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 tokenizer = AutoTokenizer.from_pretrained(model_id) - ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) - transformers_model = AutoModelForCausalLM.from_pretrained(model_id) - tokens = tokenizer("This is a sample input", return_tensors="pt") + ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype) + device = ipex_model.device + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device) + tokens = tokenizer("This is a sample input", return_tensors="pt").to(device) ipex_output = ipex_model.generate(**tokens, do_sample=False, max_new_tokens=4) ipex_output_assisted = ipex_model.generate( **tokens, do_sample=False, assistant_model=transformers_model, max_new_tokens=4 @@ -309,17 +312,20 @@ def test_assisted_decoding(self, model_arch): @parameterized.expand( grid_parameters( { - "model_arch": IPEX_PATCHED_SUPPORTED_ARCHITECTURES, + "model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True, False], } ) ) - @unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching") - def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache): + def test_ipex_beam_search(self, test_name, model_arch, use_cache): model_id = MODEL_NAMES[model_arch] set_seed(SEED) - model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, use_cache=use_cache) - trasnformers_model = AutoModelForCausalLM.from_pretrained(model_id) + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + model = IPEXModelForCausalLM.from_pretrained(model_id, use_cache=use_cache, torch_dtype=dtype) + if use_cache and model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES: + self.assertTrue(model.add_patch) + device = model.device + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device) self.assertEqual(model.use_cache, use_cache) tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token @@ -335,46 +341,27 @@ def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache): ), ) for text in texts: - tokens = tokenizer(text, padding=True, return_tensors="pt") + tokens = tokenizer(text, padding=True, return_tensors="pt").to(device) for generation_config in generation_configs: outputs = model.generate(**tokens, generation_config=generation_config) - transformers_outputs = trasnformers_model.generate(**tokens, generation_config=generation_config) + transformers_outputs = transformers_model.generate(**tokens, generation_config=generation_config) self.assertIsInstance(outputs, torch.Tensor) self.assertTrue(torch.equal(outputs, transformers_outputs)) - @parameterized.expand(IPEX_PATCHED_SUPPORTED_ARCHITECTURES) - @unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching") - def test_patched_model(self, model_arch): - model_id = MODEL_NAMES[model_arch] - patched_model_id = MODEL_NAMES["patched_" + model_arch] - ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) - exported_model = IPEXModelForCausalLM.from_pretrained(patched_model_id) - tokenizer = AutoTokenizer.from_pretrained(model_id) - tokens = tokenizer( - "This is a sample", - return_tensors="pt", - return_token_type_ids=False if model_arch in ("llama", "llama2") else None, - ) - inputs = ipex_model.prepare_inputs_for_generation(**tokens) - ipex_outputs = ipex_model(**inputs) - exported_outputs = exported_model(**inputs) - self.assertTrue(torch.allclose(ipex_outputs.logits, exported_outputs.logits, atol=1e-7)) - def test_compare_with_and_without_past_key_values(self): - model_id = "echarlaix/tiny-random-gpt2-torchscript" + model_id = "Intel/tiny_random_llama2_ipex_model" + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + model_with_pkv = IPEXModelForCausalLM.from_pretrained(model_id, use_cache=True, torch_dtype=dtype) + device = model_with_pkv.device tokenizer = AutoTokenizer.from_pretrained(model_id) - tokens = tokenizer("This is a sample input", return_tensors="pt") - - model_with_pkv = IPEXModelForCausalLM.from_pretrained(model_id, use_cache=True, subfolder="model_with_pkv") + tokens = tokenizer("This is a sample input", return_tensors="pt").to(device) # Warmup model_with_pkv.generate(**tokens) with Timer() as with_pkv_timer: outputs_model_with_pkv = model_with_pkv.generate( **tokens, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1 ) - model_without_pkv = IPEXModelForCausalLM.from_pretrained( - model_id, use_cache=False, subfolder="model_without_pkv" - ) + model_without_pkv = IPEXModelForCausalLM.from_pretrained(model_id, use_cache=False, torch_dtype=dtype) # Warmup model_without_pkv.generate(**tokens) with Timer() as without_pkv_timer: @@ -385,6 +372,22 @@ def test_compare_with_and_without_past_key_values(self): self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH + tokens.input_ids.shape[1]) self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH + tokens.input_ids.shape[1]) + @parameterized.expand(IPEX_PATCHED_SUPPORTED_ARCHITECTURES) + def test_patched_model(self, model_arch): + model_id = MODEL_NAMES[model_arch] + patched_model_id = MODEL_NAMES["patched_" + model_arch] + ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) + exported_model = IPEXModelForCausalLM.from_pretrained(patched_model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer("This is a sample", return_tensors="pt") + ipex_outputs = ipex_model.generate( + **tokens, max_new_tokens=1, return_dict_in_generate=True, output_logits=True + ) + exported_outputs = exported_model.generate( + **tokens, max_new_tokens=1, return_dict_in_generate=True, output_logits=True + ) + self.assertTrue(torch.allclose(ipex_outputs.logits[0], exported_outputs.logits[0], atol=1e-7)) + class IPEXModelForAudioClassificationTest(unittest.TestCase): IPEX_MODEL_CLASS = IPEXModelForAudioClassification @@ -403,11 +406,12 @@ def _generate_random_audio_data(self): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] - ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) + ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id) + device = ipex_model.device self.assertIsInstance(ipex_model.config, PretrainedConfig) - transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id) + transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id).to(device) preprocessor = AutoFeatureExtractor.from_pretrained(model_id) - inputs = preprocessor(self._generate_random_audio_data(), return_tensors="pt") + inputs = preprocessor(self._generate_random_audio_data(), return_tensors="pt").to(device) with torch.no_grad(): transformers_outputs = transformers_model(**inputs) outputs = ipex_model(**inputs) @@ -419,9 +423,8 @@ def test_compare_to_transformers(self, model_arch): loaded_model_outputs = loaded_model(**inputs) # Test init method - init_model = self.IPEX_MODEL_CLASS(transformers_model, export=True) + init_model = self.IPEX_MODEL_CLASS(transformers_model) init_model_outputs = init_model(**inputs) - self.assertIsInstance(init_model.model, torch.jit.RecursiveScriptModule) # Compare tensor outputs self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-3)) @@ -431,7 +434,7 @@ def test_compare_to_transformers(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) + model = self.IPEX_MODEL_CLASS.from_pretrained(model_id) preprocessor = AutoFeatureExtractor.from_pretrained(model_id) pipe = pipeline("audio-classification", model=model, feature_extractor=preprocessor) outputs = pipe([np.random.random(16000)]) @@ -443,25 +446,28 @@ class IPEXModelForImageClassificationIntegrationTest(unittest.TestCase): IPEX_MODEL_CLASS = IPEXModelForImageClassification SUPPORTED_ARCHITECTURES = ( "beit", - # "levit", "mobilenet_v1", "mobilenet_v2", "mobilevit", "resnet", "vit", ) + IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("vit",) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] set_seed(SEED) - ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) + ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id) + if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES: + self.assertTrue(ipex_model.add_patch) + device = ipex_model.device self.assertIsInstance(ipex_model.config, PretrainedConfig) - transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id) + transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id).to(device) preprocessor = AutoFeatureExtractor.from_pretrained(model_id) url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) - inputs = preprocessor(images=image, return_tensors="pt") + inputs = preprocessor(images=image, return_tensors="pt").to(device) with torch.no_grad(): transformers_outputs = transformers_model(**inputs) outputs = ipex_model(**inputs) @@ -473,20 +479,19 @@ def test_compare_to_transformers(self, model_arch): loaded_model_outputs = loaded_model(**inputs) # Test init method - init_model = self.IPEX_MODEL_CLASS(transformers_model, export=True) + init_model = self.IPEX_MODEL_CLASS(transformers_model) init_model_outputs = init_model(**inputs) - self.assertIsInstance(init_model.model, torch.jit.RecursiveScriptModule) self.assertIn("logits", outputs) # Compare tensor outputs self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4)) - self.assertTrue(torch.equal(outputs.logits, loaded_model_outputs.logits)) + self.assertTrue(torch.allclose(outputs.logits, loaded_model_outputs.logits, atol=1e-4)) self.assertTrue(torch.allclose(init_model_outputs.logits, transformers_outputs.logits, atol=1e-4)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) + model = self.IPEX_MODEL_CLASS.from_pretrained(model_id) preprocessor = AutoFeatureExtractor.from_pretrained(model_id) pipe = pipeline("image-classification", model=model, feature_extractor=preprocessor) outputs = pipe("http://images.cocodataset.org/val2017/000000039769.jpg") @@ -494,11 +499,8 @@ def test_pipeline(self, model_arch): self.assertGreaterEqual(outputs[0]["score"], 0.0) self.assertTrue(isinstance(outputs[0]["label"], str)) - @unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching") def test_patched_model(self): - ipex_model = IPEXModelForImageClassification.from_pretrained( - "Jiqing/patched_tiny_random_vit_for_image_classification" - ) + ipex_model = IPEXModelForImageClassification.from_pretrained("Intel/tiny-random-vit_ipex_model") transformers_model = self.IPEX_MODEL_CLASS.from_pretrained("hf-internal-testing/tiny-random-vit") preprocessor = AutoFeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-vit") url = "http://images.cocodataset.org/val2017/000000039769.jpg" diff --git a/tests/ipex/test_pipelines.py b/tests/ipex/test_pipelines.py index 767097a5dd..77790e19f4 100644 --- a/tests/ipex/test_pipelines.py +++ b/tests/ipex/test_pipelines.py @@ -20,7 +20,7 @@ from parameterized import parameterized from transformers import AutoTokenizer from transformers.pipelines import pipeline as transformers_pipeline -from utils_tests import MODEL_NAMES +from utils_tests import IS_XPU_AVAILABLE, MODEL_NAMES from optimum.intel.ipex.modeling_base import ( IPEXModelForAudioClassification, @@ -34,6 +34,9 @@ from optimum.intel.pipelines import pipeline as ipex_pipeline +torch.use_deterministic_algorithms(True) + + class PipelinesIntegrationTest(unittest.TestCase): COMMON_SUPPORTED_ARCHITECTURES = ( "albert", @@ -92,7 +95,6 @@ def test_token_classification_pipeline_inference(self, model_arch): ipex_output = ipex_generator(inputs) self.assertEqual(len(transformers_output), len(ipex_output)) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForTokenClassification)) - self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) for i in range(len(transformers_output)): self.assertAlmostEqual(transformers_output[i]["score"], ipex_output[i]["score"], delta=1e-4) @@ -107,7 +109,6 @@ def test_sequence_classification_pipeline_inference(self, model_arch): with torch.inference_mode(): ipex_output = ipex_generator(inputs) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSequenceClassification)) - self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) self.assertEqual(transformers_output[0]["label"], ipex_output[0]["label"]) self.assertAlmostEqual(transformers_output[0]["score"], ipex_output[0]["score"], delta=1e-4) @@ -125,7 +126,6 @@ def test_fill_mask_pipeline_inference(self, model_arch): ipex_output = ipex_generator(inputs) self.assertEqual(len(transformers_output), len(ipex_output)) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForMaskedLM)) - self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) for i in range(len(transformers_output)): self.assertEqual(transformers_output[i]["token"], ipex_output[i]["token"]) self.assertAlmostEqual(transformers_output[i]["score"], ipex_output[i]["score"], delta=1e-4) @@ -133,15 +133,15 @@ def test_fill_mask_pipeline_inference(self, model_arch): @parameterized.expand(TEXT_GENERATION_SUPPORTED_ARCHITECTURES) def test_text_generation_pipeline_inference(self, model_arch): model_id = MODEL_NAMES[model_arch] - transformers_generator = transformers_pipeline("text-generation", model_id) - ipex_generator = ipex_pipeline("text-generation", model_id, accelerator="ipex") + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + transformers_generator = transformers_pipeline("text-generation", model_id, torch_dtype=dtype) + ipex_generator = ipex_pipeline("text-generation", model_id, accelerator="ipex", torch_dtype=dtype) inputs = "Describe a real-world application of AI." with torch.inference_mode(): - transformers_output = transformers_generator(inputs, max_new_tokens=10) + transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10) with torch.inference_mode(): - ipex_output = ipex_generator(inputs, max_new_tokens=10) + ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForCausalLM)) - self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) self.assertEqual(transformers_output[0]["generated_text"], ipex_output[0]["generated_text"]) @parameterized.expand(QUESTION_ANSWERING_SUPPORTED_ARCHITECTURES) @@ -156,7 +156,6 @@ def test_question_answering_pipeline_inference(self, model_arch): with torch.inference_mode(): ipex_output = ipex_generator(question=question, context=context) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForQuestionAnswering)) - self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) self.assertAlmostEqual(transformers_output["score"], ipex_output["score"], delta=1e-4) self.assertEqual(transformers_output["start"], ipex_output["start"]) self.assertEqual(transformers_output["end"], ipex_output["end"]) @@ -172,7 +171,6 @@ def test_audio_classification_pipeline_inference(self, model_arch): with torch.inference_mode(): ipex_output = ipex_generator(inputs) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForAudioClassification)) - self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) self.assertAlmostEqual(transformers_output[0][0]["score"], ipex_output[0][0]["score"], delta=1e-2) self.assertAlmostEqual(transformers_output[0][1]["score"], ipex_output[0][1]["score"], delta=1e-2) @@ -188,7 +186,6 @@ def test_image_classification_pipeline_inference(self, model_arch): ipex_output = ipex_generator(inputs) self.assertEqual(len(transformers_output), len(ipex_output)) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForImageClassification)) - self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) for i in range(len(transformers_output)): self.assertEqual(transformers_output[i]["label"], ipex_output[i]["label"]) self.assertAlmostEqual(transformers_output[i]["score"], ipex_output[i]["score"], delta=1e-4) @@ -196,20 +193,19 @@ def test_image_classification_pipeline_inference(self, model_arch): @parameterized.expand(COMMON_SUPPORTED_ARCHITECTURES) def test_pipeline_load_from_ipex_model(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = IPEXModelForSequenceClassification.from_pretrained(model_id, export=True) + model = IPEXModelForSequenceClassification.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) ipex_generator = ipex_pipeline("text-classification", model, tokenizer=tokenizer, accelerator="ipex") inputs = "This restaurant is awesome" with torch.inference_mode(): ipex_output = ipex_generator(inputs) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSequenceClassification)) - self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) self.assertGreaterEqual(ipex_output[0]["score"], 0.0) @parameterized.expand(COMMON_SUPPORTED_ARCHITECTURES) def test_pipeline_load_from_jit_model(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = IPEXModelForSequenceClassification.from_pretrained(model_id, export=True) + model = IPEXModelForSequenceClassification.from_pretrained(model_id) save_dir = TemporaryDirectory().name model.save_pretrained(save_dir) tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -218,5 +214,4 @@ def test_pipeline_load_from_jit_model(self, model_arch): with torch.inference_mode(): ipex_output = ipex_generator(inputs) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSequenceClassification)) - self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) self.assertGreaterEqual(ipex_output[0]["score"], 0.0) diff --git a/tests/ipex/utils_tests.py b/tests/ipex/utils_tests.py index 595bc0246f..e92ef37fd6 100644 --- a/tests/ipex/utils_tests.py +++ b/tests/ipex/utils_tests.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from transformers import is_torch_xpu_available +IS_XPU_AVAILABLE = is_torch_xpu_available(check_device=True) + MODEL_NAMES = { "albert": "hf-internal-testing/tiny-random-albert", "beit": "hf-internal-testing/tiny-random-BeitForImageClassification", @@ -25,18 +28,18 @@ "codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM", "convnext": "hf-internal-testing/tiny-random-convnext", "distilbert": "hf-internal-testing/tiny-random-distilbert", - "distilgpt2": "Jiqing/tiny_random_distilgpt2", + "distilgpt2": "Intel/tiny-random-distilgpt2", "electra": "hf-internal-testing/tiny-random-electra", "flaubert": "hf-internal-testing/tiny-random-flaubert", - "falcon": "Jiqing/tiny_random_falcon", + "falcon": "Intel/tiny-random-falcon", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", - "gpt2": "hf-internal-testing/tiny-random-gpt2", + "gpt2": "Intel/tiny-random-gpt2", "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", "gptj": "hf-internal-testing/tiny-random-GPTJModel", "levit": "hf-internal-testing/tiny-random-LevitModel", "llama": "fxmarty/tiny-llama-fast-tokenizer", - "llama2": "Jiqing/tiny_random_llama2", + "llama2": "Intel/tiny-random-llama2", "marian": "sshleifer/tiny-marian-en-de", "mbart": "hf-internal-testing/tiny-random-mbart", "mistral": "echarlaix/tiny-random-mistral", @@ -56,7 +59,7 @@ "vit": "hf-internal-testing/tiny-random-vit", "wav2vec2": "anton-l/wav2vec2-random-tiny-classifier", "xlm": "hf-internal-testing/tiny-random-xlm", - "patched_falcon": "Jiqing/patched_tiny_random_falcon_for_causal_lm", - "patched_distilgpt2": "Jiqing/patched_tiny_random_distilgpt2_for_causal_lm", - "patched_llama2": "Jiqing/patched_tiny_random_llama2_for_causal_lm", + "patched_falcon": "Intel/tiny-random-falcon_ipex_model", + "patched_gpt2": "Intel/tiny-random-gpt2_ipex_model", + "patched_llama2": "Intel/tiny-random-llama2_ipex_model", } diff --git a/tests/neural_compressor/test_ipex.py b/tests/neural_compressor/test_ipex.py index ef1f19812e..2a230f23dd 100644 --- a/tests/neural_compressor/test_ipex.py +++ b/tests/neural_compressor/test_ipex.py @@ -52,7 +52,7 @@ class IPEXQuantizationTest(INCTestMixin): SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (("text-classification", "bert", 21),) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) - def test_ipex_static_quantization_with_smoothquant(self, task, model_arch, expected_quantized_matmuls): + def test_static_quantization_with_smoothquant(self, task, model_arch, expected_quantized_matmuls): recipes = {"smooth_quant": True, "smooth_quant_args": {"alpha": 0.5}} num_samples = 10 model_name = MODEL_NAMES[model_arch] @@ -79,5 +79,5 @@ def test_ipex_static_quantization_with_smoothquant(self, task, model_arch, expec is_static=True, num_samples=num_samples, load_inc_model=False, - load_ipex_model=True, + load_ipex_model=False, )