Skip to content

Commit 9167fad

Browse files
authored
Introduce GradientCheckpointingLayer (#37223)
* GradientCheckpointingLayer * trigger * Move GC layer to a separate file * Update import * Expose and document GC layer * Fix dummy * Apply to llama-based models * Update modulars * Update a few more models for consistency * Update glm4 * Update Janus
1 parent 413f9bb commit 9167fad

35 files changed

+435
-761
lines changed

docs/source/en/internal/modeling_utils.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ This page lists all the custom layers used by the library, as well as the utilit
2020

2121
Most of those are only useful if you are studying the code of the models in the library.
2222

23+
## Layers
24+
25+
[[autodoc]] GradientCheckpointingLayer
26+
2327
## Attention Functions
2428

2529
[[autodoc]] AttentionInterface

src/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,7 @@
438438
]
439439

440440
_import_structure["modeling_flash_attention_utils"] = []
441+
_import_structure["modeling_layers"] = ["GradientCheckpointingLayer"]
441442
_import_structure["modeling_outputs"] = []
442443
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update"]
443444
_import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"]
@@ -911,6 +912,7 @@
911912
from .model_debugging_utils import (
912913
model_addition_debugger_context,
913914
)
915+
from .modeling_layers import GradientCheckpointingLayer
914916
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
915917
from .modeling_utils import AttentionInterface, PreTrainedModel
916918

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from functools import partial
16+
17+
import torch.nn as nn
18+
19+
20+
class GradientCheckpointingLayer(nn.Module):
21+
"""Base class for layers with gradient checkpointing.
22+
23+
This class enables gradient checkpointing functionality for a layer. By default, gradient checkpointing is disabled
24+
(`gradient_checkpointing = False`). When `model.set_gradient_checkpointing()` is called, gradient checkpointing is
25+
enabled by setting `gradient_checkpointing = True` and assigning a checkpointing function to `_gradient_checkpointing_func`.
26+
27+
Important:
28+
29+
When using gradient checkpointing with `use_reentrant=True`, inputs that require gradients (e.g. hidden states)
30+
must be passed as positional arguments (`*args`) rather than keyword arguments to properly propagate gradients.
31+
32+
Example:
33+
34+
```python
35+
>>> # Correct - hidden_states passed as positional arg
36+
>>> out = self.layer(hidden_states, attention_mask=attention_mask)
37+
38+
>>> # Incorrect - hidden_states passed as keyword arg
39+
>>> out = self.layer(hidden_states=hidden_states, attention_mask=attention_mask)
40+
```
41+
"""
42+
43+
gradient_checkpointing = False
44+
45+
def __call__(self, *args, **kwargs):
46+
if self.gradient_checkpointing and self.training:
47+
return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
48+
return super().__call__(*args, **kwargs)

src/transformers/models/aria/modeling_aria.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
# See the License for the specific language governing permissions and
2020
# limitations under the License.
2121
from dataclasses import dataclass
22-
from functools import partial
2322
from typing import Callable, List, Optional, Tuple, Union
2423

