Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add IPEX-XPU support for Llama2 model Inference #891

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 213 additions & 47 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import List, Optional, Tuple, Union

import torch
from intel_extension_for_pytorch.llm.functional import rms_norm
from torch import nn
from torch.nn import functional as F
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
Expand All @@ -32,7 +33,6 @@

_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.3.0"


if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
logger.warning(
f"Please upgrade the IPEX version to at least {_IPEX_MINIMUM_VERSION_FOR_PATCHING} if you want to patch the model."
Expand All @@ -48,9 +48,48 @@
)


# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83
def matmul_add_add(attn_output, weight, bias=None, residual=None):
seq_len, bs, _ = attn_output.size()
if residual is None:
attn_output = torch.matmul(attn_output, weight)
if bias is not None:
attn_output += bias
else:
if bias is not None:
attn_output = torch.ops.torch_ipex.mm_bias_resadd(attn_output, weight, bias, 1.0, residual, 1.0)
else:
attn_output = torch.addmm(
residual.flatten(0, -2),
attn_output.flatten(0, -2),
weight,
beta=1.0,
)
attn_output = attn_output.view(seq_len, bs, -1)
return attn_output


def padding_attn_mask(attn_mask, alignment):
if attn_mask is None:
return None
assert isinstance(
attn_mask, torch.Tensor
), f"attn mask is supposed to be a tensor, instead we got {type(attn_mask)}"
if attn_mask.device == torch.device("cpu"):
return attn_mask
last_dim_size = attn_mask.size(-1)
aligned_size = (last_dim_size + alignment - 1) // alignment * alignment
mask_size = [*attn_mask.size()[:-1], aligned_size]
new_attn_mask = torch.empty(mask_size, dtype=attn_mask.dtype, device=attn_mask.device).fill_(-65504.0)
new_attn_mask[..., :last_dim_size] = attn_mask
return new_attn_mask


def _ipex_rms_layer_norm_forward(self, hidden_states):
return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon)
if hidden_states.device.type == "xpu":
return rms_norm(hidden_states, self.weight, self.variance_epsilon)
else:
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83
return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon)


# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130
Expand Down Expand Up @@ -108,19 +147,29 @@ def _llama_model_forward(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)

attention_mask = padding_attn_mask(attention_mask, 8)

# embed positions
hidden_states = inputs_embeds

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None

if hidden_states.device.type == "xpu":
seqlen = hidden_states.size(1)
head_dim = self.layers[0].self_attn.head_dim
sin, cos = self.layers[0].self_attn.ipex_rope.get_sin_cos(seqlen, head_dim // 2)
sin = sin.squeeze()[position_ids].unsqueeze(2)
cos = cos.squeeze()[position_ids].unsqueeze(2)
decoder_layer_kwargs = {"sin": sin, "cos": cos}
else:
decoder_layer_kwargs = {}
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)

past_key_value = past_key_values[idx] if past_key_values is not None else None
past_key_value = past_key_values[idx] if past_key_values is not None and len(past_key_values) > idx else None

