diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 3d28350b8..ba866c5c7 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -17,6 +17,7 @@ from typing import List, Optional, Tuple, Union import torch +from intel_extension_for_pytorch.llm.functional import rms_norm from torch import nn from torch.nn import functional as F from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask @@ -32,7 +33,6 @@ _IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.3.0" - if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): logger.warning( f"Please upgrade the IPEX version to at least {_IPEX_MINIMUM_VERSION_FOR_PATCHING} if you want to patch the model." @@ -48,9 +48,48 @@ ) -# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83 +def matmul_add_add(attn_output, weight, bias=None, residual=None): + seq_len, bs, _ = attn_output.size() + if residual is None: + attn_output = torch.matmul(attn_output, weight) + if bias is not None: + attn_output += bias + else: + if bias is not None: + attn_output = torch.ops.torch_ipex.mm_bias_resadd(attn_output, weight, bias, 1.0, residual, 1.0) + else: + attn_output = torch.addmm( + residual.flatten(0, -2), + attn_output.flatten(0, -2), + weight, + beta=1.0, + ) + attn_output = attn_output.view(seq_len, bs, -1) + return attn_output + + +def padding_attn_mask(attn_mask, alignment): + if attn_mask is None: + return None + assert isinstance( + attn_mask, torch.Tensor + ), f"attn mask is supposed to be a tensor, instead we got {type(attn_mask)}" + if attn_mask.device == torch.device("cpu"): + return attn_mask + last_dim_size = attn_mask.size(-1) + aligned_size = (last_dim_size + alignment - 1) // alignment * alignment + mask_size = [*attn_mask.size()[:-1], aligned_size] + new_attn_mask = torch.empty(mask_size, dtype=attn_mask.dtype, device=attn_mask.device).fill_(-65504.0) + new_attn_mask[..., :last_dim_size] = attn_mask + return new_attn_mask + + def _ipex_rms_layer_norm_forward(self, hidden_states): - return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon) + if hidden_states.device.type == "xpu": + return rms_norm(hidden_states, self.weight, self.variance_epsilon) + else: + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83 + return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon) # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130 @@ -108,6 +147,8 @@ def _llama_model_forward( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) + attention_mask = padding_attn_mask(attention_mask, 8) + # embed positions hidden_states = inputs_embeds @@ -115,12 +156,20 @@ def _llama_model_forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None - + if hidden_states.device.type == "xpu": + seqlen = hidden_states.size(1) + head_dim = self.layers[0].self_attn.head_dim + sin, cos = self.layers[0].self_attn.ipex_rope.get_sin_cos(seqlen, head_dim // 2) + sin = sin.squeeze()[position_ids].unsqueeze(2) + cos = cos.squeeze()[position_ids].unsqueeze(2) + decoder_layer_kwargs = {"sin": sin, "cos": cos} + else: + decoder_layer_kwargs = {} for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None + past_key_value = past_key_values[idx] if past_key_values is not None and len(past_key_values) > idx else None layer_outputs = decoder_layer( hidden_states, @@ -129,6 +178,7 @@ def _llama_model_forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + **decoder_layer_kwargs, ) hidden_states = layer_outputs[0] @@ -174,14 +224,82 @@ def __init__(self, module, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config - self.ipex_scale_dot_product = IndirectAccessKVCacheAttention(text_max_length=config.max_position_embeddings) - if hasattr(config, "rope_theta"): - self.ipex_rope = RotaryEmbedding( - config.max_position_embeddings, - config.hidden_size // config.num_attention_heads, - config.rope_theta, - config.architectures[0], + self.module_device = next(module.parameters()).device.type + if self.module_device == "xpu": + from intel_extension_for_pytorch.transformers.models.xpu.fusions.mha_fusion import _IPEXRopeXPU + + self.ipex_rope = _IPEXRopeXPU( + module.config.max_position_embeddings, + module.config.hidden_size // module.config.num_attention_heads, + module.config.rope_theta, + module.config.architectures[0], ) + self.port_parameters(module) + torch.xpu.empty_cache() + else: + self.ipex_scale_dot_product = IndirectAccessKVCacheAttention( + text_max_length=config.max_position_embeddings + ) + if hasattr(config, "rope_theta"): + self.ipex_rope = RotaryEmbedding( + config.max_position_embeddings, + config.hidden_size // config.num_attention_heads, + config.rope_theta, + config.architectures[0], + ) + + def port_parameters(self, module): + self.qkv_proj_bias = None + self.qkv_proj_weight = None + if self.num_heads == self.num_key_value_heads: + q_proj = module.q_proj.weight.transpose(0, 1) + k_proj = module.k_proj.weight.transpose(0, 1) + v_proj = module.v_proj.weight.transpose(0, 1) + self.qkv_proj_weight = torch.stack([q_proj, k_proj, v_proj]).contiguous().view([3, -1, q_proj.shape[-1]]) + module.q_proj.weight.data = self.qkv_proj_weight[0, :, :].transpose(0, 1) + module.k_proj.weight.data = self.qkv_proj_weight[1, :, :].transpose(0, 1) + module.v_proj.weight.data = self.qkv_proj_weight[2, :, :].transpose(0, 1) + if module.q_proj.bias is not None: + self.qkv_proj_bias = ( + torch.stack([module.q_proj.bias, module.k_proj.bias, module.v_proj.bias]) + .contiguous() + .view([3, -1]) + ) + module.q_proj.bias.data = self.qkv_proj_bias[0] + module.k_proj.bias.data = self.qkv_proj_bias[1] + module.v_proj.bias.data = self.qkv_proj_bias[2] + else: + q_proj = module.q_proj.weight.view( + self.num_key_value_heads, self.num_key_value_groups, self.head_dim, self.hidden_size + ) + k_proj = module.k_proj.weight.view(self.num_key_value_heads, 1, self.head_dim, self.hidden_size) + v_proj = module.v_proj.weight.view(self.num_key_value_heads, 1, self.head_dim, self.hidden_size) + self.qkv_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=1).view( + [self.num_key_value_heads, self.num_key_value_groups + 2, self.head_dim, self.hidden_size] + ) + module.q_proj.data = self.qkv_proj_weight[:, : self.num_key_value_groups, :, :].reshape( + [self.num_key_value_heads * self.num_key_value_groups * self.head_dim, self.hidden_size] + ) + module.k_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups, :, :].reshape( + [self.num_key_value_heads * self.head_dim, self.hidden_size] + ) + module.v_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups + 1, :, :].reshape( + [self.num_key_value_heads * self.head_dim, self.hidden_size] + ) + self.qkv_proj_weight = self.qkv_proj_weight.permute(3, 0, 1, 2).contiguous() + if module.q_proj.bias is not None: + q_bias = module.q_proj.bias.view(self.num_key_value_heads, self.num_key_value_groups, self.head_dim) + k_bias = module.k_proj.bias.view(self.num_key_value_heads, 1, self.head_dim) + v_bias = module.v_proj.bias.view(self.num_key_value_heads, 1, self.head_dim) + self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view( + [self.num_key_value_heads, self.num_key_value_groups + 2, self.head_dim] + ) + module.q_proj.bias.data = self.qkv_proj_bias[:, : self.num_key_value_groups, self.head_dim].view(-1) + module.k_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups, self.head_dim].view(-1) + module.v_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups + 1, self.head_dim].view(-1) + self.o_proj_weight = module.o_proj.weight.transpose(0, 1).contiguous() + module.o_proj.weight.data = self.o_proj_weight.transpose(0, 1) + self.o_proj_bias = module.o_proj.bias def qkv_gemm(self, hidden_states): raise NotImplementedError("Need to implement in specific model class") @@ -192,16 +310,25 @@ def rope(self, *args, **kwargs): def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask, **kwargs): # This ipex op pre-allocates buffers for past_key_values and use beam index history # which to decide which beam should be used to make attention scale dot more efficient. - (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( - query, - key, - value, - math.sqrt(self.head_dim), - past_key_value, - kwargs.get("head_mask", None), - attention_mask, - kwargs.get("alibi", None), - ) + if self.module_device == "xpu": + scale = 1.0 / math.sqrt(self.head_dim) + is_causal = False + attn_output = torch.xpu.IpexSDP( + query, key, value, None, attention_mask, None, scale, 1.0, 0.0, is_causal, False + ) + attn_weights = None + past_key_value = (key, value) + else: + (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( + query, + key, + value, + math.sqrt(self.head_dim), + past_key_value, + kwargs.get("head_mask", None), + attention_mask, + kwargs.get("alibi", None), + ) return attn_output, past_key_value, attn_weights def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask, **kwargs): @@ -235,10 +362,18 @@ def forward( qkv_out = self.qkv_gemm(hidden_states) if isinstance(qkv_out, tuple) and len(qkv_out) == 3: query, key, value = self.qkv_gemm(hidden_states) - query, key = self.rope(query, key, kv_seq_len, use_cache, position_ids=position_ids) + query, key = self.rope(query, key, kv_seq_len, use_cache, position_ids, **kwargs) else: query, key, value = self.rope(qkv_out, kv_seq_len, use_cache, past_len=past_len) + if self.module_device == "xpu": + if past_key_value is not None: + key = torch.cat([past_key_value[0].transpose(1, 2), key], dim=1) + value = torch.cat([past_key_value[1].transpose(1, 2), value], dim=1) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + attention_mask = self.prepare_attention_mask_float(attention_mask, query.dtype) sdpa = self.sdpa_with_cache if use_cache else self.sdpa_without_cache attn_output, past_key_value, attn_weights = sdpa( @@ -251,6 +386,7 @@ def forward( head_mask=kwargs.get("head_mask", None), alibi=kwargs.get("alibi", None), ) + attn_output = self.postprocess_attention_output(attn_output, bsz, seq_len) if not output_attentions: @@ -262,9 +398,10 @@ def forward( class _IPEXLlamaAttention(_IPEXAttention): def __init__(self, module, config) -> None: super().__init__(module, config) - if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: - self.mha_linear_add = LinearAdd(module.o_proj) - del self.__dict__["_modules"]["o_proj"] + if self.module_device == "cpu": + if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: + self.mha_linear_add = LinearAdd(module.o_proj) + del self.__dict__["_modules"]["o_proj"] def qkv_gemm(self, hidden_states): bsz, seq_len, _ = hidden_states.size() @@ -274,11 +411,16 @@ def qkv_gemm(self, hidden_states): return query, key, value - def rope(self, query, key, kv_seq_len, use_cache, position_ids): - if use_cache: - args = (self.head_dim, self.head_dim // 2, self.head_dim, kv_seq_len) - key = self.ipex_rope(key, position_ids, self.num_key_value_heads, *args) - query = self.ipex_rope(query, position_ids, self.num_heads, *args) + def rope(self, query, key, kv_seq_len, use_cache, position_ids, **kwargs): + if self.module_device == "xpu": + sin = kwargs.pop("sin", None) + cos = kwargs.pop("cos", None) + self.ipex_rope.apply_embedding(query, sin, cos, self.head_dim // 2, key) + else: + if use_cache: + args = (self.head_dim, self.head_dim // 2, self.head_dim, kv_seq_len) + key = self.ipex_rope(key, position_ids, self.num_key_value_heads, *args) + query = self.ipex_rope(query, position_ids, self.num_heads, *args) return query, key # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L341 @@ -372,28 +514,52 @@ def __init__(self, module, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config - # LinearAllreduce and LinearLayer cannot use fused op LinearAdd - if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]: - self.mlp_linear_add = LinearAdd(module.down_proj) - del self.__dict__["_modules"]["down_proj"] - self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) - del self.__dict__["_modules"]["gate_proj"] - del self.__dict__["_modules"]["up_proj"] + self.module_device = next(module.parameters()).device.type + if self.module_device == "xpu": + self.port_parameter(module) + torch.xpu.empty_cache() + else: + # LinearAllreduce and LinearLayer cannot use fused op LinearAdd + if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]: + self.mlp_linear_add = LinearAdd(module.down_proj) + del self.__dict__["_modules"]["down_proj"] + self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) + del self.__dict__["_modules"]["gate_proj"] + del self.__dict__["_modules"]["up_proj"] def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs): - if hasattr(self, "linear_silu_mul"): - mlp_gate = self.linear_silu_mul(hidden_states) - if hasattr(self, "mlp_linear_add"): - hidden_states = self.mlp_linear_add(mlp_gate, residual) + if self.module_device == "xpu": + up = torch.ops.torch_ipex.mm_silu(hidden_states, self.gate_proj_weight) + hidden_states = torch.ops.torch_ipex.mm_resmul(hidden_states, self.up_proj_weight, up) + hidden_states = matmul_add_add(hidden_states, self.down_proj_weight, self.down_proj_bias, residual) + else: + if hasattr(self, "linear_silu_mul"): + mlp_gate = self.linear_silu_mul(hidden_states) + if hasattr(self, "mlp_linear_add"): + hidden_states = self.mlp_linear_add(mlp_gate, residual) + else: + hidden_states = self.down_proj( + self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states) + ) + hidden_states = residual + hidden_states else: - hidden_states = self.down_proj(mlp_gate) + hidden_states = self.down_proj( + self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states) + ) hidden_states = residual + hidden_states - else: - hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) - hidden_states = residual + hidden_states - return hidden_states + def port_parameter(self, module): + self.up_proj_weight = module.up_proj.weight.transpose(0, 1).contiguous() + module.up_proj.weight.data = self.up_proj_weight.transpose(0, 1) + self.gate_proj_weight = module.gate_proj.weight.transpose(0, 1).contiguous() + module.gate_proj.weight.data = self.gate_proj_weight.transpose(0, 1) + self.down_proj_weight = module.down_proj.weight.transpose(0, 1).contiguous() + module.down_proj.weight.data = self.down_proj_weight.transpose(0, 1) + self.up_proj_bias = module.up_proj.bias + self.gate_proj_bias = module.gate_proj.bias + self.down_proj_bias = module.down_proj.bias + class _IPEXFalconMLP(nn.Module): def __init__(self, module, config) -> None: diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 739a2f2b4..0d275777a 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -178,12 +178,12 @@ def __init__( else: self._device = torch.device("cpu") + config = model.config if config is None else config # CPU only support jit model for now. - if export: + if export and self._device.type == "cpu": if isinstance(model, torch.jit.RecursiveScriptModule): logger.warning("The model has been exported already.") else: - config = model.config if config is None else config use_cache = kwargs.get("use_cache", True) model = ipex_jit_trace(model, self.export_feature, use_cache) config.torchscript = True @@ -291,7 +291,7 @@ def _from_pretrained( logger.warning("Detect torchscript is false. Convert to torchscript model!") if is_torch_version("<", "2.1.0"): - raise ImportError("`torch>=2.0.0` is needed to trace your model") + raise ImportError("`torch>=2.1.0` is needed to trace your model") task = cls.export_feature config.torch_dtype = torch_dtype @@ -304,6 +304,15 @@ def _from_pretrained( _commit_hash=commit_hash, **model_kwargs, ) + if is_torch_xpu_available(check_device=True): + model.to("xpu:0") + if _is_patched_with_ipex(model, task): + model = _patch_model(model) + else: + use_cache = kwargs.get("use_cache", True) + model = ipex_jit_trace(model, task, use_cache) + config.torchscript = True + config.torch_dtype = torch_dtype return cls(model, config=config, export=True, **kwargs) @@ -529,7 +538,7 @@ def __init__( except AttributeError: self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping) - if self._is_ipex_exported: + if self._is_ipex_exported and self._device.type == "cpu": self._reorder_cache = _ipex_reorder_cache else: # Check if _reorder_cache is a static method @@ -646,7 +655,7 @@ def forward( inputs["position_ids"] = position_ids if self.use_cache: - if past_key_values is None: + if past_key_values is None and self._device.type == "cpu": past_key_values = self._prepare_past_key_values(input_ids) inputs["past_key_values"] = past_key_values @@ -760,6 +769,19 @@ def _ipex_prepare_inputs_for_generation( return model_inputs +def _ipex_crop_past_key_values(model, past_key_values, max_length): + if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"): + new_past_key_values = [] + for i in range(len(past_key_values)): + pkv = [] + pkv.append(past_key_values[i][0][:, :max_length, :max_length, :]) + pkv += [past_key_values[i][_] for _ in range(1, 4)] + new_past_key_values.append(tuple(pkv)) + new_past_key_values = tuple(new_past_key_values) + return new_past_key_values + return _crop_past_key_values(model, past_key_values, max_length) + + def _ipex_reorder_cache( past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: @@ -778,16 +800,3 @@ def _ipex_reorder_cache( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) for layer_past in past_key_values ) - - -def _ipex_crop_past_key_values(model, past_key_values, max_length): - if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"): - new_past_key_values = [] - for i in range(len(past_key_values)): - pkv = [] - pkv.append(past_key_values[i][0][:, :max_length, :max_length, :]) - pkv += [past_key_values[i][_] for _ in range(1, 4)] - new_past_key_values.append(tuple(pkv)) - new_past_key_values = tuple(new_past_key_values) - return new_past_key_values - return _crop_past_key_values(model, past_key_values, max_length)