Skip to content

Commit 4e6e459

Browse files
ArthurZuckervasqumolbap
authored andcommitted
Refactor the way we handle outputs for new llamas and new models (huggingface#39120)
* just update 2 files * update other models as well just making fix-copies * also add the changes needed to modeling utils * put this on the pretrained model instead * nits and fixes * update generic, fix to use config value * update other modelings * use transformers kwargs instead * update * update * update other models * update * updates * update * update * update * fix * finally * very small nits * this fixes more tests * fix other models as well! * update modularqwen2 * update models based on qwen2 * update * update * remove the **flash stuff in favor of noraml kwargs * update * propagate gemma? * remove output attentions * propagate * support cross attention edge case * same * test this * fixes * more fix * update * update * fix conflicts * update * fix emu3 * fix emu3 * move the fix a bit * quel enfer * some fixes, loss_kwargs should never had been * finish fixing gemma3n * fix small lm3 * fix another one * fix csm now * fux csm and mistral * fix mistral now * small fixes * fix janusss * only for some models * fixup * phix phi3 * more fixes? * dose this fix it? * update * holy shit it was just graph breaks * protect torch * updates * fix samhq? * fix moonshine * more moonshine fixes, 3 failures left! * nits * generic needs to support more * more fixes to moonshine! * fix cross attention outputs! * fix csm! * nits * fix stupid kosmos2 * current updates * fixes * use output recorder? * nicer! * a little bit of magic * update * fix protect * fix * small fixes * protect import * fix a bunch of more models * fix fixups * fix some of the last ones * nit * partly fix phi * update * fix import path * make something that is fullgraph compatible just to be sure * typing was wrong on llama so the rest was wrong as well * fucking ugly but at least it is still exportable * syle * supposed to fix moonshine, it still breaks * fix some default * fix the last bits of sam * update samhq * more fixes to am hq * nit * fix all output+hidden states and output_attentions! * fix? * fix diffllama * updates to fix initialization on the sam pips * ups there was a bug * fix the last sam hq test * fix gotocr * fix gotocr2! * fixes * skip stupid tests * there was one left :) * fixup * fix fix copies issues with this test file * fix copies for sam_hq * rm some comments * skip 2 more failing tests * fix * fix everything * Apply suggestions from code review Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * add more doc! * fix public init * fix modular qwen3 --------- Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
1 parent ef7f972 commit 4e6e459

File tree

145 files changed

+1994
-5892
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

145 files changed

+1994
-5892
lines changed

examples/modular-transformers/modeling_dummy.py

Lines changed: 0 additions & 446 deletions
This file was deleted.

examples/modular-transformers/modeling_multimodal1.py

Lines changed: 0 additions & 446 deletions
This file was deleted.

examples/modular-transformers/modular_dummy.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

examples/modular-transformers/modular_multimodal1.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

src/transformers/modeling_utils.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@
123123
logging,
124124
strtobool,
125125
)
126-
from .utils.generic import GeneralInterface
126+
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
127127
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
128128
from .utils.import_utils import (
129129
ENV_VARS_TRUE_VALUES,
@@ -1925,7 +1925,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
19251925
- **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
19261926
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
19271927
models, `pixel_values` for vision models and `input_values` for speech models).
1928-
"""
1928+
- **can_record_outputs** (dict):"""
19291929

19301930
config_class = None
19311931
base_model_prefix = ""
@@ -2006,6 +2006,50 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
20062006
# In practice, it means that they support attention interface functions, fully pass the kwargs
20072007
# through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
20082008
_supports_attention_backend = False
2009+
_can_record_outputs = None
2010+
2011+
@property
2012+
@torch._dynamo.allow_in_graph
2013+
def can_record_outputs(self) -> dict[str, OutputRecorder]:
2014+
"""
2015+
Maps output names (e.g., "attentions", "hidden_states")
2016+
to either:
2017+
- A module class (e.g., `LlamaDecoderLayer`), using default index conventions:
2018+
* index=0 for "hidden_states"
2019+
* index=1 for "attentions"
2020+
- Or an `OutputRecorder(...)` with `target_class`, optional `index`, and `layer_name`.
2021+
2022+
Examples:
2023+
These two are equivalent:
2024+
2025+
```python
2026+
_can_record_outputs = {
2027+
"attentions": LlamaAttention,
2028+
"hidden_states": LlamaDecoderLayer
2029+
}
2030+
2031+
_can_record_outputs = {
2032+
"attentions": OutputRecorder(LlamaAttention, index=1),
2033+
"hidden_states": OutputRecorder(LlamaDecoderLayer, index=0)
2034+
}
2035+
```
2036+
2037+
This means you can record outputs from the same class, by specifying a layer name. Before
2038+
collecting outputs, we check that they come from this layer.
2039+
2040+
If you have cross attention that come from `LlamaAttention` and self attention that also
2041+
come from `LlamaAttention` but from `self_attn` you can do this:
2042+
2043+
```python
2044+
class LlamaModel(PreTrainedModel):
2045+
_can_record_outputs = {
2046+
"attentions": OutputRecorder(LlamaAttention, index=1, layer-name="self_attn"),
2047+
"cross_attentions": OutputRecorder(LlamaAttention, index=1, layer_name="cross_attn")
2048+
}
2049+
2050+
```
2051+
"""
2052+
return self._can_record_outputs or {}
20092053

20102054
@property
20112055
def dummy_inputs(self) -> dict[str, torch.Tensor]:
@@ -2056,6 +2100,7 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
20562100
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
20572101

20582102
self._no_split_modules = self._no_split_modules or []
2103+
_CAN_RECORD_REGISTRY[self] = self._can_record_outputs # added for executorch support only
20592104

20602105
def post_init(self):
20612106
"""

0 commit comments

Comments
 (0)