layer_outputs = decoder_layer(
hidden_states,
Expand All @@ -129,6 +178,7 @@ def _llama_model_forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**decoder_layer_kwargs,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -174,14 +224,82 @@ def __init__(self, module, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
self.ipex_scale_dot_product = IndirectAccessKVCacheAttention(text_max_length=config.max_position_embeddings)
if hasattr(config, "rope_theta"):
self.ipex_rope = RotaryEmbedding(
config.max_position_embeddings,
config.hidden_size // config.num_attention_heads,
config.rope_theta,
config.architectures[0],
self.module_device = next(module.parameters()).device.type
if self.module_device == "xpu":
from intel_extension_for_pytorch.transformers.models.xpu.fusions.mha_fusion import _IPEXRopeXPU

self.ipex_rope = _IPEXRopeXPU(
module.config.max_position_embeddings,
module.config.hidden_size // module.config.num_attention_heads,
module.config.rope_theta,
module.config.architectures[0],
)
self.port_parameters(module)
torch.xpu.empty_cache()
else:
self.ipex_scale_dot_product = IndirectAccessKVCacheAttention(
text_max_length=config.max_position_embeddings
)
if hasattr(config, "rope_theta"):
self.ipex_rope = RotaryEmbedding(
config.max_position_embeddings,
config.hidden_size // config.num_attention_heads,
config.rope_theta,
config.architectures[0],
)

def port_parameters(self, module):
self.qkv_proj_bias = None
self.qkv_proj_weight = None
if self.num_heads == self.num_key_value_heads:
q_proj = module.q_proj.weight.transpose(0, 1)
k_proj = module.k_proj.weight.transpose(0, 1)
v_proj = module.v_proj.weight.transpose(0, 1)
self.qkv_proj_weight = torch.stack([q_proj, k_proj, v_proj]).contiguous().view([3, -1, q_proj.shape[-1]])
module.q_proj.weight.data = self.qkv_proj_weight[0, :, :].transpose(0, 1)
module.k_proj.weight.data = self.qkv_proj_weight[1, :, :].transpose(0, 1)
module.v_proj.weight.data = self.qkv_proj_weight[2, :, :].transpose(0, 1)
if module.q_proj.bias is not None:
self.qkv_proj_bias = (
torch.stack([module.q_proj.bias, module.k_proj.bias, module.v_proj.bias])
.contiguous()
.view([3, -1])
)
module.q_proj.bias.data = self.qkv_proj_bias[0]
module.k_proj.bias.data = self.qkv_proj_bias[1]
module.v_proj.bias.data = self.qkv_proj_bias[2]
else:
q_proj = module.q_proj.weight.view(
self.num_key_value_heads, self.num_key_value_groups, self.head_dim, self.hidden_size
)
k_proj = module.k_proj.weight.view(self.num_key_value_heads, 1, self.head_dim, self.hidden_size)
v_proj = module.v_proj.weight.view(self.num_key_value_heads, 1, self.head_dim, self.hidden_size)
self.qkv_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=1).view(
[self.num_key_value_heads, self.num_key_value_groups + 2, self.head_dim, self.hidden_size]
)
module.q_proj.data = self.qkv_proj_weight[:, : self.num_key_value_groups, :, :].reshape(
[self.num_key_value_heads * self.num_key_value_groups * self.head_dim, self.hidden_size]
)
module.k_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups, :, :].reshape(
[self.num_key_value_heads * self.head_dim, self.hidden_size]
)
module.v_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups + 1, :, :].reshape(
[self.num_key_value_heads * self.head_dim, self.hidden_size]
)
self.qkv_proj_weight = self.qkv_proj_weight.permute(3, 0, 1, 2).contiguous()
if module.q_proj.bias is not None:
q_bias = module.q_proj.bias.view(self.num_key_value_heads, self.num_key_value_groups, self.head_dim)
k_bias = module.k_proj.bias.view(self.num_key_value_heads, 1, self.head_dim)
v_bias = module.v_proj.bias.view(self.num_key_value_heads, 1, self.head_dim)
self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view(
[self.num_key_value_heads, self.num_key_value_groups + 2, self.head_dim]
)
module.q_proj.bias.data = self.qkv_proj_bias[:, : self.num_key_value_groups, self.head_dim].view(-1)
module.k_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups, self.head_dim].view(-1)
module.v_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups + 1, self.head_dim].view(-1)
self.o_proj_weight = module.o_proj.weight.transpose(0, 1).contiguous()
module.o_proj.weight.data = self.o_proj_weight.transpose(0, 1)
self.o_proj_bias = module.o_proj.bias