2524
from ...activations import ACT2FN
@@ -28,6 +27,7 @@
2827
from ...integrations import use_kernel_forward_from_hub
2928
from ...modeling_attn_mask_utils import AttentionMaskConverter
3029
from ...modeling_flash_attention_utils import FlashAttentionKwargs
30+
from ...modeling_layers import GradientCheckpointingLayer
3131
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
3232
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
3333
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@@ -590,7 +590,7 @@ def forward(
590590
return attn_output, attn_weights
591591

592592

593-
class AriaTextDecoderLayer(nn.Module):
593+
class AriaTextDecoderLayer(GradientCheckpointingLayer):
594594
"""
595595
Aria Text Decoder Layer.
596596
@@ -940,30 +940,17 @@ def forward(
940940
if output_hidden_states:
941941
all_hidden_states += (hidden_states,)
942942

943-
if self.gradient_checkpointing and self.training:
944-
layer_outputs = self._gradient_checkpointing_func(
945-
partial(decoder_layer.__call__, **flash_attn_kwargs),
946-
hidden_states,
947-
causal_mask,
948-
position_ids,
949-
past_key_values,
950-
output_attentions,
951-
use_cache,
952-
cache_position,
953-
position_embeddings,
954-
)
955-
else:
956-
layer_outputs = decoder_layer(
957-
hidden_states,
958-
attention_mask=causal_mask,
959-
position_ids=position_ids,
960-
past_key_value=past_key_values,
961-
output_attentions=output_attentions,
962-
use_cache=use_cache,
963-
cache_position=cache_position,
964-
position_embeddings=position_embeddings,
965-
**flash_attn_kwargs,
966-
)
943+
layer_outputs = decoder_layer(
944+
hidden_states,
945+
attention_mask=causal_mask,
946+
position_ids=position_ids,
947+
past_key_value=past_key_values,
948+
output_attentions=output_attentions,
949+
use_cache=use_cache,
950+
cache_position=cache_position,
951+
position_embeddings=position_embeddings,
952+
**flash_attn_kwargs,
953+
)
967954

968955
hidden_states = layer_outputs[0]
969956

src/transformers/models/cohere/modeling_cohere.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
# This file is based on the LLama model definition file in transformers
2828

2929

30-
from functools import partial
3130
from typing import Callable, List, Optional, Tuple, Union
3231

3332
import torch
@@ -38,6 +37,7 @@
3837
from ...generation import GenerationMixin
3938
from ...modeling_attn_mask_utils import AttentionMaskConverter
4039
from ...modeling_flash_attention_utils import FlashAttentionKwargs
40+
from ...modeling_layers import GradientCheckpointingLayer
4141
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
4242
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
4343
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@@ -301,7 +301,7 @@ def forward(
301301
return attn_output, attn_weights
302302

303303

304-
class CohereDecoderLayer(nn.Module):
304+
class CohereDecoderLayer(GradientCheckpointingLayer):
305305
def __init__(self, config: CohereConfig, layer_idx: int):
306306
super().__init__()
307307
self.hidden_size = config.hidden_size
@@ -589,30 +589,17 @@ def forward(
589589
if output_hidden_states:
590590
all_hidden_states += (hidden_states,)
591591

592-
if self.gradient_checkpointing and self.training:
593-
layer_outputs = self._gradient_checkpointing_func(
594-
partial(decoder_layer.__call__, **flash_attn_kwargs),
595-
hidden_states,
596-
causal_mask,
597-
position_ids,
598-
past_key_values,
599-
output_attentions,
600-
use_cache,
601-
cache_position,
602-
position_embeddings,
603-
)
604-
else:
605-
layer_outputs = decoder_layer(
606-
hidden_states,
607-
attention_mask=causal_mask,
608-
position_ids=position_ids,
609-
past_key_value=past_key_values,
610-
output_attentions=output_attentions,
611-
use_cache=use_cache,
612-
cache_position=cache_position,
613-
position_embeddings=position_embeddings,
614-
**flash_attn_kwargs,
615-
)
592+
layer_outputs = decoder_layer(
593+
hidden_states,
594+
attention_mask=causal_mask,
595+
position_ids=position_ids,
596+
past_key_value=past_key_values,
597+
output_attentions=output_attentions,
598+
use_cache=use_cache,
599+
cache_position=cache_position,
600+
position_embeddings=position_embeddings,
601+
**flash_attn_kwargs,
602+
)
616603

617604
hidden_states = layer_outputs[0]
618605

src/transformers/models/cohere/modular_cohere.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from ...cache_utils import Cache
3232
from ...modeling_flash_attention_utils import FlashAttentionKwargs
33+
from ...modeling_layers import GradientCheckpointingLayer
3334
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
3435
from ...modeling_rope_utils import dynamic_rope_update
3536
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
@@ -209,7 +210,7 @@ def forward(
209210
return attn_output, attn_weights
210211

211212

212-
class CohereDecoderLayer(nn.Module):
213+
class CohereDecoderLayer(GradientCheckpointingLayer):
213214
def __init__(self, config: CohereConfig, layer_idx: int):
214215
super().__init__()
215216
self.hidden_size = config.hidden_size

src/transformers/models/cohere2/modeling_cohere2.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2020
# See the License for the specific language governing permissions and
2121
# limitations under the License.
22-
from functools import partial
2322
from typing import Callable, List, Optional, Tuple, Union
2423

2524
import torch
@@ -29,6 +28,7 @@
2928
from ...cache_utils import Cache, HybridCache, StaticCache
3029
from ...generation import GenerationMixin
3130
from ...modeling_flash_attention_utils import FlashAttentionKwargs
31+
from ...modeling_layers import GradientCheckpointingLayer
3232
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
3333
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
3434
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@@ -290,7 +290,7 @@ def forward(self, x):
290290
return down_proj
291291

292292

293-
class Cohere2DecoderLayer(nn.Module):
293+
class Cohere2DecoderLayer(GradientCheckpointingLayer):
294294
def __init__(self, config: Cohere2Config, layer_idx: int):
295295
super().__init__()
296296
self.hidden_size = config.hidden_size
@@ -612,28 +612,16 @@ def forward(
612612
if output_hidden_states:
613613
all_hidden_states += (hidden_states,)
614614

615-
if self.gradient_checkpointing and self.training:
616-
layer_outputs = self._gradient_checkpointing_func(
617-
partial(decoder_layer.__call__, **flash_attn_kwargs),
618-
hidden_states,
619-
position_embeddings,
620-
causal_mask,
621-
past_key_values,
622-
output_attentions,
623-
use_cache,
624-
cache_position,
625-
)
626-
else:
627-
layer_outputs = decoder_layer(
628-
hidden_states,
629-
position_embeddings=position_embeddings,
630-
attention_mask=causal_mask,
631-
past_key_value=past_key_values,
632-
output_attentions=output_attentions,
633-
use_cache=use_cache,
634-
cache_position=cache_position,
635-
**flash_attn_kwargs,
636-
)
615+
layer_outputs = decoder_layer(
616+
hidden_states,
617+
position_embeddings=position_embeddings,
618+
attention_mask=causal_mask,
619+
past_key_value=past_key_values,
620+
output_attentions=output_attentions,
621+
use_cache=use_cache,
622+
cache_position=cache_position,
623+
**flash_attn_kwargs,
624+
)
637625

638626
hidden_states = layer_outputs[0]
639627

src/transformers/models/cohere2/modular_cohere2.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16-
from functools import partial
1716
from typing import Callable, Optional, Tuple
1817

1918
import torch
@@ -526,28 +525,16 @@ def forward(
526525
if output_hidden_states:
527526
all_hidden_states += (hidden_states,)
528527

529-
if self.gradient_checkpointing and self.training:
530-
layer_outputs = self._gradient_checkpointing_func(
531-
partial(decoder_layer.__call__, **flash_attn_kwargs),
532-
hidden_states,
533-
position_embeddings,
534-
causal_mask,
535-
past_key_values,
536-
output_attentions,
537-
use_cache,
538-
cache_position,
539-
)
540-
else:
541-
layer_outputs = decoder_layer(
542-
hidden_states,
543-
position_embeddings=position_embeddings,
544-
attention_mask=causal_mask,
545-
past_key_value=past_key_values,
546-
output_attentions=output_attentions,
547-
use_cache=use_cache,
548-
cache_position=cache_position,
549-
**flash_attn_kwargs,
550-
)
528+
layer_outputs = decoder_layer(
529+
hidden_states,
530+
position_embeddings=position_embeddings,
531+
attention_mask=causal_mask,
532+
past_key_value=past_key_values,
533+
output_attentions=output_attentions,
534+
use_cache=use_cache,
535+
cache_position=cache_position,
536+
**flash_attn_kwargs,
537+
)
551538

552539
hidden_states = layer_outputs[0]
553540

0 commit comments

Comments
 (0)