Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paged attn #1024

Merged
merged 6 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_ipex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
transformers-version: ["4.45.*"]
transformers-version: ["4.46.*"]
torch-version: ["2.4.0", "2.5.0"]

runs-on: ubuntu-22.04
Expand Down
14 changes: 7 additions & 7 deletions optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,21 @@ class IPEXPagedCache(Cache):
def __init__(
self,
config: PretrainedConfig,
max_batch_size: int,
batch_size: int,
max_cache_len: int,
device,
dtype=None,
layer_device_map=None,
**kwargs,
) -> None:
super().__init__()
self.max_batch_size = max_batch_size
self.batch_size = max_batch_size
self.batch_size = batch_size
# Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = torch.zeros([max_batch_size], dtype=torch.int32, device=device)
self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device)
self.block_size = 16
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * max_batch_size
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
max_batch_size, -1
batch_size, -1
)
self.free_blocks = torch.arange(self.num_blocks, device=device)
self.max_cache_len = max_cache_len
Expand Down Expand Up @@ -194,7 +194,7 @@ def get_max_length(self) -> Optional[int]:

def reset(self):
"""Resets the cache values while preserving the objects"""
self._seen_tokens = torch.zeros([self.max_batch_size], dtype=torch.int32, device=self.block_tables.device)
self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.block_tables.device)
self.block_tables.fill_(-1)
self.free_blocks = torch.arange(self.num_blocks, device=self.block_tables.device)
self.max_seq_len = 0
Expand Down
4 changes: 2 additions & 2 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@


# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version
_TRANSFORMERS_MIN_VERSION = "4.45.0"
_TRANSFORMERS_MAX_VERSION = "4.45.99"
_TRANSFORMERS_MIN_VERSION = "4.46.0"
_TRANSFORMERS_MAX_VERSION = "4.46.99"

_IPEX_EXPORTED_GENERATION_TASKS = ("text-generation",)

Expand Down
89 changes: 16 additions & 73 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,6 @@ def forward(
class IPEXModelForCausalLM(IPEXModel, GenerationMixin):
auto_model_class = AutoModelForCausalLM
export_feature = "text-generation"
_supports_cache_class = False
_is_stateful = False

def __init__(
self,
Expand All @@ -430,6 +428,13 @@ def __init__(
super().__init__(
model, config, export=export, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache
)

self._supports_cache_class = getattr(model, "_supports_cache_class", None)
self._supports_sdpa = getattr(model, "_supports_sdpa", None)
self._supports_cache_class = getattr(model, "_supports_cache_class", None)
self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None)
self._supports_static_cache = getattr(model, "_supports_static_cache", None)

if self._add_patch:
self._supports_cache_class = True
GenerationMixin.__init__(self)
Expand All @@ -448,18 +453,6 @@ def __init__(
except AttributeError:
self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping)

if is_transformers_version(">=", "4.38.0") and model_type in {
"llama",
"phi",
"persimmon",
"mistral",
"falcon",
"gpt2",
}:
self.prepare_inputs_for_generation = _ipex_prepare_inputs_for_generation
else:
self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self)

if hasattr(self.model_cls, "_convert_to_standard_cache"):
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
Expand Down Expand Up @@ -521,6 +514,12 @@ def _prepare_generation_config(

return generation_config, model_kwargs

def _reorder_cache(self, *args, **kwargs):
return self.model._reorder_cache(*args, **kwargs)

def prepare_inputs_for_generation(self, *args, **kwargs):
return self.model.prepare_inputs_for_generation(*args, **kwargs)

def generate(self, *args, **kwargs):
new_kwargs = copy.deepcopy(kwargs)
if is_ipex_version("<", "2.4.0") and self._add_patch and new_kwargs.get("assistant_model", None):
Expand Down Expand Up @@ -556,68 +555,12 @@ def generate(self, *args, **kwargs):
return result


def _ipex_prepare_inputs_for_generation(
input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
from transformers.cache_utils import Cache

if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = cache_length = past_key_values.get_seq_length()
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs


def _ipex_crop_past_key_values(model, past_key_values, max_length):
if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"):
if isinstance(past_key_values, IPEXPagedCache):
return past_key_values.crop(max_length)
# .crop is an inplace op, returns None
past_key_values = past_key_values.crop(max_length)
return past_key_values
else:
raise ValueError("only support IPEXPagedCache input now")
return _crop_past_key_values(model, past_key_values, max_length)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
"nncf": ["nncf>=2.11.0"],
"openvino": ["nncf>=2.11.0", "openvino==2024.5.0", "openvino-tokenizers==2024.5.0"],
"neural-compressor": ["neural-compressor[pt]>3.0", "accelerate", "transformers<4.46"],
"ipex": ["intel-extension-for-pytorch", "transformers>=4.39,<4.46"],
"ipex": ["intel-extension-for-pytorch", "transformers>=4.39,<4.47"],
"diffusers": ["diffusers"],
"quality": QUALITY_REQUIRE,
"tests": TESTS_REQUIRE,
Expand Down
9 changes: 3 additions & 6 deletions tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
)
from optimum.intel.utils.import_utils import is_ipex_version
from optimum.utils.testing_utils import grid_parameters
from utils_tests import MODEL_NAMES, IS_XPU

Expand Down Expand Up @@ -307,20 +306,19 @@ def test_assisted_decoding(self, model_arch):
@parameterized.expand(
grid_parameters(
{
"model_arch": IPEX_PATCHED_SUPPORTED_ARCHITECTURES,
"model_arch": SUPPORTED_ARCHITECTURES,
"use_cache": [True, False],
}
)
)
@unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching")
def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache):
def test_ipex_beam_search(self, test_name, model_arch, use_cache):
model_id = MODEL_NAMES[model_arch]
set_seed(SEED)
dtype = torch.float32
if IS_XPU:
dtype = torch.float16
model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, use_cache=use_cache, torch_dtype=dtype)
if use_cache:
if use_cache and model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES:
self.assertTrue(model.add_patch)
device = model.device
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device)
Expand All @@ -346,7 +344,6 @@ def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache):
self.assertIsInstance(outputs, torch.Tensor)
self.assertTrue(torch.equal(outputs, transformers_outputs))

@unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching")
def test_compare_with_and_without_past_key_values(self):
model_id = "Intel/tiny_random_llama2"
dtype = torch.float32
Expand Down
Loading