diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 8f05e71a460d..25dfc817faf9 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -19,6 +19,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F import torch.utils.checkpoint from ...activations import ACT2FN @@ -32,6 +33,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ..clip.cross_attention import AttnProcessor, AttnProcessor2_0, CrossAttnProcessor from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig @@ -739,9 +741,25 @@ def __init__(self, config): self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() + self.set_processor(processor) + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + def set_processor(self, processor: "AttnProcessor"): + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/clip/cross_attention.py b/src/transformers/models/clip/cross_attention.py new file mode 100644 index 000000000000..609687c28884 --- /dev/null +++ b/src/transformers/models/clip/cross_attention.py @@ -0,0 +1,627 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib.util +import sys +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +from packaging import version +from torch import nn + +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# The package importlib_metadata is in a different place, depending on the python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + + +_xformers_available = importlib.util.find_spec("xformers") is not None +try: + _xformers_version = importlib_metadata.version("xformers") + if version.Version(torch.__version__) < version.Version("1.12"): + raise ValueError("PyTorch should be >= 1.12") + logger.debug(f"Successfully imported xformers version {_xformers_version}") +except importlib_metadata.PackageNotFoundError: + _xformers_available = False + + +def is_xformers_available(): + return _xformers_available + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class CrossAttention(nn.Module): + r""" + Parameters: + A cross attention layer. + query_dim (`int`): The number of channels in the query. cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. dim_head (`int`, + *optional*, defaults to 64): The number of channels in each head. dropout (`float`, *optional*, defaults to + 0.0): The dropout probability to use. bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + processor: Optional["AttnProcessor"] = None, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.cross_attention_norm = cross_attention_norm + + self.scale = dim_head**-0.5 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + if cross_attention_norm: + self.norm_cross = nn.LayerNorm(cross_attention_dim) + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + # set attention processor + # We use the AttnProcessor2_0 by default when torch2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + if processor is None: + processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() + self.set_processor(processor) + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = CrossAttnAddedKVProcessor() + else: + processor = CrossAttnProcessor() + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor"): + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): + # The `CrossAttention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def get_attention_scores(self, query, key, attention_mask=None): + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask(self, attention_mask, target_length, batch_size=None): + if batch_size is None: + batch_size = 1 + + head_size = self.heads + if attention_mask is None: + return attention_mask + + if attention_mask.shape[-1] != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + return attention_mask + + +class CrossAttnProcessor: + def __call__( + self, + attn: CrossAttention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4): + super().__init__() + + if rank > min(in_features, out_features): + raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") + + self.down = nn.Linear(in_features, rank, bias=False) + self.up = nn.Linear(rank, out_features, bias=False) + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + return up_hidden_states.to(orig_dtype) + + +class LoRACrossAttnProcessor(nn.Module): + def __init__(self, hidden_size, cross_attention_dim=None, rank=4): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + + def __call__( + self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 + ): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) + + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class CrossAttnAddedKVProcessor: + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + +class XFormersCrossAttnProcessor: + def __init__(self, attention_op: Optional[Callable] = None): + self.attention_op = attention_op + + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class AttnProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, inner_dim = hidden_states.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class LoRAXFormersCrossAttnProcessor(nn.Module): + def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + self.attention_op = attention_op + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + + def __call__( + self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 + ): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query).contiguous() + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) + + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op + ) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class SlicedAttnProcessor: + def __init__(self, slice_size): + self.slice_size = slice_size + + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype + ) + + for i in range(hidden_states.shape[0] // self.slice_size): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class SlicedAttnAddedKVProcessor: + def __init__(self, slice_size): + self.slice_size = slice_size + + def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype + ) + + for i in range(hidden_states.shape[0] // self.slice_size): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + +AttnProcessor = Union[ + CrossAttnProcessor, + XFormersCrossAttnProcessor, + SlicedAttnProcessor, + CrossAttnAddedKVProcessor, + SlicedAttnAddedKVProcessor, + LoRACrossAttnProcessor, + LoRAXFormersCrossAttnProcessor, +] diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index b59a3d244d01..e1e66a80b25f 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -16,9 +16,10 @@ from dataclasses import dataclass -from typing import Any, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn @@ -33,6 +34,7 @@ replace_return_docstrings, ) from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig +from .cross_attention import AttnProcessor, AttnProcessor2_0, CrossAttnProcessor logger = logging.get_logger(__name__) @@ -254,9 +256,25 @@ def __init__(self, config): self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() + self.set_processor(processor) + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + def set_processor(self, processor: "AttnProcessor"): + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + def forward( self, hidden_states: torch.Tensor, @@ -777,6 +795,59 @@ def __init__(self, config: CLIPTextConfig): # Initialize weights and apply final processing self.post_init() + @property + def attn_processors(self) -> Dict[str, AttnProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]): + r""" + Parameters: + `processor (`dict` of `AttnProcessor` or `AttnProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + of **all** `CrossAttention` layers. + In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + def get_input_embeddings(self) -> nn.Module: return self.text_model.embeddings.token_embedding diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 3ec81b33fb76..7e2afbba81fc 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -20,6 +20,7 @@ from typing import Any, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn @@ -33,6 +34,7 @@ logging, replace_return_docstrings, ) +from ..clip.cross_attention import AttnProcessor, AttnProcessor2_0, CrossAttnProcessor from .configuration_clipseg import CLIPSegConfig, CLIPSegTextConfig, CLIPSegVisionConfig @@ -275,9 +277,25 @@ def __init__(self, config): self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() + self.set_processor(processor) + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + def set_processor(self, processor: "AttnProcessor"): + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 9d75451292c0..1454f604ffa2 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -21,6 +21,7 @@ from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss @@ -36,6 +37,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ..clip.cross_attention import AttnProcessor, AttnProcessor2_0, CrossAttnProcessor from .configuration_git import GitConfig, GitVisionConfig @@ -673,9 +675,25 @@ def __init__(self, config): self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() + self.set_processor(processor) + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + def set_processor(self, processor: "AttnProcessor"): + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 69289ced40a3..3f027249c593 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -20,6 +20,7 @@ from typing import Any, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn @@ -33,6 +34,7 @@ logging, replace_return_docstrings, ) +from ..clip.cross_attention import AttnProcessor, AttnProcessor2_0, CrossAttnProcessor from .configuration_x_clip import XCLIPConfig, XCLIPTextConfig, XCLIPVisionConfig @@ -207,9 +209,25 @@ def __init__(self, config): self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() + self.set_processor(processor) + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + def set_processor(self, processor: "AttnProcessor"): + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + def forward( self, hidden_states: torch.Tensor,