|
19 | 19 | # See the License for the specific language governing permissions and |
20 | 20 | # limitations under the License. |
21 | 21 | from dataclasses import dataclass |
22 | | -from functools import partial |
23 | 22 | from typing import Callable, List, Optional, Tuple, Union |
24 | 23 |
|
25 | 24 | from ...activations import ACT2FN |
|
28 | 27 | from ...integrations import use_kernel_forward_from_hub |
29 | 28 | from ...modeling_attn_mask_utils import AttentionMaskConverter |
30 | 29 | from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| 30 | +from ...modeling_layers import GradientCheckpointingLayer |
31 | 31 | from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput |
32 | 32 | from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
33 | 33 | from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
@@ -590,7 +590,7 @@ def forward( |
590 | 590 | return attn_output, attn_weights |
591 | 591 |
|
592 | 592 |
|
593 | | -class AriaTextDecoderLayer(nn.Module): |
| 593 | +class AriaTextDecoderLayer(GradientCheckpointingLayer): |
594 | 594 | """ |
595 | 595 | Aria Text Decoder Layer. |
596 | 596 |
|
@@ -940,30 +940,17 @@ def forward( |
940 | 940 | if output_hidden_states: |
941 | 941 | all_hidden_states += (hidden_states,) |
942 | 942 |
|
943 | | - if self.gradient_checkpointing and self.training: |
944 | | - layer_outputs = self._gradient_checkpointing_func( |
945 | | - partial(decoder_layer.__call__, **flash_attn_kwargs), |
946 | | - hidden_states, |
947 | | - causal_mask, |
948 | | - position_ids, |
949 | | - past_key_values, |
950 | | - output_attentions, |
951 | | - use_cache, |
952 | | - cache_position, |
953 | | - position_embeddings, |
954 | | - ) |
955 | | - else: |
956 | | - layer_outputs = decoder_layer( |
957 | | - hidden_states, |
958 | | - attention_mask=causal_mask, |
959 | | - position_ids=position_ids, |
960 | | - past_key_value=past_key_values, |
961 | | - output_attentions=output_attentions, |
962 | | - use_cache=use_cache, |
963 | | - cache_position=cache_position, |
964 | | - position_embeddings=position_embeddings, |
965 | | - **flash_attn_kwargs, |
966 | | - ) |
| 943 | + layer_outputs = decoder_layer( |
| 944 | + hidden_states, |
| 945 | + attention_mask=causal_mask, |
| 946 | + position_ids=position_ids, |
| 947 | + past_key_value=past_key_values, |
| 948 | + output_attentions=output_attentions, |
| 949 | + use_cache=use_cache, |
| 950 | + cache_position=cache_position, |
| 951 | + position_embeddings=position_embeddings, |
| 952 | + **flash_attn_kwargs, |
| 953 | + ) |
967 | 954 |
|
968 | 955 | hidden_states = layer_outputs[0] |
969 | 956 |
|
|
0 commit comments