diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index abcfac64d038..1c7f62ec6ea7 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -296,6 +296,8 @@ title: Trainer - local: main_classes/deepspeed title: DeepSpeed + - local: main_classes/executorch + title: ExecuTorch - local: main_classes/feature_extractor title: Feature Extractor - local: main_classes/image_processor diff --git a/docs/source/en/main_classes/executorch.md b/docs/source/en/main_classes/executorch.md new file mode 100644 index 000000000000..28e0a445e79f --- /dev/null +++ b/docs/source/en/main_classes/executorch.md @@ -0,0 +1,33 @@ + + + +# ExecuTorch + +[`ExecuTorch`](https://github.com/pytorch/executorch) is an end-to-end solution for enabling on-device inference capabilities across mobile and edge devices including wearables, embedded devices and microcontrollers. It is part of the PyTorch ecosystem and supports the deployment of PyTorch models with a focus on portability, productivity, and performance. + +ExecuTorch introduces well defined entry points to perform model, device, and/or use-case specific optimizations such as backend delegation, user-defined compiler transformations, memory planning, and more. The first step in preparing a PyTorch model for execution on an edge device using ExecuTorch is to export the model. This is achieved through the use of a PyTorch API called [`torch.export`](https://pytorch.org/docs/stable/export.html). + + +## ExecuTorch Integration + +An integration point is being developed to ensure that 🤗 Transformers can be exported using `torch.export`. The goal of this integration is not only to enable export but also to ensure that the exported artifact can be further lowered and optimized to run efficiently in `ExecuTorch`, particularly for mobile and edge use cases. + +[[autodoc]] integrations.executorch.TorchExportableModuleWithStaticCache + - forward + +[[autodoc]] integrations.executorch.convert_and_export_with_cache diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 4f4b17ac84f1..2418500e9f32 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1322,6 +1322,13 @@ "WhisperTimeStampLogitsProcessor", ] ) + + # PyTorch domain libraries integration + _import_structure["integrations.executorch"] = [ + "TorchExportableModuleWithStaticCache", + "convert_and_export_with_cache", + ] + _import_structure["modeling_flash_attention_utils"] = [] _import_structure["modeling_outputs"] = [] _import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS"] @@ -6147,6 +6154,10 @@ WatermarkLogitsProcessor, WhisperTimeStampLogitsProcessor, ) + from .integrations.executorch import ( + TorchExportableModuleWithStaticCache, + convert_and_export_with_cache, + ) from .modeling_rope_utils import ROPE_INIT_FUNCTIONS from .modeling_utils import PreTrainedModel from .models.albert import ( diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 1cdd1ae135e1..b3e94da3d7d7 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -293,6 +293,46 @@ def validate(self): ) +@dataclass +class StaticCacheConfig(CacheConfig): + """ + Configuration class for static cache settings. + """ + + cache_implementation = "static" + + def __init__(self, batch_size: int, max_cache_len: int, device="cpu"): + self.batch_size = batch_size + self.max_cache_len = max_cache_len + self.device = device + + def validate(self): + """Validates if the arguments passed are correct""" + + incorrect_arg_msg = ( + "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " + "but found {found_value}" + ) + + if self.batch_size <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="batch_size", + correct_value="> 0", + found_value=self.batch_size, + ), + ) + + if self.max_cache_len <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="max_cache_len", + correct_value="> 0", + found_value=self.max_cache_len, + ), + ) + + class DynamicCache(Cache): """ A cache that grows dynamically as more tokens are generated. This is the default for generative models. diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 773ef0ccfe55..d7571cbb5dc0 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -57,9 +57,11 @@ QuantoQuantizedCache, SlidingWindowCache, StaticCache, + StaticCacheConfig, ) NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig + NEEDS_CACHE_CONFIG["static"] = StaticCacheConfig NEED_SETUP_CACHE_CLASSES_MAPPING = { "static": StaticCache, "offloaded_static": OffloadedStaticCache, diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 4c756a23ae0a..0a28ff022a53 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import TYPE_CHECKING -from ..utils import _LazyModule +from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available _import_structure = { @@ -98,6 +98,17 @@ "quanto": ["replace_with_quanto_layers"], } +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["executorch"] = [ + "TorchExportableModuleWithStaticCache", + "convert_and_export_with_cache", + ] + if TYPE_CHECKING: from .aqlm import replace_with_aqlm_linear from .awq import ( @@ -178,6 +189,15 @@ ) from .peft import PeftAdapterMixin from .quanto import replace_with_quanto_layers + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .executorch import TorchExportableModuleWithStaticCache, convert_and_export_with_cache + else: import sys diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py new file mode 100644 index 000000000000..afcba5ebd069 --- /dev/null +++ b/src/transformers/integrations/executorch.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch + +from transformers import ( + PreTrainedModel, + StaticCache, +) +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3 + + +class TorchExportableModuleWithStaticCache(torch.nn.Module): + """ + A wrapper module designed to make a `PreTrainedModel` exportable with `torch.export`, + specifically for use with static caching. This module ensures that the exported model + is compatible with further lowering and execution in `ExecuTorch`. + + Note: + This class is specifically designed to support export process using `torch.export` + in a way that ensures the model can be further lowered and run efficiently in `ExecuTorch`. + """ + + def __init__(self, model: PreTrainedModel): + """ + Initializes the wrapper module with the pretrained model. + + Args: + model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching + enabled and use a 'static' caching implementation. + + Raises: + AssertionError: If the pretrained model does not have caching enabled or if it does + not use a 'static' caching implementation in `model.generation_config`. + """ + super().__init__() + + # Sanity checks + if model.generation_config is None: + raise AssertionError( + "The model must have a generation config to be exported with static caching. " + "Please set `generation_config`." + ) + + if not model.generation_config.use_cache: + raise AssertionError( + "The model must have caching enabled to be exported with static caching. " + "Please set `generation_config.use_cache=True`." + ) + + if model.generation_config.cache_implementation != "static": + raise AssertionError( + "The model must use a 'static' caching implementation to be exported with static caching. " + "Please set `generation_config.cache_implementation='static'`." + ) + + self.model = model + self.static_cache = StaticCache( + config=self.model.config, + batch_size=self.model.generation_config.cache_config.batch_size, + max_cache_len=self.model.generation_config.cache_config.max_cache_len, + dtype=self.model.config.torch_dtype, + ) + self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures) + if self.is_causal: + causal_mask = torch.tril( + torch.ones( + self.static_cache.max_cache_len, + self.static_cache.max_cache_len, + dtype=torch.bool, + ) + ) + self.register_buffer("mask", causal_mask, persistent=False) + + def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor): + """ + Forward pass of the module, which is compatible with the ExecuTorch runtime. + + Args: + input_ids (`torch.Tensor`): Tensor representing current input token id to the module. + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + + Returns: + torch.Tensor: Logits output from the model. + + This forward adapter serves two primary purposes: + + 1. **Making the Model `torch.export`-Compatible**: + The adapter hides unsupported objects, such as the `Cache`, from the graph inputs and outputs, + enabling the model to be exportable using `torch.export` without encountering issues. + + 2. **Ensuring Compatibility with `ExecuTorch` runtime**: + The adapter matches the model's forward signature with that in `executorch/extension/llm/runner`, + ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box. + """ + _, seqlen = input_ids.shape + attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None + outs = self.model( + input_ids=input_ids, + attention_mask=attn_mask, + position_ids=cache_position.unsqueeze(0), + cache_position=cache_position, + past_key_values=self.static_cache, + use_cache=True, + ) + return outs.logits + + +def convert_and_export_with_cache( + model: PreTrainedModel, + example_input_ids: torch.Tensor = None, + example_cache_position: torch.Tensor = None, +): + """ + Convert a `PreTrainedModel` into an exportable module and export it using `torch.export`, + ensuring the exported model is compatible with `ExecuTorch`. + + Args: + model (`PreTrainedModel`): The pretrained model to be exported. + example_input_ids (`torch.Tensor`): Example input token id used by `torch.export`. + example_cache_position (`torch.Tensor`): Example current cache position used by `torch.export`. + + Returns: + Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`. + """ + + if not is_torch_greater_or_equal_than_2_3: + raise ImportError("torch >= 2.3 is required.") + + import torch.export._trace + + with torch.no_grad(): + # TODO: The default inputs only work for text models. We need to add support for vision/audio models. + example_input_ids = ( + example_input_ids if example_input_ids is not None else torch.tensor([[1]], dtype=torch.long) + ) + example_cache_position = ( + example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long) + ) + + # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal + # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release. + exported_program = torch.export._trace._export( + TorchExportableModuleWithStaticCache(model), + args=(example_input_ids,), + kwargs={"cache_position": example_cache_position}, + pre_dispatch=False, + strict=True, + ) + return exported_program diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2faa60210ed4..359509f469a7 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3223,6 +3223,7 @@ def from_pretrained( adapter_kwargs = kwargs.pop("adapter_kwargs", {}) adapter_name = kwargs.pop("adapter_name", "default") use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) + generation_config = kwargs.pop("generation_config", None) gguf_file = kwargs.pop("gguf_file", None) # Cache path to the GGUF file @@ -3998,7 +3999,10 @@ def from_pretrained( model.eval() # If it is a model with generation capabilities, attempt to load the generation config - if model.can_generate() and pretrained_model_name_or_path is not None: + if model.can_generate() and generation_config is not None: + logger.info("The user-defined `generation_config` will be used to override the default generation config.") + model.generation_config = model.generation_config.from_dict(generation_config.to_dict()) + elif model.can_generate() and pretrained_model_name_or_path is not None: try: model.generation_config = GenerationConfig.from_pretrained( pretrained_model_name_or_path, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index ddf7608155e8..276b06760ae3 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -513,6 +513,17 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class TorchExportableModuleWithStaticCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +def convert_and_export_with_cache(*args, **kwargs): + requires_backends(convert_and_export_with_cache, ["torch"]) + + ROPE_INIT_FUNCTIONS = None diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index fb5459be10cb..6ab821231fd5 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -16,7 +16,6 @@ import copy import unittest -from packaging import version from parameterized import parameterized from transformers import set_seed @@ -35,7 +34,6 @@ import torch from transformers import ( - AutoConfig, AutoModelForCausalLM, AutoTokenizer, DynamicCache, @@ -44,7 +42,9 @@ LlamaConfig, SinkCache, StaticCache, + convert_and_export_with_cache, ) + from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3 @require_torch @@ -175,61 +175,54 @@ def test_static_cache_exportability(self): """ Tests that static cache works with `torch.export()` """ - import torch - - if version.parse(torch.__version__) < version.parse("2.3"): + if not is_torch_greater_or_equal_than_2_3: self.skipTest(reason="This test requires torch >= 2.3 to run.") + set_seed(0) device = "cpu" dtype = torch.float32 + cache_implementation = "static" + attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention batch_size = 1 - - config = AutoConfig.from_pretrained( + max_cache_len = 1234 + model = AutoModelForCausalLM.from_pretrained( "google/gemma-2b", + device_map=device, torch_dtype=dtype, - use_cache=True, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_cache_len, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_cache_len, + }, + ), ) - m = AutoModelForCausalLM.from_pretrained( - "google/gemma-2b", - config=config, - torch_dtype=dtype, - attn_implementation="sdpa", # Export and ExecuTorch only works for SdpaAttention - ).to(device) - tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") - inputs = tokenizer(["The best color is"], return_tensors="pt").to(device)["input_ids"] - - class ExportatibleModelWithStaticCache(torch.nn.Module): - def __init__(self, config, model): - super().__init__() - self.config = config - self.model = model - self.static_cache = StaticCache( - config=config, batch_size=batch_size, max_cache_len=config.max_length, device=device - ) - - def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor): - outs = self.model( - input_ids=tokens, - attention_mask=None, - position_ids=input_pos.unsqueeze(0), - cache_position=input_pos, - past_key_values=self.static_cache, - use_cache=True, - ) - return outs.logits - - set_seed(0) - with torch.no_grad(): - import torch.export._trace - from torch.export import ExportedProgram - - model = ExportatibleModelWithStaticCache(config, m) - # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal - # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.4.1+ release. - exported_program = torch.export._trace._export( - model, args=(inputs,), kwargs={"input_pos": torch.arange(1)}, pre_dispatch=False, strict=True - ) - self.assertTrue(isinstance(exported_program, ExportedProgram)) + # Check if cache config is passed through correctly + self.assertEqual(model.generation_config.use_cache, True) + self.assertEqual(model.generation_config.cache_implementation, cache_implementation) + self.assertEqual(model.generation_config.max_length, max_cache_len) + self.assertTrue(model.generation_config.cache_config is not None) + self.assertEqual(model.generation_config.cache_config.batch_size, batch_size) + self.assertEqual(model.generation_config.cache_config.max_cache_len, max_cache_len) + + exported_program = convert_and_export_with_cache(model) + + # Check if the exported model is configured with the `StaticCache` correctly + n_static_key_caches = n_static_value_caches = 0 + for buffer_name, buffer in exported_program.named_buffers(): + if buffer_name.startswith("static_cache.key_cache"): + self.assertTrue(buffer.shape[0] == batch_size) + self.assertTrue(buffer.shape[2] == max_cache_len) + n_static_key_caches = n_static_key_caches + 1 + if buffer_name.startswith("static_cache.value_cache"): + self.assertTrue(buffer.shape[0] == batch_size) + self.assertTrue(buffer.shape[2] == max_cache_len) + n_static_value_caches = n_static_value_caches + 1 + self.assertEqual(n_static_key_caches, model.config.num_hidden_layers) + self.assertEqual(n_static_value_caches, model.config.num_hidden_layers) @require_torch_gpu