From dc53695963c26b0601f726829dc06afa7bfe19f7 Mon Sep 17 00:00:00 2001 From: calpt Date: Fri, 12 Jul 2024 14:57:49 +0200 Subject: [PATCH] Upgrade Transformers to v4.41.x (#712) Changes needed for sync: - BERT/ ViT: Copy & adapt new sdpa attention classes - Update copied `_prepare_encoder_decoder_kwargs_for_generation` in model mixin - Adjust 2dim attention masks for prompt tuning --- hf_transformers | 2 +- setup.py | 2 +- src/adapters/methods/prefix_tuning.py | 4 +- src/adapters/model_mixin.py | 21 ++++- src/adapters/models/bert/modeling_bert.py | 107 +++++++++++++++++++++- src/adapters/models/vit/modeling_vit.py | 34 ++++++- src/adapters/utils.py | 29 +++--- 7 files changed, 177 insertions(+), 22 deletions(-) diff --git a/hf_transformers b/hf_transformers index 4fdf58afb7..ab0f050b42 160000 --- a/hf_transformers +++ b/hf_transformers @@ -1 +1 @@ -Subproject commit 4fdf58afb72b0754da30037fc800b6044e7d9c99 +Subproject commit ab0f050b42d903f34d6eb97f3f8c0c07f0517ad2 diff --git a/setup.py b/setup.py index 8993926c6d..2f3157ad2d 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,7 @@ "sphinx-multiversion==0.2.4", "timeout-decorator", "torch>=1.10,!=1.12.0", - "transformers~=4.40.2", + "transformers~=4.41.2", ] diff --git a/src/adapters/methods/prefix_tuning.py b/src/adapters/methods/prefix_tuning.py index 5e98ca266c..5303760bd9 100644 --- a/src/adapters/methods/prefix_tuning.py +++ b/src/adapters/methods/prefix_tuning.py @@ -430,10 +430,8 @@ def pad_and_concat(self, states: List[PrefixTuningState]) -> PrefixTuningState: value_states = F.pad(value_states, pad_size, "constant", self.model_config.pad_token_id) # pad attention mask - if pad_length > 0: + if pad_length > 0 and attention_mask is not None: # Masking the padded tokens only works correctly if attention_mask is set - # We assume this to be the case at this point - assert attention_mask is not None, "Attention mask must be set for prefix tuning" attention_mask = F.pad( attention_mask, (max_prefix_length - attention_mask.shape[-1], 0), diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 8988af61a4..50d846075f 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -9,7 +9,9 @@ import torch from torch import nn +from transformers import GenerationConfig from transformers.modeling_outputs import ModelOutput +from transformers.utils import is_accelerate_available from .composition import AdapterCompositionBlock, Fuse, Stack, parse_composition from .configuration import ADAPTER_CONFIG_MAP, AdapterConfig, AdapterFusionConfig, BnConfig @@ -29,6 +31,9 @@ logger = logging.getLogger(__name__) +if is_accelerate_available(): + from accelerate.hooks import AlignDevicesHook, add_hook_to_module + class InvertibleAdaptersMixin: """Mixin for Transformer models adding invertible adapters.""" @@ -1263,10 +1268,21 @@ def reset_adapter(self): # HACK Copied from transformers/generation/utils.py def _prepare_encoder_decoder_kwargs_for_generation( - self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None + self, + inputs_tensor: torch.Tensor, + model_kwargs, + model_input_name: Optional[str], + generation_config: GenerationConfig, ) -> Dict[str, Any]: # 1. get encoder encoder = self.get_encoder() + # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device + # as the inputs. + if hasattr(self, "hf_device_map"): + if hasattr(encoder, "_hf_hook"): + encoder._hf_hook.io_same_device = True + else: + add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True)) # 2. prepare encoder args and encoder kwargs from model kwargs irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] @@ -1275,7 +1291,6 @@ def _prepare_encoder_decoder_kwargs_for_generation( for argument, value in model_kwargs.items() if not any(argument.startswith(p) for p in irrelevant_prefix) } - encoder_signature = set(inspect.signature(encoder.forward).parameters) encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature if not encoder_accepts_wildcard: @@ -1284,6 +1299,8 @@ def _prepare_encoder_decoder_kwargs_for_generation( for argument, value in encoder_kwargs.items() if argument in encoder_signature or argument == "adapter_input_parallelized" } + encoder_kwargs["output_attentions"] = generation_config.output_attentions + encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states # 3. make sure that encoder returns `ModelOutput` model_input_name = model_input_name if model_input_name is not None else self.main_input_name diff --git a/src/adapters/models/bert/modeling_bert.py b/src/adapters/models/bert/modeling_bert.py index ea60b6f5dc..de860151e4 100644 --- a/src/adapters/models/bert/modeling_bert.py +++ b/src/adapters/models/bert/modeling_bert.py @@ -23,13 +23,17 @@ import torch.utils.checkpoint from torch import nn -from transformers.models.bert.modeling_bert import BertOutput, BertSelfAttention, BertSelfOutput +from transformers.models.bert.modeling_bert import BertOutput, BertSdpaSelfAttention, BertSelfAttention, BertSelfOutput +from transformers.utils import logging from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel from ...utils import prefix_attention_mask from .mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin +logger = logging.get_logger(__name__) + + class BertSelfAttentionWithAdapters(BertSelfAttentionAdaptersMixin, BertSelfAttention): def forward( self, @@ -142,6 +146,107 @@ def forward( return outputs +class BertSdpaSelfAttentionWithAdapters(BertSelfAttentionAdaptersMixin, BertSdpaSelfAttention): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + attention_mask = prefix_attention_mask(attention_mask, [2, 3]) # type: ignore + + if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + logger.warning_once( + "BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support" + " non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to" + " the manual attention implementation, but specifying the manual implementation will be required from" + " Transformers version v5.0.0 onwards. This warning can be removed using the argument" + ' `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + bsz, tgt_len, _ = hidden_states.size() + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer) + (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + key_layer, value_layer, attention_mask = self.prefix_tuning( + key_layer, value_layer, hidden_states, attention_mask + ) + (query_layer,) = adjust_tensors_for_parallel(key_layer, query_layer) + bsz = query_layer.size(0) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal + # mask in case tgt_len == 1. + is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + class BertSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, BertSelfOutput): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) diff --git a/src/adapters/models/vit/modeling_vit.py b/src/adapters/models/vit/modeling_vit.py index f8c02bd931..3a4460a658 100644 --- a/src/adapters/models/vit/modeling_vit.py +++ b/src/adapters/models/vit/modeling_vit.py @@ -23,7 +23,7 @@ from torch import nn from adapters.composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel -from transformers.models.vit.modeling_vit import ViTLayer, ViTOutput, ViTSelfAttention +from transformers.models.vit.modeling_vit import ViTLayer, ViTOutput, ViTSdpaSelfAttention, ViTSelfAttention from .mixin_vit import ViTLayerAdaptersMixin, ViTOutputAdaptersMixin, ViTSelfAttentionAdaptersMixin @@ -70,6 +70,38 @@ def forward( return outputs +class ViTSdpaSelfAttentionWithAdapters(ViTSelfAttentionAdaptersMixin, ViTSdpaSelfAttention): + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer) + + key_layer, value_layer, _ = self.prefix_tuning(key_layer, value_layer, hidden_states) + (query_layer,) = adjust_tensors_for_parallel(key_layer, query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + class ViTOutputWithAdapters(ViTOutputAdaptersMixin, ViTOutput): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) diff --git a/src/adapters/utils.py b/src/adapters/utils.py index 7338f4c3ac..bbcbb509a2 100644 --- a/src/adapters/utils.py +++ b/src/adapters/utils.py @@ -865,7 +865,7 @@ def get_adapter_info(adapter_id: str, source: str = "ah") -> Optional[AdapterInf raise ValueError("Please specify either 'ah' or 'hf' as source.") -def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0): +def prefix_attention_mask(attention_mask, dim: Union[int, List[int]] = 3, prefix_value: int = 0): """ Adds a prefix to an attention mask. The length of the prefix is determined by the `prefix_attention_mask_length` attribute in the ForwardContext. @@ -890,18 +890,21 @@ def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0): and forward_context is not None and getattr(forward_context, "prompt_tokens_length", None) is not None ): - # Create a tensor of ones with the desired shape - ones_shape = list(attention_mask.shape) - ones_shape[dim] = forward_context.prompt_tokens_length - - prefix_attention_mask = torch.full( - ones_shape, - prefix_value, - dtype=attention_mask.dtype, - ).to(attention_mask.device) - - # Concatenate the prefix_attention_mask along the specified dimension - attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=dim) + if isinstance(dim, int): + dim = [dim] + for d in dim: + # Create a tensor of ones with the desired shape + ones_shape = list(attention_mask.shape) + ones_shape[d] = forward_context.prompt_tokens_length + + prefix_attention_mask = torch.full( + ones_shape, + prefix_value, + dtype=attention_mask.dtype, + ).to(attention_mask.device) + + # Concatenate the prefix_attention_mask along the specified dimension + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=d) return attention_mask