Skip to content

Commit

Permalink
Paged attn (#1024)
Browse files Browse the repository at this point in the history
* fix reorder cache for non-patch models

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* disable torch < 2.3 tests, we won't use torch < 2.4

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix test beam serach

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix cache selection

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* upgrad to transformers4.46

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* change ipex test yaml transformers version to 4.46

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
  • Loading branch information
jiqing-feng authored Nov 25, 2024
1 parent 0d7f8b6 commit b48192b
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 90 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_ipex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
transformers-version: ["4.45.*"]
transformers-version: ["4.46.*"]
torch-version: ["2.4.0", "2.5.0"]

runs-on: ubuntu-22.04
Expand Down
14 changes: 7 additions & 7 deletions optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,21 @@ class IPEXPagedCache(Cache):
def __init__(
self,
config: PretrainedConfig,
max_batch_size: int,
batch_size: int,
max_cache_len: int,
device,
dtype=None,
layer_device_map=None,
**kwargs,
) -> None:
super().__init__()
self.max_batch_size = max_batch_size
self.batch_size = max_batch_size
self.batch_size = batch_size
# Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = torch.zeros([max_batch_size], dtype=torch.int32, device=device)
self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device)
self.block_size = 16
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * max_batch_size
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
max_batch_size, -1
batch_size, -1
)
self.free_blocks = torch.arange(self.num_blocks, device=device)
self.max_cache_len = max_cache_len
Expand Down Expand Up @@ -194,7 +194,7 @@ def get_max_length(self) -> Optional[int]:

def reset(self):
"""Resets the cache values while preserving the objects"""
self._seen_tokens = torch.zeros([self.max_batch_size], dtype=torch.int32, device=self.block_tables.device)
self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.block_tables.device)
self.block_tables.fill_(-1)
self.free_blocks = torch.arange(self.num_blocks, device=self.block_tables.device)
self.max_seq_len = 0
Expand Down
4 changes: 2 additions & 2 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@


# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version
_TRANSFORMERS_MIN_VERSION = "4.45.0"
_TRANSFORMERS_MAX_VERSION = "4.45.99"
_TRANSFORMERS_MIN_VERSION = "4.46.0"
_TRANSFORMERS_MAX_VERSION = "4.46.99"

_IPEX_EXPORTED_GENERATION_TASKS = ("text-generation",)

Expand Down
89 changes: 16 additions & 73 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,6 @@ def forward(
class IPEXModelForCausalLM(IPEXModel, GenerationMixin):
auto_model_class = AutoModelForCausalLM
export_feature = "text-generation"
_supports_cache_class = False
_is_stateful = False

def __init__(
self,
Expand All @@ -430,6 +428,13 @@ def __init__(
super().__init__(
model, config, export=export, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache
)

self._supports_cache_class = getattr(model, "_supports_cache_class", None)
self._supports_sdpa = getattr(model, "_supports_sdpa", None)
self._supports_cache_class = getattr(model, "_supports_cache_class", None)
self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None)
self._supports_static_cache = getattr(model, "_supports_static_cache", None)

if self._add_patch:
self._supports_cache_class = True
GenerationMixin.__init__(self)
Expand All @@ -448,18 +453,6 @@ def __init__(
except AttributeError:
self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping)

if is_transformers_version(">=", "4.38.0") and model_type in {
"llama",
"phi",
"persimmon",
"mistral",
"falcon",
"gpt2",
}:
self.prepare_inputs_for_generation = _ipex_prepare_inputs_for_generation
else:
self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self)

if hasattr(self.model_cls, "_convert_to_standard_cache"):
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
Expand Down Expand Up @@ -521,6 +514,12 @@ def _prepare_generation_config(

return generation_config, model_kwargs

def _reorder_cache(self, *args, **kwargs):
return self.model._reorder_cache(*args, **kwargs)

def prepare_inputs_for_generation(self, *args, **kwargs):
return self.model.prepare_inputs_for_generation(*args, **kwargs)

def generate(self, *args, **kwargs):
new_kwargs = copy.deepcopy(kwargs)
if is_ipex_version("<", "2.4.0") and self._add_patch and new_kwargs.get("assistant_model", None):
Expand Down Expand Up @@ -556,68 +555,12 @@ def generate(self, *args, **kwargs):
return result


def _ipex_prepare_inputs_for_generation(
input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
from transformers.cache_utils import Cache

if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = cache_length = past_key_values.get_seq_length()
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs


def _ipex_crop_past_key_values(model, past_key_values, max_length):
if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"):
if isinstance(past_key_values, IPEXPagedCache):
return past_key_values.crop(max_length)
# .crop is an inplace op, returns None
past_key_values = past_key_values.crop(max_length)
return past_key_values
else:
raise ValueError("only support IPEXPagedCache input now")
return _crop_past_key_values(model, past_key_values, max_length)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
"nncf": ["nncf>=2.11.0"],
"openvino": ["nncf>=2.11.0", "openvino==2024.5.0", "openvino-tokenizers==2024.5.0"],
"neural-compressor": ["neural-compressor[pt]>3.0", "accelerate", "transformers<4.46"],
"ipex": ["intel-extension-for-pytorch", "transformers>=4.39,<4.46"],
"ipex": ["intel-extension-for-pytorch", "transformers>=4.39,<4.47"],
"diffusers": ["diffusers"],
"quality": QUALITY_REQUIRE,
"tests": TESTS_REQUIRE,
Expand Down
9 changes: 3 additions & 6 deletions tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
)
from optimum.intel.utils.import_utils import is_ipex_version
from optimum.utils.testing_utils import grid_parameters
from utils_tests import MODEL_NAMES, IS_XPU

Expand Down Expand Up @@ -307,20 +306,19 @@ def test_assisted_decoding(self, model_arch):
@parameterized.expand(
grid_parameters(
{
"model_arch": IPEX_PATCHED_SUPPORTED_ARCHITECTURES,
"model_arch": SUPPORTED_ARCHITECTURES,
"use_cache": [True, False],
}
)
)
@unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching")
def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache):
def test_ipex_beam_search(self, test_name, model_arch, use_cache):
model_id = MODEL_NAMES[model_arch]
set_seed(SEED)
dtype = torch.float32
if IS_XPU:
dtype = torch.float16
model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, use_cache=use_cache, torch_dtype=dtype)
if use_cache:
if use_cache and model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES:
self.assertTrue(model.add_patch)
device = model.device
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device)
Expand All @@ -346,7 +344,6 @@ def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache):
self.assertIsInstance(outputs, torch.Tensor)
self.assertTrue(torch.equal(outputs, transformers_outputs))

@unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching")
def test_compare_with_and_without_past_key_values(self):
model_id = "Intel/tiny_random_llama2"
dtype = torch.float32
Expand Down

0 comments on commit b48192b

Please sign in to comment.