Skip to content

Commit

Permalink
unify xpu and cpu backend and use paged attention (#1009)
Browse files Browse the repository at this point in the history
* add page attention implementation remove jit logic

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* add support in transformers 4.45

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix congif (#935)

* move patch model to init

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* refine class IPEXPagedCache's update method (#945)

* refine class IPEXPagedCache's update method

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* replace tensor on xpu to List to avoid memory copy

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* split IPEXPagedCache's update function into `update_for_prefill` and `update_for_decode`

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix bug when doing beam search (#954)

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* enable qkv concat layer (#958)

* enable qkv

* split key value into 2 lists

* add xpu cache optimiztion

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* xpu mlp optimization

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* optimize cache ops in xpu, improve for beam search

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* enable gpt2, falcon has core dump error in PagedAttention.single_quer… (#979)

* enable gpt2, falcon has core dump error in PagedAttention.single_query_cached_kv_attention

* enable new_decoder_arch falcon

* only keep 1 config

* rm autocast

* fix unit test case, CPU part is OK; Enable Falcon7b for XPU (#992)

* fix bug when run IPEXCausalModel forward directly; fix bug when using `save_pretrain`

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* add LinearGelu Op support for XPU

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix unit test error

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* adjust unit test case

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix bug

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* skip assited decoding unit test for models using paged attention (#998)

* skip assited decoding unit test for models using paged attention

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* XPU CI tests get almost all passed

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix ci config (#1010)

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* Fix tests versions (#1011)

* fix ci config

* fix test versions

* fix ipex version

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix torch test version (#1012)

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* use python3.9 test (#1013)

* use python3.9 test

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* change ipex transformers limited verison in setup (#1015)

* change ipex transformers limited verison in setup
* fix inc tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add XPU LinearAddAdd op (#1017)

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix bert and vit patch (#1022)

* fix bert and vit patch
* fix vit and bert save


Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* Paged attn (#1024)

* fix reorder cache for non-patch models

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* disable torch < 2.3 tests, we won't use torch < 2.4

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix test beam serach

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix cache selection

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* upgrad to transformers4.46

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* change ipex test yaml transformers version to 4.46

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* set device as the same as origin model (#1031)

* set device as the same as origin model
* fix device

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* Simplify IPEXModel (#1032)

* simplify forward and save pretrained since no jit support

* fix format

* rm warmup because no jit mode anymore

* simplify forward for causal lm model

* fix paged pkv  forward

* disable use_cache when just run forward

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* nice code (#1035)

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* Paged attn (#1036)

* nice code
* device type adjustment

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* Enable torch.compile for non-generation tasks in CPU (#1037)

* enable compile for non-generation tasks
* add no_grad in forward
* warmup compiled model
* disable compile not ready models
* set system level optimize for torch.compile
* fix typo
* add comments
* set torch minimum version for compiling

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* Fix ipex upload and update readme. (#1045)

* fix readme and push to hub support

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* rm export in tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* test with torch 2.5.*

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* Fix tests (#1047)

* fix tests
* fix typo
* add patched tests

* change forward to generate

* fix tests

* fix test model name


---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* Patch gpt2 block forward for passing input_lens. (#1050)

* fix forward without pkv
* patch gpt2 block forward
* fix typo
* revert causal lm tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: kaixuanliu <kaixuan.liu@intel.com>
Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
  • Loading branch information
4 people authored Dec 5, 2024
1 parent c94b3f5 commit 41f0a46
Show file tree
Hide file tree
Showing 13 changed files with 1,035 additions and 860 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_inc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
torch-version: ["2.4.*", "2.5.0"]
torch-version: ["2.4.0", "2.5.*"]

runs-on: ubuntu-22.04

Expand Down
8 changes: 2 additions & 6 deletions .github/workflows/test_ipex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
strategy:
fail-fast: false
matrix:
torch-version: ["2.2.0", "2.3.*"]
transformers-version: ["4.39.0", "4.44.*"]
transformers-version: ["4.46.0", "4.46.3"]
torch-version: ["2.4.0", "2.5.*"]

runs-on: ubuntu-22.04

Expand All @@ -38,10 +38,6 @@ jobs:
pip install torch==${{ matrix.torch-version }} torchaudio torchvision --extra-index-url https://download.pytorch.org/whl/cpu
pip install .[ipex,tests] transformers[testing]==${{ matrix.transformers-version }} intel_extension_for_pytorch==${{ matrix.torch-version }}
- if: ${{ matrix.torch-version == '2.2.0' }}
name: Downgrade Numpy
run: pip install numpy==1.*

- name: Assert versions
run: |
python -c "import torch; print(torch.__version__); assert torch.__version__.startswith('${{ matrix.torch-version }}'.replace('.*', ''))"
Expand Down
6 changes: 3 additions & 3 deletions docs/source/ipex/inference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ Optimum Intel can be used to load models from the [Hub](https://huggingface.co/m

## Loading

You can load your model and apply IPEX optimizations (including weight prepacking and graph mode). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators.
For now, support is only enabled for CPUs and the original model will be exported via TorchScript. In the future `torch.compile` will be used and model exported via TorchScript will get deprecated.
You can load your model and apply IPEX optimizations (apply torch.compile for non-generation tasks). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators.
For now, support is enabled for Intel CPU/GPU. Previous models converted to TorchScript will be deprecated in v1.22.

```diff
import torch
Expand All @@ -25,7 +25,7 @@ For now, support is only enabled for CPUs and the original model will be exporte

model_id = "gpt2"
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+ model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, export=True)
+ model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
results = pipe("He's a dreadful magician and")
Expand Down
238 changes: 238 additions & 0 deletions optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
from typing import List, Optional, Tuple

import torch
from intel_extension_for_pytorch.llm.modules import PagedAttention
from transformers import Cache, PretrainedConfig


class IPEXPagedCache(Cache):
"""
A PagedCache that grows dynamically as more tokens are generated. everytime it grows block-size memory, vendor could set the pageCache memory layout.
ipex-xpu:
ipex-cpu:
Example:
```python
>>> from transformers import AutoTokenizer
>>> from optimum.intel import IPEXModelForCausalLM
>>> from optimum.exporters.ipex.cache_utils import IPEXPagedCache
>>> model = IPEXModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", export=True)
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = IPEXPagedCache()
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""

def __init__(
self,
config: PretrainedConfig,
batch_size: int,
max_cache_len: int,
device,
dtype=None,
layer_device_map=None,
**kwargs,
) -> None:
super().__init__()
self.batch_size = batch_size
# Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device)
self.block_size = 16
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
batch_size, -1
)
self.free_blocks = torch.arange(self.num_blocks, device=device)
self.max_cache_len = max_cache_len
self.num_kv_heads = config.num_key_value_heads
self.num_hidden_layers = config.num_hidden_layers
if hasattr(config, "head_dim"):
head_size = config.head_dim
else:
head_size = config.hidden_size // config.num_attention_heads
self.head_size = head_size
self.max_seq_len = 0

self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []

if device.type == "cpu":
key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
elif device.type == "xpu":
key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1)
value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size)
for i in range(config.num_hidden_layers):
new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device)
new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

def update_for_prefill(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
batch_size: int,
input_lens: torch.Tensor,
):
if layer_idx == 0:
all_block_indices = []
all_slot_offsets = []
num_blocks = (input_lens + self.block_size - 1) // self.block_size
for i in range(batch_size):
for b_idx in range(num_blocks[i]):
if self.block_tables[i][b_idx] == -1:
# need a free block
self.block_tables[i][b_idx] = self.free_blocks[0]
self.free_blocks = self.free_blocks[1:]

slots_range = torch.arange(input_lens[i], device=key_states.device)
block_indices = slots_range // self.block_size
slot_offsets = slots_range % self.block_size
all_block_indices.append(self.block_tables[i][block_indices])
all_slot_offsets.append(slot_offsets)

all_block_indices = torch.cat(all_block_indices)
all_slot_offsets = torch.cat(all_slot_offsets)
self.slots = all_block_indices * self.block_size + all_slot_offsets

# Update the cache
PagedAttention.reshape_and_cache(
key_states,
value_states,
self.key_cache[layer_idx],
self.value_cache[layer_idx],
self.slots,
)

# Update the number of seen tokens
if layer_idx == self.num_hidden_layers - 1:
self._seen_tokens = self._seen_tokens + input_lens
self.max_seq_len, _ = self._seen_tokens.max(dim=0)

def update_for_decode(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
batch_size: int,
):
if layer_idx == 0:
start_block_idx = self._seen_tokens // self.block_size
num_blocks = (self._seen_tokens + self.block_size) // self.block_size
slot_offset_in_block = (self._seen_tokens) % self.block_size
self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32)
for i in range(batch_size):
for b_idx in range(start_block_idx[i], num_blocks[i]):
if self.block_tables[i][b_idx] == -1:
# need a free block
self.block_tables[i][b_idx] = self.free_blocks[0]
self.free_blocks = self.free_blocks[1:]

self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i]
# Update the cache
PagedAttention.reshape_and_cache(
key_states,
value_states,
self.key_cache[layer_idx],
self.value_cache[layer_idx],
self.slots,
)

# Update the number of seen tokens
if layer_idx == self.num_hidden_layers - 1:
self._seen_tokens = self._seen_tokens + 1
self.max_seq_len = self.max_seq_len + 1

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
attention_mask: torch.Tensor,
input_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
Return:
A tuple containing the updated key and value states.
"""

batch_size = input_lens.shape[-1]
if self.get_seq_length() == 0:
# prefill
self.update_for_prefill(key_states, value_states, layer_idx, batch_size, input_lens)
else:
# decode
self.update_for_decode(key_states, value_states, layer_idx, batch_size)

return self.key_cache[layer_idx], self.value_cache[layer_idx]

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
return self.max_seq_len

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states."""
return self.max_cache_len

def reset(self):
"""Resets the cache values while preserving the objects"""
self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.block_tables.device)
self.block_tables.fill_(-1)
self.free_blocks = torch.arange(self.num_blocks, device=self.block_tables.device)
self.max_seq_len = 0

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
device = self.block_tables.device
origin_table = self.block_tables.clone()
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device))
mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0)
num_blocks = mask.cumsum(-1)[:, -1]
updated_table = []
for i in range(beam_idx.shape[0]):
self.block_tables[i, 0 : num_blocks[i] - 1] = updated_block_tables[i, 0 : num_blocks[i] - 1]
updated_table.append(self.block_tables[i : i + 1, num_blocks[i] - 1 : num_blocks[i]])
updated_table = torch.cat(tuple(updated_table), dim=0)
for layer_idx in range(self.num_hidden_layers):
self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx][updated_table[beam_idx]]
self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx][updated_table[beam_idx]]
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
self.free_blocks = torch.cat((self.free_blocks, free_table))

def crop(self, maximum_length: int):
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""

max_seq_len = self.get_seq_length()
if maximum_length < 0:
maximum_length = max_seq_len - abs(maximum_length)

if max_seq_len <= maximum_length:
return
origin_table = self.block_tables.clone()
for bs in range(self._seen_tokens.shape[0]):
new_tokens = self._seen_tokens[bs] + maximum_length - max_seq_len
num_blocks = (new_tokens + self.block_size - 1) // self.block_size
self.block_tables[bs, num_blocks:] = -1
self._seen_tokens[bs] = new_tokens
self.max_seq_len, _ = self._seen_tokens.max(dim=0)
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
self.free_blocks = torch.cat((self.free_blocks, free_table))
39 changes: 22 additions & 17 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# limitations under the License.

from transformers.models.bert.modeling_bert import BertIntermediate
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconForCausalLM
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
)
Expand All @@ -28,7 +27,9 @@

from .modeling_utils import (
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_falcon_model_forward,
_gpt2_block_forward,
_gpt2_model_forward,
_ipex_rms_layer_norm_forward,
_IPEXFalconDecoderLayer,
_IPEXGPT2Attention,
Expand All @@ -39,8 +40,8 @@


# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version
_TRANSFORMERS_MIN_VERSION = "4.39.0"
_TRANSFORMERS_MAX_VERSION = "4.44.99"
_TRANSFORMERS_MIN_VERSION = "4.46.0"
_TRANSFORMERS_MAX_VERSION = "4.46.99"

_IPEX_EXPORTED_GENERATION_TASKS = ("text-generation",)

Expand Down Expand Up @@ -75,7 +76,7 @@ def patch_op(m, target_m, new_op_name, new_op):
def _patch_llama_model(model):
"""
Patch llama model:
1. Use IPEX Rope and IAKV cache
1. Use IPEX rope and paged cache
2. Linear fusion with (2 Linears + Silu + Mul) and (Linear + Add)
"""
convert_functions(model, LlamaModel, "forward", _llama_model_forward)
Expand All @@ -87,11 +88,14 @@ def _patch_llama_model(model):
def _patch_falcon_model(model):
"""
Patch falcon model:
1. Disable SDPA so the attention mask will be compatible to ipex attention.
2. Use IPEX Rope and IAKV cache
3. Linear fusion with (Linear + Gelu) and (Linear + Add + Add)
1. Use IPEX rope and paged cache
2. Linear fusion with (Linear + Gelu) and (Linear + Add + Add)
"""
model.transformer._use_sdpa = False
num_key_value_heads = (
model.config.num_kv_heads if (model.config.new_decoder_architecture or not model.config.multi_query) else 1
)
setattr(model.config, "num_key_value_heads", num_key_value_heads)
convert_functions(model, FalconModel, "forward", _falcon_model_forward)
replace_customized_linear_with_linear(model)
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.config)
return model
Expand All @@ -100,12 +104,13 @@ def _patch_falcon_model(model):
def _patch_gpt2_model(model):
"""
Patch gpt2 model:
1. Disable SDPA so the attention mask will be compatible to ipex attention.
2. Use IAKV cache
1. Use IPEX paged attention
"""
model.transformer._attn_implementation = "eager"
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config)
num_key_value_heads = model.config.num_attention_heads
setattr(model.config, "num_key_value_heads", num_key_value_heads)
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward)
convert_functions(model, GPT2Block, "forward", _gpt2_block_forward)
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config)
return model


Expand Down Expand Up @@ -136,11 +141,11 @@ def _patch_model(model):
raise ImportError(
f"Only transformers versions {_TRANSFORMERS_MIN_VERSION} ~ {_TRANSFORMERS_MAX_VERSION} are verified."
)
if isinstance(model, LlamaForCausalLM):
if model.config.model_type == "llama":
model = _patch_llama_model(model)
elif isinstance(model, FalconForCausalLM):
elif model.config.model_type == "falcon":
model = _patch_falcon_model(model)
elif isinstance(model, GPT2LMHeadModel):
elif model.config.model_type == "gpt2":
model = _patch_gpt2_model(model)
elif model.config.model_type == "bert":
model = _patch_bert_model(model)
Expand Down
Loading

0 comments on commit 41f0a46

Please sign in to comment.