def qkv_gemm(self, hidden_states):
raise NotImplementedError("Need to implement in specific model class")
Expand All @@ -192,16 +310,25 @@ def rope(self, *args, **kwargs):
def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask, **kwargs):
# This ipex op pre-allocates buffers for past_key_values and use beam index history
# which to decide which beam should be used to make attention scale dot more efficient.
(attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product(
query,
key,
value,
math.sqrt(self.head_dim),
past_key_value,
kwargs.get("head_mask", None),
attention_mask,
kwargs.get("alibi", None),
)
if self.module_device == "xpu":
scale = 1.0 / math.sqrt(self.head_dim)
is_causal = False
attn_output = torch.xpu.IpexSDP(
query, key, value, None, attention_mask, None, scale, 1.0, 0.0, is_causal, False
)
attn_weights = None
past_key_value = (key, value)
else:
(attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product(
query,
key,
value,
math.sqrt(self.head_dim),
past_key_value,
kwargs.get("head_mask", None),
attention_mask,
kwargs.get("alibi", None),
)
return attn_output, past_key_value, attn_weights

def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask, **kwargs):
Expand Down Expand Up @@ -235,10 +362,18 @@ def forward(
qkv_out = self.qkv_gemm(hidden_states)
if isinstance(qkv_out, tuple) and len(qkv_out) == 3:
query, key, value = self.qkv_gemm(hidden_states)
query, key = self.rope(query, key, kv_seq_len, use_cache, position_ids=position_ids)
query, key = self.rope(query, key, kv_seq_len, use_cache, position_ids, **kwargs)
else:
query, key, value = self.rope(qkv_out, kv_seq_len, use_cache, past_len=past_len)

if self.module_device == "xpu":
if past_key_value is not None:
key = torch.cat([past_key_value[0].transpose(1, 2), key], dim=1)
value = torch.cat([past_key_value[1].transpose(1, 2), value], dim=1)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)

attention_mask = self.prepare_attention_mask_float(attention_mask, query.dtype)
sdpa = self.sdpa_with_cache if use_cache else self.sdpa_without_cache
attn_output, past_key_value, attn_weights = sdpa(
Expand All @@ -251,6 +386,7 @@ def forward(
head_mask=kwargs.get("head_mask", None),
alibi=kwargs.get("alibi", None),
)

attn_output = self.postprocess_attention_output(attn_output, bsz, seq_len)

if not output_attentions:
Expand All @@ -262,9 +398,10 @@ def forward(
class _IPEXLlamaAttention(_IPEXAttention):
def __init__(self, module, config) -> None:
super().__init__(module, config)
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = LinearAdd(module.o_proj)
del self.__dict__["_modules"]["o_proj"]
if self.module_device == "cpu":
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = LinearAdd(module.o_proj)
del self.__dict__["_modules"]["o_proj"]

def qkv_gemm(self, hidden_states):
bsz, seq_len, _ = hidden_states.size()
Expand All @@ -274,11 +411,16 @@ def qkv_gemm(self, hidden_states):

return query, key, value

def rope(self, query, key, kv_seq_len, use_cache, position_ids):
if use_cache:
args = (self.head_dim, self.head_dim // 2, self.head_dim, kv_seq_len)
key = self.ipex_rope(key, position_ids, self.num_key_value_heads, *args)
query = self.ipex_rope(query, position_ids, self.num_heads, *args)
def rope(self, query, key, kv_seq_len, use_cache, position_ids, **kwargs):
if self.module_device == "xpu":
sin = kwargs.pop("sin", None)
cos = kwargs.pop("cos", None)
self.ipex_rope.apply_embedding(query, sin, cos, self.head_dim // 2, key)
else:
if use_cache:
args = (self.head_dim, self.head_dim // 2, self.head_dim, kv_seq_len)
key = self.ipex_rope(key, position_ids, self.num_key_value_heads, *args)
query = self.ipex_rope(query, position_ids, self.num_heads, *args)
return query, key

# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L341
Expand Down Expand Up @@ -372,28 +514,52 @@ def __init__(self, module, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mlp_linear_add = LinearAdd(module.down_proj)
del self.__dict__["_modules"]["down_proj"]
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
del self.__dict__["_modules"]["gate_proj"]
del self.__dict__["_modules"]["up_proj"]
self.module_device = next(module.parameters()).device.type
if self.module_device == "xpu":
self.port_parameter(module)
torch.xpu.empty_cache()
else:
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mlp_linear_add = LinearAdd(module.down_proj)
del self.__dict__["_modules"]["down_proj"]
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
del self.__dict__["_modules"]["gate_proj"]
del self.__dict__["_modules"]["up_proj"]

def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs):
if hasattr(self, "linear_silu_mul"):
mlp_gate = self.linear_silu_mul(hidden_states)
if hasattr(self, "mlp_linear_add"):
hidden_states = self.mlp_linear_add(mlp_gate, residual)
if self.module_device == "xpu":
up = torch.ops.torch_ipex.mm_silu(hidden_states, self.gate_proj_weight)
hidden_states = torch.ops.torch_ipex.mm_resmul(hidden_states, self.up_proj_weight, up)
hidden_states = matmul_add_add(hidden_states, self.down_proj_weight, self.down_proj_bias, residual)
else:
if hasattr(self, "linear_silu_mul"):
mlp_gate = self.linear_silu_mul(hidden_states)
if hasattr(self, "mlp_linear_add"):
hidden_states = self.mlp_linear_add(mlp_gate, residual)
else:
hidden_states = self.down_proj(
self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)
)
hidden_states = residual + hidden_states
else:
hidden_states = self.down_proj(mlp_gate)
hidden_states = self.down_proj(
self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)
)
hidden_states = residual + hidden_states
else:
hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
hidden_states = residual + hidden_states

return hidden_states

def port_parameter(self, module):
self.up_proj_weight = module.up_proj.weight.transpose(0, 1).contiguous()
module.up_proj.weight.data = self.up_proj_weight.transpose(0, 1)
self.gate_proj_weight = module.gate_proj.weight.transpose(0, 1).contiguous()
module.gate_proj.weight.data = self.gate_proj_weight.transpose(0, 1)
self.down_proj_weight = module.down_proj.weight.transpose(0, 1).contiguous()
module.down_proj.weight.data = self.down_proj_weight.transpose(0, 1)
self.up_proj_bias = module.up_proj.bias
self.gate_proj_bias = module.gate_proj.bias
self.down_proj_bias = module.down_proj.bias


class _IPEXFalconMLP(nn.Module):
def __init__(self, module, config) -> None:
Expand Down
Loading