diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 79fa5179a104..30d497892ff5 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -163,6 +163,7 @@
[
"AllegroTransformer3DModel",
"AsymmetricAutoencoderKL",
+ "AttentionBackendName",
"AuraFlowTransformer2DModel",
"AutoencoderDC",
"AutoencoderKL",
@@ -238,6 +239,7 @@
"VQModel",
"WanTransformer3DModel",
"WanVACETransformer3DModel",
+ "attention_backend",
]
)
_import_structure["modular_pipelines"].extend(
@@ -815,6 +817,7 @@
from .models import (
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
+ AttentionBackendName,
AuraFlowTransformer2DModel,
AutoencoderDC,
AutoencoderKL,
@@ -889,6 +892,7 @@
VQModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
+ attention_backend,
)
from .modular_pipelines import (
ComponentsManager,
diff --git a/src/diffusers/hooks/faster_cache.py b/src/diffusers/hooks/faster_cache.py
index 1be5e1436294..a6c250b50ca4 100644
--- a/src/diffusers/hooks/faster_cache.py
+++ b/src/diffusers/hooks/faster_cache.py
@@ -18,6 +18,7 @@
import torch
+from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention
from ..models.modeling_outputs import Transformer2DModelOutput
from ..utils import logging
@@ -567,7 +568,7 @@ def high_frequency_weight_callback(module: torch.nn.Module) -> float:
_apply_faster_cache_on_denoiser(module, config)
for name, submodule in module.named_modules():
- if not isinstance(submodule, _ATTENTION_CLASSES):
+ if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
continue
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
_apply_faster_cache_on_attention_class(name, submodule, config)
diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py
index bbdd1c3f68d4..1c8787194196 100644
--- a/src/diffusers/hooks/pyramid_attention_broadcast.py
+++ b/src/diffusers/hooks/pyramid_attention_broadcast.py
@@ -18,6 +18,7 @@
import torch
+from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging
from .hooks import HookRegistry, ModelHook
@@ -227,7 +228,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt
config.spatial_attention_block_skip_range = 2
for name, submodule in module.named_modules():
- if not isinstance(submodule, _ATTENTION_CLASSES):
+ if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py
index e05d53687a24..dca4758ba038 100644
--- a/src/diffusers/loaders/ip_adapter.py
+++ b/src/diffusers/loaders/ip_adapter.py
@@ -40,8 +40,6 @@
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
- FluxAttnProcessor2_0,
- FluxIPAdapterJointAttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
@@ -867,6 +865,9 @@ def unload_ip_adapter(self):
>>> ...
```
"""
+ # TODO: once the 1.0.0 deprecations are in, we can move the imports to top-level
+ from ..models.transformers.transformer_flux import FluxAttnProcessor, FluxIPAdapterAttnProcessor
+
# remove CLIP image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None
@@ -886,9 +887,9 @@ def unload_ip_adapter(self):
# restore original Transformer attention processors layers
attn_procs = {}
for name, value in self.transformer.attn_processors.items():
- attn_processor_class = FluxAttnProcessor2_0()
+ attn_processor_class = FluxAttnProcessor()
attn_procs[name] = (
- attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
+ attn_processor_class if isinstance(value, FluxIPAdapterAttnProcessor) else value.__class__()
)
self.transformer.set_attn_processor(attn_procs)
diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py
index 0de809594864..ced81960fae5 100644
--- a/src/diffusers/loaders/transformer_flux.py
+++ b/src/diffusers/loaders/transformer_flux.py
@@ -86,9 +86,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
- from ..models.attention_processor import (
- FluxIPAdapterJointAttnProcessor2_0,
- )
+ from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
if low_cpu_mem_usage:
if is_accelerate_available():
@@ -120,7 +118,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_
else:
cross_attention_dim = self.config.joint_attention_dim
hidden_size = self.inner_dim
- attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
+ attn_processor_class = FluxIPAdapterAttnProcessor
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 7c09df92493e..cd1df3667a18 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -26,6 +26,7 @@
if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
+ _import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
_import_structure["auto_model"] = ["AutoModel"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
@@ -112,6 +113,7 @@
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .adapter import MultiAdapter, T2IAdapter
+ from .attention_dispatch import AttentionBackendName, attention_backend
from .auto_model import AutoModel
from .autoencoders import (
AsymmetricAutoencoderKL,
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index ae51d3ab1349..c720b379551f 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -11,23 +11,504 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple
+
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
+import torch.nn as nn
import torch.nn.functional as F
-from torch import nn
from ..utils import deprecate, logging
+from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
-from .attention_processor import Attention, JointAttnProcessor2_0
+from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
+if is_xformers_available():
+ import xformers as xops
+else:
+ xops = None
+
+
logger = logging.get_logger(__name__)
+class AttentionMixin:
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+ """
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ for module in self.modules():
+ if isinstance(module, AttentionModuleMixin):
+ module.fuse_projections()
+
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ for module in self.modules():
+ if isinstance(module, AttentionModuleMixin):
+ module.unfuse_projections()
+
+
+class AttentionModuleMixin:
+ _default_processor_cls = None
+ _available_processors = []
+ fused_projections = False
+
+ def set_processor(self, processor: AttentionProcessor) -> None:
+ """
+ Set the attention processor to use.
+
+ Args:
+ processor (`AttnProcessor`):
+ The attention processor to use.
+ """
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
+ # pop `processor` from `self._modules`
+ if (
+ hasattr(self, "processor")
+ and isinstance(self.processor, torch.nn.Module)
+ and not isinstance(processor, torch.nn.Module)
+ ):
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
+ self._modules.pop("processor")
+
+ self.processor = processor
+
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
+ """
+ Get the attention processor in use.
+
+ Args:
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
+ Set to `True` to return the deprecated LoRA attention processor.
+
+ Returns:
+ "AttentionProcessor": The attention processor in use.
+ """
+ if not return_deprecated_lora:
+ return self.processor
+
+ def set_attention_backend(self, backend: str):
+ from .attention_dispatch import AttentionBackendName
+
+ available_backends = {x.value for x in AttentionBackendName.__members__.values()}
+ if backend not in available_backends:
+ raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
+
+ backend = AttentionBackendName(backend.lower())
+ self.processor._attention_backend = backend
+
+ def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
+ """
+ Set whether to use NPU flash attention from `torch_npu` or not.
+
+ Args:
+ use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
+ """
+
+ if use_npu_flash_attention:
+ if not is_torch_npu_available():
+ raise ImportError("torch_npu is not available")
+
+ self.set_attention_backend("_native_npu")
+
+ def set_use_xla_flash_attention(
+ self,
+ use_xla_flash_attention: bool,
+ partition_spec: Optional[Tuple[Optional[str], ...]] = None,
+ is_flux=False,
+ ) -> None:
+ """
+ Set whether to use XLA flash attention from `torch_xla` or not.
+
+ Args:
+ use_xla_flash_attention (`bool`):
+ Whether to use pallas flash attention kernel from `torch_xla` or not.
+ partition_spec (`Tuple[]`, *optional*):
+ Specify the partition specification if using SPMD. Otherwise None.
+ is_flux (`bool`, *optional*, defaults to `False`):
+ Whether the model is a Flux model.
+ """
+ if use_xla_flash_attention:
+ if not is_torch_xla_available():
+ raise ImportError("torch_xla is not available")
+
+ self.set_attention_backend("_native_xla")
+
+ def set_use_memory_efficient_attention_xformers(
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
+ ) -> None:
+ """
+ Set whether to use memory efficient attention from `xformers` or not.
+
+ Args:
+ use_memory_efficient_attention_xformers (`bool`):
+ Whether to use memory efficient attention from `xformers` or not.
+ attention_op (`Callable`, *optional*):
+ The attention operation to use. Defaults to `None` which uses the default attention operation from
+ `xformers`.
+ """
+ if use_memory_efficient_attention_xformers:
+ if not is_xformers_available():
+ raise ModuleNotFoundError(
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers",
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
+ " only available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ if is_xformers_available():
+ dtype = None
+ if attention_op is not None:
+ op_fw, op_bw = attention_op
+ dtype, *_ = op_fw.SUPPORTED_DTYPES
+ q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
+ _ = xops.memory_efficient_attention(q, q, q)
+ except Exception as e:
+ raise e
+
+ self.set_attention_backend("xformers")
+
+ @torch.no_grad()
+ def fuse_projections(self):
+ """
+ Fuse the query, key, and value projections into a single projection for efficiency.
+ """
+ # Skip if already fused
+ if getattr(self, "fused_projections", False):
+ return
+
+ device = self.to_q.weight.data.device
+ dtype = self.to_q.weight.data.dtype
+
+ if hasattr(self, "is_cross_attention") and self.is_cross_attention:
+ # Fuse cross-attention key-value projections
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
+ self.to_kv.weight.copy_(concatenated_weights)
+ if hasattr(self, "use_bias") and self.use_bias:
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
+ self.to_kv.bias.copy_(concatenated_bias)
+ else:
+ # Fuse self-attention projections
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
+ self.to_qkv.weight.copy_(concatenated_weights)
+ if hasattr(self, "use_bias") and self.use_bias:
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
+ self.to_qkv.bias.copy_(concatenated_bias)
+
+ # Handle added projections for models like SD3, Flux, etc.
+ if (
+ getattr(self, "add_q_proj", None) is not None
+ and getattr(self, "add_k_proj", None) is not None
+ and getattr(self, "add_v_proj", None) is not None
+ ):
+ concatenated_weights = torch.cat(
+ [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
+ )
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_added_qkv = nn.Linear(
+ in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
+ )
+ self.to_added_qkv.weight.copy_(concatenated_weights)
+ if self.added_proj_bias:
+ concatenated_bias = torch.cat(
+ [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
+ )
+ self.to_added_qkv.bias.copy_(concatenated_bias)
+
+ self.fused_projections = True
+
+ @torch.no_grad()
+ def unfuse_projections(self):
+ """
+ Unfuse the query, key, and value projections back to separate projections.
+ """
+ # Skip if not fused
+ if not getattr(self, "fused_projections", False):
+ return
+
+ # Remove fused projection layers
+ if hasattr(self, "to_qkv"):
+ delattr(self, "to_qkv")
+
+ if hasattr(self, "to_kv"):
+ delattr(self, "to_kv")
+
+ if hasattr(self, "to_added_qkv"):
+ delattr(self, "to_added_qkv")
+
+ self.fused_projections = False
+
+ def set_attention_slice(self, slice_size: int) -> None:
+ """
+ Set the slice size for attention computation.
+
+ Args:
+ slice_size (`int`):
+ The slice size for attention computation.
+ """
+ if hasattr(self, "sliceable_head_dim") and slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ processor = None
+
+ # Try to get a compatible processor for sliced attention
+ if slice_size is not None:
+ processor = self._get_compatible_processor("sliced")
+
+ # If no processor was found or slice_size is None, use default processor
+ if processor is None:
+ processor = self.default_processor_cls()
+
+ self.set_processor(processor)
+
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
+ """
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ batch_size, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
+ """
+ Reshape the tensor for multi-head attention processing.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ if tensor.ndim == 3:
+ batch_size, seq_len, dim = tensor.shape
+ extra_dim = 1
+ else:
+ batch_size, extra_dim, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3)
+
+ if out_dim == 3:
+ tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
+
+ return tensor
+
+ def get_attention_scores(
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """
+ Compute the attention scores.
+
+ Args:
+ query (`torch.Tensor`): The query tensor.
+ key (`torch.Tensor`): The key tensor.
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use.
+
+ Returns:
+ `torch.Tensor`: The attention probabilities/scores.
+ """
+ dtype = query.dtype
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ if attention_mask is None:
+ baddbmm_input = torch.empty(
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
+ )
+ beta = 0
+ else:
+ baddbmm_input = attention_mask
+ beta = 1
+
+ attention_scores = torch.baddbmm(
+ baddbmm_input,
+ query,
+ key.transpose(-1, -2),
+ beta=beta,
+ alpha=self.scale,
+ )
+ del baddbmm_input
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+ del attention_scores
+
+ attention_probs = attention_probs.to(dtype)
+
+ return attention_probs
+
+ def prepare_attention_mask(
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
+ ) -> torch.Tensor:
+ """
+ Prepare the attention mask for the attention computation.
+
+ Args:
+ attention_mask (`torch.Tensor`): The attention mask to prepare.
+ target_length (`int`): The target length of the attention mask.
+ batch_size (`int`): The batch size for repeating the attention mask.
+ out_dim (`int`, *optional*, defaults to `3`): Output dimension.
+
+ Returns:
+ `torch.Tensor`: The prepared attention mask.
+ """
+ head_size = self.heads
+ if attention_mask is None:
+ return attention_mask
+
+ current_length: int = attention_mask.shape[-1]
+ if current_length != target_length:
+ if attention_mask.device.type == "mps":
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
+ # Instead, we can manually construct the padding tensor.
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
+ else:
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
+ # remaining_length: int = target_length - current_length
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+
+ if out_dim == 3:
+ if attention_mask.shape[0] < batch_size * head_size:
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
+ elif out_dim == 4:
+ attention_mask = attention_mask.unsqueeze(1)
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
+
+ return attention_mask
+
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ """
+ Normalize the encoder hidden states.
+
+ Args:
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
+
+ Returns:
+ `torch.Tensor`: The normalized encoder hidden states.
+ """
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
+ if isinstance(self.norm_cross, nn.LayerNorm):
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ elif isinstance(self.norm_cross, nn.GroupNorm):
+ # Group norm norms along the channels dimension and expects
+ # input to be in the shape of (N, C, *). In this case, we want
+ # to norm along the hidden dimension, so we need to move
+ # (batch_size, sequence_length, hidden_size) ->
+ # (batch_size, hidden_size, sequence_length)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ else:
+ assert False
+
+ return encoder_hidden_states
+
+
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py
new file mode 100644
index 000000000000..141a7fee858b
--- /dev/null
+++ b/src/diffusers/models/attention_dispatch.py
@@ -0,0 +1,1155 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+import functools
+import inspect
+import math
+from enum import Enum
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
+
+import torch
+
+from ..utils import (
+ get_logger,
+ is_flash_attn_3_available,
+ is_flash_attn_available,
+ is_flash_attn_version,
+ is_sageattention_available,
+ is_sageattention_version,
+ is_torch_npu_available,
+ is_torch_version,
+ is_torch_xla_available,
+ is_torch_xla_version,
+ is_xformers_available,
+ is_xformers_version,
+)
+from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+
+if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"):
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+else:
+ logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.")
+ flash_attn_func = None
+ flash_attn_varlen_func = None
+
+
+if is_flash_attn_3_available():
+ from flash_attn_interface import flash_attn_func as flash_attn_3_func
+ from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
+else:
+ flash_attn_3_func = None
+ flash_attn_3_varlen_func = None
+
+
+if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"):
+ from sageattention import (
+ sageattn,
+ sageattn_qk_int8_pv_fp8_cuda,
+ sageattn_qk_int8_pv_fp8_cuda_sm90,
+ sageattn_qk_int8_pv_fp16_cuda,
+ sageattn_qk_int8_pv_fp16_triton,
+ sageattn_varlen,
+ )
+else:
+ logger.warning(
+ "`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`."
+ )
+ sageattn = None
+ sageattn_qk_int8_pv_fp16_cuda = None
+ sageattn_qk_int8_pv_fp16_triton = None
+ sageattn_qk_int8_pv_fp8_cuda = None
+ sageattn_qk_int8_pv_fp8_cuda_sm90 = None
+ sageattn_varlen = None
+
+
+if is_torch_version(">=", "2.5.0"):
+ # We cannot import the flex_attention function from the package directly because it is expected (from the
+ # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
+ # compiled function.
+ import torch.nn.attention.flex_attention as flex_attention
+
+
+if is_torch_npu_available():
+ from torch_npu import npu_fusion_attention
+else:
+ npu_fusion_attention = None
+
+
+if is_torch_xla_available() and is_torch_xla_version(">", "2.2"):
+ from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
+else:
+ xla_flash_attention = None
+
+
+if is_xformers_available() and is_xformers_version(">=", "0.0.29"):
+ import xformers.ops as xops
+else:
+ logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.")
+ xops = None
+
+
+# TODO(aryan): Add support for the following:
+# - Sage Attention++
+# - block sparse, radial and other attention methods
+# - CP with sage attention, flex, xformers, other missing backends
+# - Add support for normal and CP training with backends that don't support it yet
+
+
+_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
+_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
+_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
+
+
+class AttentionBackendName(str, Enum):
+ # EAGER = "eager"
+
+ # `flash-attn`
+ FLASH = "flash"
+ FLASH_VARLEN = "flash_varlen"
+ _FLASH_3 = "_flash_3"
+ _FLASH_VARLEN_3 = "_flash_varlen_3"
+
+ # PyTorch native
+ FLEX = "flex"
+ NATIVE = "native"
+ _NATIVE_CUDNN = "_native_cudnn"
+ _NATIVE_EFFICIENT = "_native_efficient"
+ _NATIVE_FLASH = "_native_flash"
+ _NATIVE_MATH = "_native_math"
+ _NATIVE_NPU = "_native_npu"
+ _NATIVE_XLA = "_native_xla"
+
+ # `sageattention`
+ SAGE = "sage"
+ SAGE_VARLEN = "sage_varlen"
+ _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
+ _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
+ _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda"
+ _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton"
+ # TODO: let's not add support for Sparge Attention now because it requires tuning per model
+ # We can look into supporting something "autotune"-ing in the future
+ # SPARGE = "sparge"
+
+ # `xformers`
+ XFORMERS = "xformers"
+
+
+class _AttentionBackendRegistry:
+ _backends = {}
+ _constraints = {}
+ _supported_arg_names = {}
+ _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
+ _checks_enabled = DIFFUSERS_ATTN_CHECKS
+
+ @classmethod
+ def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None):
+ logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}")
+
+ def decorator(func):
+ cls._backends[backend] = func
+ cls._constraints[backend] = constraints or []
+ cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
+ return func
+
+ return decorator
+
+ @classmethod
+ def get_active_backend(cls):
+ return cls._active_backend, cls._backends[cls._active_backend]
+
+ @classmethod
+ def list_backends(cls):
+ return list(cls._backends.keys())
+
+
+@contextlib.contextmanager
+def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE):
+ """
+ Context manager to set the active attention backend.
+ """
+ if backend not in _AttentionBackendRegistry._backends:
+ raise ValueError(f"Backend {backend} is not registered.")
+
+ old_backend = _AttentionBackendRegistry._active_backend
+ _AttentionBackendRegistry._active_backend = backend
+
+ try:
+ yield
+ finally:
+ _AttentionBackendRegistry._active_backend = old_backend
+
+
+def dispatch_attention_fn(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ *,
+ backend: Optional[AttentionBackendName] = None,
+) -> torch.Tensor:
+ attention_kwargs = attention_kwargs or {}
+
+ if backend is None:
+ # If no backend is specified, we either use the default backend (set via the DIFFUSERS_ATTN_BACKEND environment
+ # variable), or we use a custom backend based on whether user is using the `attention_backend` context manager
+ backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend()
+ else:
+ backend_name = AttentionBackendName(backend)
+ backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
+
+ kwargs = {
+ "query": query,
+ "key": key,
+ "value": value,
+ "attn_mask": attn_mask,
+ "dropout_p": dropout_p,
+ "is_causal": is_causal,
+ "scale": scale,
+ "enable_gqa": enable_gqa,
+ **attention_kwargs,
+ }
+
+ if _AttentionBackendRegistry._checks_enabled:
+ removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name])
+ if removed_kwargs:
+ logger.warning(f"Removing unsupported arguments for attention backend {backend_name}: {removed_kwargs}.")
+ for check in _AttentionBackendRegistry._constraints.get(backend_name):
+ check(**kwargs)
+
+ kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
+ return backend_fn(**kwargs)
+
+
+# ===== Checks =====
+# A list of very simple functions to catch common errors quickly when debugging.
+
+
+def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None:
+ if attn_mask is not None and is_causal:
+ raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.")
+
+
+def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ if query.device != key.device or query.device != value.device:
+ raise ValueError("Query, key, and value must be on the same device.")
+ if query.dtype != key.dtype or query.dtype != value.dtype:
+ raise ValueError("Query, key, and value must have the same dtype.")
+
+
+def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ _check_device(query, key, value)
+ if query.device.type != "cuda":
+ raise ValueError("Query, key, and value must be on a CUDA device.")
+
+
+def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable:
+ def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ _check_device_cuda(query, key, value)
+ if torch.cuda.get_device_capability(query.device) < (major, minor):
+ raise ValueError(
+ f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}."
+ )
+
+ return check_device_cuda
+
+
+def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ if query.dtype != key.dtype:
+ raise ValueError("Query and key must have the same dtype.")
+ if query.dtype != value.dtype:
+ raise ValueError("Query and value must have the same dtype.")
+
+
+def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ _check_qkv_dtype_match(query, key, value)
+ if query.dtype not in (torch.bfloat16, torch.float16):
+ raise ValueError("Query, key, and value must be either bfloat16 or float16.")
+
+
+def _check_shape(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+) -> None:
+ if query.shape[-1] != key.shape[-1]:
+ raise ValueError("Query and key must have the same last dimension.")
+ if query.shape[-2] != value.shape[-2]:
+ raise ValueError("Query and value must have the same second to last dimension.")
+ if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]:
+ raise ValueError("Attention mask must match the key's second to last dimension.")
+
+
+# ===== Helper functions =====
+
+
+@functools.lru_cache(maxsize=128)
+def _prepare_for_flash_attn_or_sage_varlen_without_mask(
+ batch_size: int,
+ seq_len_q: int,
+ seq_len_kv: int,
+ device: Optional[torch.device] = None,
+):
+ seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
+ seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
+ cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
+ cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
+ cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
+ cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
+ max_seqlen_q = seqlens_q.max().item()
+ max_seqlen_k = seqlens_k.max().item()
+ return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
+
+
+def _prepare_for_flash_attn_or_sage_varlen_with_mask(
+ batch_size: int,
+ seq_len_q: int,
+ attn_mask: torch.Tensor,
+ device: Optional[torch.device] = None,
+):
+ seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
+ seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
+ cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
+ cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
+ cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
+ cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
+ max_seqlen_q = seqlens_q.max().item()
+ max_seqlen_k = seqlens_k.max().item()
+ return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
+
+
+def _prepare_for_flash_attn_or_sage_varlen(
+ batch_size: int,
+ seq_len_q: int,
+ seq_len_kv: int,
+ attn_mask: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+) -> None:
+ if attn_mask is None:
+ return _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device)
+ return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, device)
+
+
+def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor:
+ """
+ Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in
+ FlashAttention/Sage varlen.
+
+ Supports 1D to 4D shapes and common broadcasting patterns.
+ """
+ if attn_mask.dtype != torch.bool:
+ raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.")
+
+ if attn_mask.ndim == 1:
+ # [seq_len_k] -> broadcast across batch
+ attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k)
+
+ elif attn_mask.ndim == 2:
+ # [batch_size, seq_len_k]. Maybe broadcast across batch
+ if attn_mask.size(0) not in [1, batch_size]:
+ raise ValueError(
+ f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask."
+ )
+ attn_mask = attn_mask.expand(batch_size, seq_len_k)
+
+ elif attn_mask.ndim == 3:
+ # [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension
+ # We do this reduction because we know that arbitrary QK masks is not supported in Flash/Sage varlen.
+ if attn_mask.size(0) not in [1, batch_size]:
+ raise ValueError(
+ f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask."
+ )
+ attn_mask = attn_mask.any(dim=1)
+ attn_mask = attn_mask.expand(batch_size, seq_len_k)
+
+ elif attn_mask.ndim == 4:
+ # [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions
+ if attn_mask.size(0) not in [1, batch_size]:
+ raise ValueError(
+ f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask."
+ )
+ attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K]
+ attn_mask = attn_mask.any(dim=(1, 2)) # [B, K]
+
+ else:
+ raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}")
+
+ if attn_mask.shape != (batch_size, seq_len_k):
+ raise ValueError(
+ f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})"
+ )
+
+ return attn_mask
+
+
+def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
+ return q_idx >= kv_idx
+
+
+# ===== torch op registrations =====
+# Registrations are required for fullgraph tracing compatibility
+
+
+# TODO: library.custom_op and register_fake probably need version guards?
+# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
+# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
+@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
+def _wrapped_flash_attn_3_original(
+ query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ out, lse = flash_attn_3_func(query, key, value)
+ lse = lse.permute(0, 2, 1)
+ return out, lse
+
+
+@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
+def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ batch_size, seq_len, num_heads, head_dim = query.shape
+ lse_shape = (batch_size, seq_len, num_heads)
+ return torch.empty_like(query), query.new_empty(lse_shape)
+
+
+# ===== Attention backends =====
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.FLASH,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _flash_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ window_size: Tuple[int, int] = (-1, -1),
+ softcap: float = 0.0,
+ alibi_slopes: Optional[torch.Tensor] = None,
+ deterministic: bool = False,
+ return_attn_probs: bool = False,
+) -> torch.Tensor:
+ out = flash_attn_func(
+ q=query,
+ k=key,
+ v=value,
+ dropout_p=dropout_p,
+ softmax_scale=scale,
+ causal=is_causal,
+ window_size=window_size,
+ softcap=softcap,
+ alibi_slopes=alibi_slopes,
+ deterministic=deterministic,
+ return_attn_probs=return_attn_probs,
+ )
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.FLASH_VARLEN,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _flash_varlen_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ cu_seqlens_q: Optional[torch.Tensor] = None,
+ cu_seqlens_k: Optional[torch.Tensor] = None,
+ max_seqlen_q: Optional[int] = None,
+ max_seqlen_k: Optional[int] = None,
+ dropout_p: float = 0.0,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ window_size: Tuple[int, int] = (-1, -1),
+ softcap: float = 0.0,
+ alibi_slopes: Optional[torch.Tensor] = None,
+ deterministic: bool = False,
+ return_attn_probs: bool = False,
+ attn_mask: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ batch_size, seq_len_q, _, _ = query.shape
+ _, seq_len_kv, _, _ = key.shape
+
+ if attn_mask is not None:
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
+
+ if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
+ )
+ )
+ else:
+ seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
+ cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
+ cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
+
+ key_valid, value_valid = [], []
+ for b in range(batch_size):
+ valid_len = seqlens_k[b]
+ key_valid.append(key[b, :valid_len])
+ value_valid.append(value[b, :valid_len])
+
+ query_packed = query.flatten(0, 1)
+ key_packed = torch.cat(key_valid, dim=0)
+ value_packed = torch.cat(value_valid, dim=0)
+
+ out = flash_attn_varlen_func(
+ q=query_packed,
+ k=key_packed,
+ v=value_packed,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ dropout_p=dropout_p,
+ softmax_scale=scale,
+ causal=is_causal,
+ window_size=window_size,
+ softcap=softcap,
+ alibi_slopes=alibi_slopes,
+ deterministic=deterministic,
+ return_attn_probs=return_attn_probs,
+ )
+ out = out.unflatten(0, (batch_size, -1))
+
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._FLASH_3,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _flash_attention_3(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ window_size: Tuple[int, int] = (-1, -1),
+ softcap: float = 0.0,
+ deterministic: bool = False,
+ return_attn_probs: bool = False,
+) -> torch.Tensor:
+ out, lse, *_ = flash_attn_3_func(
+ q=query,
+ k=key,
+ v=value,
+ softmax_scale=scale,
+ causal=is_causal,
+ qv=None,
+ q_descale=None,
+ k_descale=None,
+ v_descale=None,
+ window_size=window_size,
+ attention_chunk=0,
+ softcap=softcap,
+ num_splits=1,
+ pack_gqa=None,
+ deterministic=deterministic,
+ sm_margin=0,
+ )
+ return (out, lse) if return_attn_probs else out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._FLASH_VARLEN_3,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _flash_varlen_attention_3(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ cu_seqlens_q: Optional[torch.Tensor] = None,
+ cu_seqlens_k: Optional[torch.Tensor] = None,
+ max_seqlen_q: Optional[int] = None,
+ max_seqlen_k: Optional[int] = None,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ window_size: Tuple[int, int] = (-1, -1),
+ softcap: float = 0.0,
+ deterministic: bool = False,
+ return_attn_probs: bool = False,
+ attn_mask: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ batch_size, seq_len_q, _, _ = query.shape
+ _, seq_len_kv, _, _ = key.shape
+
+ if attn_mask is not None:
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
+
+ if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
+ )
+ )
+ else:
+ seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
+ cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
+ cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
+
+ key_valid, value_valid = [], []
+ for b in range(batch_size):
+ valid_len = seqlens_k[b]
+ key_valid.append(key[b, :valid_len])
+ value_valid.append(value[b, :valid_len])
+
+ query_packed = query.flatten(0, 1)
+ key_packed = torch.cat(key_valid, dim=0)
+ value_packed = torch.cat(value_valid, dim=0)
+
+ out, lse, *_ = flash_attn_3_varlen_func(
+ q=query_packed,
+ k=key_packed,
+ v=value_packed,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ seqused_q=None,
+ seqused_k=None,
+ softmax_scale=scale,
+ causal=is_causal,
+ qv=None,
+ q_descale=None,
+ k_descale=None,
+ v_descale=None,
+ window_size=window_size,
+ softcap=softcap,
+ num_splits=1,
+ pack_gqa=None,
+ deterministic=deterministic,
+ sm_margin=0,
+ )
+ out = out.unflatten(0, (batch_size, -1))
+
+ return (out, lse) if return_attn_probs else out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.FLEX,
+ constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
+)
+def _native_flex_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ kernel_options: Optional[Dict[str, Any]] = None,
+) -> torch.Tensor:
+ # TODO: should we LRU cache the block mask creation?
+ score_mod = None
+ block_mask = None
+ batch_size, seq_len_q, num_heads, _ = query.shape
+ _, seq_len_kv, _, _ = key.shape
+
+ if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask):
+ block_mask = attn_mask
+ elif is_causal:
+ block_mask = flex_attention.create_block_mask(
+ _flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device
+ )
+ elif torch.is_tensor(attn_mask):
+ if attn_mask.ndim == 2:
+ attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
+
+ attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv)
+
+ if attn_mask.dtype == torch.bool:
+ # TODO: this probably does not work but verify!
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
+ return attn_mask[batch_idx, head_idx, q_idx, kv_idx]
+
+ block_mask = flex_attention.create_block_mask(
+ mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device
+ )
+ else:
+
+ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
+ return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx]
+ else:
+ raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.")
+
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ out = flex_attention.flex_attention(
+ query=query,
+ key=key,
+ value=value,
+ score_mod=score_mod,
+ block_mask=block_mask,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ return_lse=return_lse,
+ kernel_options=kernel_options,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.NATIVE,
+ constraints=[_check_device, _check_shape],
+)
+def _native_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_CUDNN,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _native_cudnn_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_EFFICIENT,
+ constraints=[_check_device, _check_shape],
+)
+def _native_efficient_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_FLASH,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _native_flash_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=None, # not supported
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_MATH,
+ constraints=[_check_device, _check_shape],
+)
+def _native_math_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_NPU,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _native_npu_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ scale: Optional[float] = None,
+) -> torch.Tensor:
+ return npu_fusion_attention(
+ query,
+ key,
+ value,
+ query.size(2), # num_heads
+ input_layout="BSND",
+ pse=None,
+ scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
+ pre_tockens=65536,
+ next_tokens=65536,
+ keep_prob=1.0 - dropout_p,
+ sync=False,
+ inner_precise=0,
+ )[0]
+
+
+# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_XLA,
+ constraints=[_check_device, _check_shape],
+)
+def _native_xla_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+) -> torch.Tensor:
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ query = query / math.sqrt(query.shape[-1])
+ out = xla_flash_attention(
+ q=query,
+ k=key,
+ v=value,
+ causal=is_causal,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.SAGE,
+ constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _sage_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ return sageattn(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ sm_scale=scale,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.SAGE_VARLEN,
+ constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _sage_varlen_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ cu_seqlens_q: Optional[torch.Tensor] = None,
+ cu_seqlens_k: Optional[torch.Tensor] = None,
+ max_seqlen_q: Optional[int] = None,
+ max_seqlen_k: Optional[int] = None,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ smooth_k: bool = True,
+ attn_mask: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ batch_size, seq_len_q, _, _ = query.shape
+ _, seq_len_kv, _, _ = key.shape
+
+ if attn_mask is not None:
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
+
+ if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
+ )
+ )
+ else:
+ seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
+ cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
+ cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
+
+ key_valid, value_valid = [], []
+ for b in range(batch_size):
+ valid_len = seqlens_k[b]
+ key_valid.append(key[b, :valid_len])
+ value_valid.append(value[b, :valid_len])
+
+ query_packed = query.flatten(0, 1)
+ key_packed = torch.cat(key_valid, dim=0)
+ value_packed = torch.cat(value_valid, dim=0)
+
+ out = sageattn_varlen(
+ q=query_packed,
+ k=key_packed,
+ v=value_packed,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ is_causal=is_causal,
+ sm_scale=scale,
+ smooth_k=smooth_k,
+ )
+ out = out.unflatten(0, (batch_size, -1))
+
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
+ constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
+)
+def _sage_qk_int8_pv_fp8_cuda_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
+ pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
+ smooth_k: bool = True,
+ smooth_v: bool = False,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp8_cuda(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ qk_quant_gran=qk_quant_gran,
+ sm_scale=scale,
+ pv_accum_dtype=pv_accum_dtype,
+ smooth_k=smooth_k,
+ smooth_v=smooth_v,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
+ constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
+)
+def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
+ pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
+ smooth_k: bool = True,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp8_cuda_sm90(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ qk_quant_gran=qk_quant_gran,
+ sm_scale=scale,
+ pv_accum_dtype=pv_accum_dtype,
+ smooth_k=smooth_k,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
+ constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
+)
+def _sage_qk_int8_pv_fp16_cuda_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
+ pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32",
+ smooth_k: bool = True,
+ smooth_v: bool = False,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp16_cuda(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ qk_quant_gran=qk_quant_gran,
+ sm_scale=scale,
+ pv_accum_dtype=pv_accum_dtype,
+ smooth_k=smooth_k,
+ smooth_v=smooth_v,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
+ constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
+)
+def _sage_qk_int8_pv_fp16_triton_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton",
+ smooth_k: bool = True,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp16_triton(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ quantization_backend=quantization_backend,
+ is_causal=is_causal,
+ sm_scale=scale,
+ smooth_k=smooth_k,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.XFORMERS,
+ constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
+)
+def _xformers_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ batch_size, seq_len_q, num_heads_q, _ = query.shape
+ _, seq_len_kv, num_heads_kv, _ = key.shape
+
+ if is_causal:
+ attn_mask = xops.LowerTriangularMask()
+ elif attn_mask is not None:
+ if attn_mask.ndim == 2:
+ attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
+ elif attn_mask.ndim != 4:
+ raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
+ attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
+
+ if enable_gqa:
+ if num_heads_q % num_heads_kv != 0:
+ raise ValueError("Number of heads in query must be divisible by number of heads in key/value.")
+ num_heads_per_group = num_heads_q // num_heads_kv
+ query = query.unflatten(2, (num_heads_kv, -1))
+ key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
+ value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
+
+ out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale)
+
+ if enable_gqa:
+ out = out.flatten(2, 3)
+
+ return out
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 4760cfd40b3c..990245de1742 100755
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -2272,558 +2272,6 @@ def __call__(
return hidden_states
-class FluxAttnProcessor2_0:
- """Attention processor used typically in processing the SD3-like self-attention projections."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- if encoder_hidden_states is not None:
- # `context` projections.
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
-class FluxAttnProcessor2_0_NPU:
- """Attention processor used typically in processing the SD3-like self-attention projections."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- if encoder_hidden_states is not None:
- # `context` projections.
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- if query.dtype in (torch.float16, torch.bfloat16):
- hidden_states = torch_npu.npu_fusion_attention(
- query,
- key,
- value,
- attn.heads,
- input_layout="BNSD",
- pse=None,
- scale=1.0 / math.sqrt(query.shape[-1]),
- pre_tockens=65536,
- next_tockens=65536,
- keep_prob=1.0,
- sync=False,
- inner_precise=0,
- )[0]
- else:
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
-class FusedFluxAttnProcessor2_0:
- """Attention processor used typically in processing the SD3-like self-attention projections."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- qkv = attn.to_qkv(hidden_states)
- split_size = qkv.shape[-1] // 3
- query, key, value = torch.split(qkv, split_size, dim=-1)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- # `context` projections.
- if encoder_hidden_states is not None:
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
- split_size = encoder_qkv.shape[-1] // 3
- (
- encoder_hidden_states_query_proj,
- encoder_hidden_states_key_proj,
- encoder_hidden_states_value_proj,
- ) = torch.split(encoder_qkv, split_size, dim=-1)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
-class FusedFluxAttnProcessor2_0_NPU:
- """Attention processor used typically in processing the SD3-like self-attention projections."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- qkv = attn.to_qkv(hidden_states)
- split_size = qkv.shape[-1] // 3
- query, key, value = torch.split(qkv, split_size, dim=-1)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- # `context` projections.
- if encoder_hidden_states is not None:
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
- split_size = encoder_qkv.shape[-1] // 3
- (
- encoder_hidden_states_query_proj,
- encoder_hidden_states_key_proj,
- encoder_hidden_states_value_proj,
- ) = torch.split(encoder_qkv, split_size, dim=-1)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- if query.dtype in (torch.float16, torch.bfloat16):
- hidden_states = torch_npu.npu_fusion_attention(
- query,
- key,
- value,
- attn.heads,
- input_layout="BNSD",
- pse=None,
- scale=1.0 / math.sqrt(query.shape[-1]),
- pre_tockens=65536,
- next_tockens=65536,
- keep_prob=1.0,
- sync=False,
- inner_precise=0,
- )[0]
- else:
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
-class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
- """Flux Attention processor for IP-Adapter."""
-
- def __init__(
- self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
- ):
- super().__init__()
-
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
- )
-
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
-
- if not isinstance(num_tokens, (tuple, list)):
- num_tokens = [num_tokens]
-
- if not isinstance(scale, list):
- scale = [scale] * len(num_tokens)
- if len(scale) != len(num_tokens):
- raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
- self.scale = scale
-
- self.to_k_ip = nn.ModuleList(
- [
- nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
- for _ in range(len(num_tokens))
- ]
- )
- self.to_v_ip = nn.ModuleList(
- [
- nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
- for _ in range(len(num_tokens))
- ]
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ip_hidden_states: Optional[List[torch.Tensor]] = None,
- ip_adapter_masks: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- hidden_states_query_proj = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- if encoder_hidden_states is not None:
- # `context` projections.
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- # IP-adapter
- ip_query = hidden_states_query_proj
- ip_attn_output = torch.zeros_like(hidden_states)
-
- for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
- ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
- ):
- ip_key = to_k_ip(current_ip_hidden_states)
- ip_value = to_v_ip(current_ip_hidden_states)
-
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- current_ip_hidden_states = F.scaled_dot_product_attention(
- ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
- current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
- batch_size, -1, attn.heads * head_dim
- )
- current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
- ip_attn_output += scale * current_ip_hidden_states
-
- return hidden_states, encoder_hidden_states, ip_attn_output
- else:
- return hidden_states
-
-
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -3453,106 +2901,6 @@ def __call__(
return hidden_states
-class XLAFluxFlashAttnProcessor2_0:
- r"""
- Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
- """
-
- def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
- )
- if is_torch_xla_version("<", "2.3"):
- raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
- if is_spmd() and is_torch_xla_version("<", "2.4"):
- raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
- self.partition_spec = partition_spec
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- if encoder_hidden_states is not None:
- # `context` projections.
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- query /= math.sqrt(head_dim)
- hidden_states = flash_attention(query, key, value, causal=False)
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
class MochiVaeAttnProcessor2_0:
r"""
Attention processor used in Mochi VAE.
@@ -5992,17 +5340,6 @@ def __init__(self):
pass
-class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
- r"""
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
- """
-
- def __init__(self):
- deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
- deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
- super().__init__()
-
-
class SanaLinearAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product linear attention.
@@ -6167,6 +5504,111 @@ def __call__(
return hidden_states
+class FluxAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
+ deprecate("FluxAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ return FluxAttnProcessor(*args, **kwargs)
+
+
+class FluxSingleAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FluxSingleAttnProcessor` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead."
+ deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ return FluxAttnProcessor(*args, **kwargs)
+
+
+class FusedFluxAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FusedFluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
+ deprecate("FusedFluxAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ return FluxAttnProcessor(*args, **kwargs)
+
+
+class FluxIPAdapterJointAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FluxIPAdapterJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxIPAdapterAttnProcessor`"
+ deprecate("FluxIPAdapterJointAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ from .transformers.transformer_flux import FluxIPAdapterAttnProcessor
+
+ return FluxIPAdapterAttnProcessor(*args, **kwargs)
+
+
+class FluxAttnProcessor2_0_NPU:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = (
+ "FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
+ "alternative solution to use NPU Flash Attention will be provided in the future."
+ )
+ deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
+
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ processor = FluxAttnProcessor()
+ processor._attention_backend = "_native_npu"
+ return processor
+
+
+class FusedFluxAttnProcessor2_0_NPU:
+ def __new__(self):
+ deprecation_message = (
+ "FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
+ "alternative solution to use NPU Flash Attention will be provided in the future."
+ )
+ deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
+
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ processor = FluxAttnProcessor()
+ processor._attention_backend = "_fused_npu"
+ return processor
+
+
+class XLAFluxFlashAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
+ """
+
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = (
+ "XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An "
+ "alternative solution to using XLA Flash Attention will be provided in the future."
+ )
+ deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
+
+ if is_torch_xla_version("<", "2.3"):
+ raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
+ if is_spmd() and is_torch_xla_version("<", "2.4"):
+ raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
+
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ if len(args) > 0 or kwargs.get("partition_spec", None) is not None:
+ deprecation_message = (
+ "partition_spec was not used in the processor implementation when it was added. Passing it "
+ "is a no-op and support for it will be removed."
+ )
+ deprecate("partition_spec", "1.0.0", deprecation_message)
+
+ processor = FluxAttnProcessor(*args, **kwargs)
+ processor._attention_backend = "_native_xla"
+ return processor
+
+
ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index d77aa1aaa635..b51f5d7aec25 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -1181,6 +1181,7 @@ def apply_rotary_emb(
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
+ sequence_dim: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
@@ -1198,8 +1199,15 @@ def apply_rotary_emb(
"""
if use_real:
cos, sin = freqs_cis # [S, D]
- cos = cos[None, None]
- sin = sin[None, None]
+ if sequence_dim == 2:
+ cos = cos[None, None, :, :]
+ sin = sin[None, None, :, :]
+ elif sequence_dim == 1:
+ cos = cos[None, :, None, :]
+ sin = sin[None, :, None, :]
+ else:
+ raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
+
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
@@ -1243,37 +1251,6 @@ def apply_1d_rope(tokens, pos, cos, sin):
return x
-class FluxPosEmbed(nn.Module):
- # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
- def __init__(self, theta: int, axes_dim: List[int]):
- super().__init__()
- self.theta = theta
- self.axes_dim = axes_dim
-
- def forward(self, ids: torch.Tensor) -> torch.Tensor:
- n_axes = ids.shape[-1]
- cos_out = []
- sin_out = []
- pos = ids.float()
- is_mps = ids.device.type == "mps"
- is_npu = ids.device.type == "npu"
- freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
- for i in range(n_axes):
- cos, sin = get_1d_rotary_pos_embed(
- self.axes_dim[i],
- pos[:, i],
- theta=self.theta,
- repeat_interleave_real=True,
- use_real=True,
- freqs_dtype=freqs_dtype,
- )
- cos_out.append(cos)
- sin_out.append(sin)
- freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
- freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
- return freqs_cos, freqs_sin
-
-
class TimestepEmbedding(nn.Module):
def __init__(
self,
@@ -2624,3 +2601,13 @@ def forward(self, image_embeds: List[torch.Tensor]):
projected_image_embeds.append(image_embed)
return projected_image_embeds
+
+
+class FluxPosEmbed(nn.Module):
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please import it from `diffusers.models.transformers.transformer_flux`."
+ deprecate("FluxPosEmbed", "1.0.0", deprecation_message)
+
+ from .transformers.transformer_flux import FluxPosEmbed
+
+ return FluxPosEmbed(*args, **kwargs)
diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py
index 405f1ddf845b..fb01e7e01a1e 100644
--- a/src/diffusers/models/modeling_utils.py
+++ b/src/diffusers/models/modeling_utils.py
@@ -610,6 +610,56 @@ def enable_group_offload(
offload_to_disk_path=offload_to_disk_path,
)
+ def set_attention_backend(self, backend: str) -> None:
+ """
+ Set the attention backend for the model.
+
+ Args:
+ backend (`str`):
+ The name of the backend to set. Must be one of the available backends defined in
+ `AttentionBackendName`. Available backends can be found in
+ `diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product
+ attention as backend.
+ """
+ from .attention import AttentionModuleMixin
+ from .attention_dispatch import AttentionBackendName
+
+ # TODO: the following will not be required when everything is refactored to AttentionModuleMixin
+ from .attention_processor import Attention, MochiAttention
+
+ backend = backend.lower()
+ available_backends = {x.value for x in AttentionBackendName.__members__.values()}
+ if backend not in available_backends:
+ raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
+
+ backend = AttentionBackendName(backend)
+ attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
+
+ for module in self.modules():
+ if not isinstance(module, attention_classes):
+ continue
+ processor = module.processor
+ if processor is None or not hasattr(processor, "_attention_backend"):
+ continue
+ processor._attention_backend = backend
+
+ def reset_attention_backend(self) -> None:
+ """
+ Resets the attention backend for the model. Following calls to `forward` will use the environment default or
+ the torch native scaled dot product attention.
+ """
+ from .attention import AttentionModuleMixin
+ from .attention_processor import Attention, MochiAttention
+
+ attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
+ for module in self.modules():
+ if not isinstance(module, attention_classes):
+ continue
+ processor = module.processor
+ if processor is None or not hasattr(processor, "_attention_backend"):
+ continue
+ processor._attention_backend = None
+
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py
index 0f6dd677ac5c..5823ae9d3da6 100644
--- a/src/diffusers/models/transformers/transformer_chroma.py
+++ b/src/diffusers/models/transformers/transformer_chroma.py
@@ -24,19 +24,13 @@
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
-from ..attention import FeedForward
-from ..attention_processor import (
- Attention,
- AttentionProcessor,
- FluxAttnProcessor2_0,
- FluxAttnProcessor2_0_NPU,
- FusedFluxAttnProcessor2_0,
-)
+from ..attention import AttentionMixin, FeedForward
from ..cache_utils import CacheMixin
from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
+from .transformer_flux import FluxAttention, FluxAttnProcessor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -223,6 +217,8 @@ def __init__(
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
if is_torch_npu_available():
+ from ..attention_processor import FluxAttnProcessor2_0_NPU
+
deprecation_message = (
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
"should be set explicitly using the `set_attn_processor` method."
@@ -230,17 +226,15 @@ def __init__(
deprecate("npu_processor", "0.34.0", deprecation_message)
processor = FluxAttnProcessor2_0_NPU()
else:
- processor = FluxAttnProcessor2_0()
+ processor = FluxAttnProcessor()
- self.attn = Attention(
+ self.attn = FluxAttention(
query_dim=dim,
- cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
- qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
@@ -292,17 +286,15 @@ def __init__(
self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
- self.attn = Attention(
+ self.attn = FluxAttention(
query_dim=dim,
- cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
- processor=FluxAttnProcessor2_0(),
- qk_norm=qk_norm,
+ processor=FluxAttnProcessor(),
eps=eps,
)
@@ -376,7 +368,13 @@ def forward(
class ChromaTransformer2DModel(
- ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
+ ModelMixin,
+ ConfigMixin,
+ PeftAdapterMixin,
+ FromOriginalModelMixin,
+ FluxTransformer2DLoadersMixin,
+ CacheMixin,
+ AttentionMixin,
):
"""
The Transformer model introduced in Flux, modified for Chroma.
@@ -475,106 +473,6 @@ def __init__(
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
- def fuse_qkv_projections(self):
- """
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
- are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
-
- This API is 🧪 experimental.
-
-
- """
- self.original_attn_processors = None
-
- for _, attn_processor in self.attn_processors.items():
- if "Added" in str(attn_processor.__class__.__name__):
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
-
- self.original_attn_processors = self.attn_processors
-
- for module in self.modules():
- if isinstance(module, Attention):
- module.fuse_projections(fuse=True)
-
- self.set_attn_processor(FusedFluxAttnProcessor2_0())
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
- def unfuse_qkv_projections(self):
- """Disables the fused QKV projection if enabled.
-
-
-
- This API is 🧪 experimental.
-
-
-
- """
- if self.original_attn_processors is not None:
- self.set_attn_processor(self.original_attn_processors)
-
def forward(
self,
hidden_states: torch.Tensor,
diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py
index 608ea7f6e5f0..9080cd508de4 100644
--- a/src/diffusers/models/transformers/transformer_flux.py
+++ b/src/diffusers/models/transformers/transformer_flux.py
@@ -12,28 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from typing import Any, Dict, Optional, Tuple, Union
+import inspect
+from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
+import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
-from ..attention import FeedForward
-from ..attention_processor import (
- Attention,
- AttentionProcessor,
- FluxAttnProcessor2_0,
- FluxAttnProcessor2_0_NPU,
- FusedFluxAttnProcessor2_0,
-)
+from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
-from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
+from ..embeddings import (
+ CombinedTimestepGuidanceTextProjEmbeddings,
+ CombinedTimestepTextProjEmbeddings,
+ apply_rotary_emb,
+ get_1d_rotary_pos_embed,
+)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
@@ -42,6 +42,307 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ encoder_query = encoder_key = encoder_value = None
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+
+ encoder_query = encoder_key = encoder_value = (None,)
+ if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
+ encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
+ if attn.fused_projections:
+ return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
+
+
+class FluxAttnProcessor:
+ _attention_backend = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
+
+ def __call__(
+ self,
+ attn: "FluxAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
+ attn, hidden_states, encoder_hidden_states
+ )
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if attn.added_kv_proj_dim is not None:
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+
+ encoder_query = attn.norm_added_q(encoder_query)
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([encoder_query, query], dim=1)
+ key = torch.cat([encoder_key, key], dim=1)
+ value = torch.cat([encoder_value, value], dim=1)
+
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ hidden_states = dispatch_attention_fn(
+ query, key, value, attn_mask=attention_mask, backend=self._attention_backend
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ )
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class FluxIPAdapterAttnProcessor(torch.nn.Module):
+ """Flux Attention processor for IP-Adapter."""
+
+ _attention_backend = None
+
+ def __init__(
+ self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
+ ):
+ super().__init__()
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+
+ if not isinstance(num_tokens, (tuple, list)):
+ num_tokens = [num_tokens]
+
+ if not isinstance(scale, list):
+ scale = [scale] * len(num_tokens)
+ if len(scale) != len(num_tokens):
+ raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
+ self.scale = scale
+
+ self.to_k_ip = nn.ModuleList(
+ [
+ nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
+ for _ in range(len(num_tokens))
+ ]
+ )
+ self.to_v_ip = nn.ModuleList(
+ [
+ nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
+ for _ in range(len(num_tokens))
+ ]
+ )
+
+ def __call__(
+ self,
+ attn: "FluxAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ip_hidden_states: Optional[List[torch.Tensor]] = None,
+ ip_adapter_masks: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size = hidden_states.shape[0]
+
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
+ attn, hidden_states, encoder_hidden_states
+ )
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+ ip_query = query
+
+ if encoder_hidden_states is not None:
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+
+ encoder_query = attn.norm_added_q(encoder_query)
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([encoder_query, query], dim=1)
+ key = torch.cat([encoder_key, key], dim=1)
+ value = torch.cat([encoder_value, value], dim=1)
+
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ )
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ # IP-adapter
+ ip_attn_output = torch.zeros_like(hidden_states)
+
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
+ ):
+ ip_key = to_k_ip(current_ip_hidden_states)
+ ip_value = to_v_ip(current_ip_hidden_states)
+
+ ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim)
+
+ current_ip_hidden_states = dispatch_attention_fn(
+ ip_query,
+ ip_key,
+ ip_value,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ )
+ current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
+ current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
+ ip_attn_output += scale * current_ip_hidden_states
+
+ return hidden_states, encoder_hidden_states, ip_attn_output
+ else:
+ return hidden_states
+
+
+class FluxAttention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = FluxAttnProcessor
+ _available_processors = [
+ FluxAttnProcessor,
+ FluxIPAdapterAttnProcessor,
+ ]
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ added_proj_bias: Optional[bool] = True,
+ out_bias: bool = True,
+ eps: float = 1e-5,
+ out_dim: int = None,
+ context_pre_only: Optional[bool] = None,
+ pre_only: bool = False,
+ elementwise_affine: bool = True,
+ processor=None,
+ ):
+ super().__init__()
+
+ self.head_dim = dim_head
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
+ self.use_bias = bias
+ self.dropout = dropout
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.context_pre_only = context_pre_only
+ self.pre_only = pre_only
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.added_proj_bias = added_proj_bias
+
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ if not self.pre_only:
+ self.to_out = torch.nn.ModuleList([])
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(torch.nn.Dropout(dropout))
+
+ if added_kv_proj_dim is not None:
+ self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
+
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
+
+
@maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
@@ -54,6 +355,8 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int,
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
if is_torch_npu_available():
+ from ..attention_processor import FluxAttnProcessor2_0_NPU
+
deprecation_message = (
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
"should be set explicitly using the `set_attn_processor` method."
@@ -61,17 +364,15 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int,
deprecate("npu_processor", "0.34.0", deprecation_message)
processor = FluxAttnProcessor2_0_NPU()
else:
- processor = FluxAttnProcessor2_0()
+ processor = FluxAttnProcessor()
- self.attn = Attention(
+ self.attn = FluxAttention(
query_dim=dim,
- cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
- qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
@@ -118,17 +419,15 @@ def __init__(
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
- self.attn = Attention(
+ self.attn = FluxAttention(
query_dim=dim,
- cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
- processor=FluxAttnProcessor2_0(),
- qk_norm=qk_norm,
+ processor=FluxAttnProcessor(),
eps=eps,
)
@@ -152,6 +451,7 @@ def forward(
encoder_hidden_states, emb=temb
)
joint_attention_kwargs = joint_attention_kwargs or {}
+
# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
@@ -180,7 +480,6 @@ def forward(
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
-
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
@@ -195,8 +494,45 @@ def forward(
return encoder_hidden_states, hidden_states
+class FluxPosEmbed(nn.Module):
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
+ def __init__(self, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ cos_out = []
+ sin_out = []
+ pos = ids.float()
+ is_mps = ids.device.type == "mps"
+ is_npu = ids.device.type == "npu"
+ freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
+ for i in range(n_axes):
+ cos, sin = get_1d_rotary_pos_embed(
+ self.axes_dim[i],
+ pos[:, i],
+ theta=self.theta,
+ repeat_interleave_real=True,
+ use_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ cos_out.append(cos)
+ sin_out.append(sin)
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
+ return freqs_cos, freqs_sin
+
+
class FluxTransformer2DModel(
- ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
+ ModelMixin,
+ ConfigMixin,
+ PeftAdapterMixin,
+ FromOriginalModelMixin,
+ FluxTransformer2DLoadersMixin,
+ CacheMixin,
+ AttentionMixin,
):
"""
The Transformer model introduced in Flux.
@@ -292,106 +628,6 @@ def __init__(
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
- def fuse_qkv_projections(self):
- """
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
- are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
-
- This API is 🧪 experimental.
-
-
- """
- self.original_attn_processors = None
-
- for _, attn_processor in self.attn_processors.items():
- if "Added" in str(attn_processor.__class__.__name__):
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
-
- self.original_attn_processors = self.attn_processors
-
- for module in self.modules():
- if isinstance(module, Attention):
- module.fuse_projections(fuse=True)
-
- self.set_attn_processor(FusedFluxAttnProcessor2_0())
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
- def unfuse_qkv_projections(self):
- """Disables the fused QKV projection if enabled.
-
-
-
- This API is 🧪 experimental.
-
-
-
- """
- if self.original_attn_processors is not None:
- self.set_attn_processor(self.original_attn_processors)
-
def forward(
self,
hidden_states: torch.Tensor,
diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py
index 2df05cb8eb36..cadcedb98a14 100644
--- a/src/diffusers/utils/__init__.py
+++ b/src/diffusers/utils/__init__.py
@@ -67,6 +67,9 @@
is_bitsandbytes_version,
is_bs4_available,
is_cosmos_guardrail_available,
+ is_flash_attn_3_available,
+ is_flash_attn_available,
+ is_flash_attn_version,
is_flax_available,
is_ftfy_available,
is_gguf_available,
@@ -90,6 +93,8 @@
is_peft_version,
is_pytorch_retinaface_available,
is_safetensors_available,
+ is_sageattention_available,
+ is_sageattention_version,
is_scipy_available,
is_sentencepiece_available,
is_tensorboard_available,
@@ -108,6 +113,7 @@
is_unidecode_available,
is_wandb_available,
is_xformers_available,
+ is_xformers_version,
requires_backends,
)
from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py
index 7c04287d33ed..f8f04cc03abd 100644
--- a/src/diffusers/utils/constants.py
+++ b/src/diffusers/utils/constants.py
@@ -41,6 +41,8 @@
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
DIFFUSERS_REQUEST_TIMEOUT = 60
+DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
+DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index c08ccf7ade27..901aec4b2205 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -258,6 +258,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class AttentionBackendName(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class AuraFlowTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -1368,6 +1383,10 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+def attention_backend(*args, **kwargs):
+ requires_backends(attention_backend, ["torch"])
+
+
class ComponentsManager(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py
index f12e9de33172..a27c2da648f4 100644
--- a/src/diffusers/utils/import_utils.py
+++ b/src/diffusers/utils/import_utils.py
@@ -220,6 +220,9 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
_nltk_available, _nltk_version = _is_package_available("nltk")
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
+_sageattention_available, _sageattention_version = _is_package_available("sageattention")
+_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
+_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
def is_torch_available():
@@ -378,6 +381,18 @@ def is_hpu_available():
return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
+def is_sageattention_available():
+ return _sageattention_available
+
+
+def is_flash_attn_available():
+ return _flash_attn_available
+
+
+def is_flash_attn_3_available():
+ return _flash_attn_3_available
+
+
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -804,6 +819,51 @@ def is_optimum_quanto_version(operation: str, version: str):
return compare_versions(parse(_optimum_quanto_version), operation, version)
+def is_xformers_version(operation: str, version: str):
+ """
+ Compares the current xformers version to a given reference with an operation.
+
+ Args:
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _xformers_available:
+ return False
+ return compare_versions(parse(_xformers_version), operation, version)
+
+
+def is_sageattention_version(operation: str, version: str):
+ """
+ Compares the current sageattention version to a given reference with an operation.
+
+ Args:
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _sageattention_available:
+ return False
+ return compare_versions(parse(_sageattention_version), operation, version)
+
+
+def is_flash_attn_version(operation: str, version: str):
+ """
+ Compares the current flash-attention version to a given reference with an operation.
+
+ Args:
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _flash_attn_available:
+ return False
+ return compare_versions(parse(_flash_attn_version), operation, version)
+
+
def get_objects_from_module(module):
"""
Returns a dict of object names and values in a module, while skipping private/internal objects
diff --git a/tests/pipelines/chroma/test_pipeline_chroma.py b/tests/pipelines/chroma/test_pipeline_chroma.py
index fc5749f96cd8..5121a2b52d75 100644
--- a/tests/pipelines/chroma/test_pipeline_chroma.py
+++ b/tests/pipelines/chroma/test_pipeline_chroma.py
@@ -7,12 +7,7 @@
from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import torch_device
-from ..test_pipelines_common import (
- FluxIPAdapterTesterMixin,
- PipelineTesterMixin,
- check_qkv_fusion_matches_attn_procs_length,
- check_qkv_fusion_processors_exist,
-)
+from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
class ChromaPipelineFastTests(
@@ -126,12 +121,10 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
- assert check_qkv_fusion_matches_attn_procs_length(
- pipe.transformer, pipe.transformer.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
diff --git a/tests/pipelines/chroma/test_pipeline_chroma_img2img.py b/tests/pipelines/chroma/test_pipeline_chroma_img2img.py
index 02b20527b2f9..d518e1b7b8d1 100644
--- a/tests/pipelines/chroma/test_pipeline_chroma_img2img.py
+++ b/tests/pipelines/chroma/test_pipeline_chroma_img2img.py
@@ -8,12 +8,7 @@
from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import floats_tensor, torch_device
-from ..test_pipelines_common import (
- FluxIPAdapterTesterMixin,
- PipelineTesterMixin,
- check_qkv_fusion_matches_attn_procs_length,
- check_qkv_fusion_processors_exist,
-)
+from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
class ChromaImg2ImgPipelineFastTests(
@@ -129,12 +124,10 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
- assert check_qkv_fusion_matches_attn_procs_length(
- pipe.transformer, pipe.transformer.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
index 8d63619c402b..ab4cf3273489 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
@@ -16,11 +16,7 @@
)
from diffusers.utils.torch_utils import randn_tensor
-from ..test_pipelines_common import (
- PipelineTesterMixin,
- check_qkv_fusion_matches_attn_procs_length,
- check_qkv_fusion_processors_exist,
-)
+from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@@ -170,12 +166,10 @@ def test_fused_qkv_projections(self):
original_image_slice = image[0, -3:, -3:, -1]
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
- assert check_qkv_fusion_matches_attn_procs_length(
- pipe.transformer, pipe.transformer.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py
index a848ec615e40..cc8266e1a54c 100644
--- a/tests/pipelines/flux/test_pipeline_flux.py
+++ b/tests/pipelines/flux/test_pipeline_flux.py
@@ -28,8 +28,7 @@
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
- check_qkv_fusion_matches_attn_procs_length,
- check_qkv_fusion_processors_exist,
+ check_qkv_fused_layers_exist,
)
@@ -171,12 +170,10 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
- assert check_qkv_fusion_matches_attn_procs_length(
- pipe.transformer, pipe.transformer.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
diff --git a/tests/pipelines/flux/test_pipeline_flux_control.py b/tests/pipelines/flux/test_pipeline_flux_control.py
index d8d0774e1e32..42283da6fd03 100644
--- a/tests/pipelines/flux/test_pipeline_flux_control.py
+++ b/tests/pipelines/flux/test_pipeline_flux_control.py
@@ -8,11 +8,7 @@
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import torch_device
-from ..test_pipelines_common import (
- PipelineTesterMixin,
- check_qkv_fusion_matches_attn_procs_length,
- check_qkv_fusion_processors_exist,
-)
+from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@@ -140,12 +136,10 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
- assert check_qkv_fusion_matches_attn_procs_length(
- pipe.transformer, pipe.transformer.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
diff --git a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
index a2f7c9171082..0abd08e37300 100644
--- a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
+++ b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
@@ -15,11 +15,7 @@
torch_device,
)
-from ..test_pipelines_common import (
- PipelineTesterMixin,
- check_qkv_fusion_matches_attn_procs_length,
- check_qkv_fusion_processors_exist,
-)
+from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@@ -134,12 +130,10 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
- assert check_qkv_fusion_matches_attn_procs_length(
- pipe.transformer, pipe.transformer.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index 13c25ccaa469..387eb6a614f9 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -37,6 +37,7 @@
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
+from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_processor import AttnProcessor
from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
@@ -98,6 +99,20 @@ def check_qkv_fusion_processors_exist(model):
return all(p.startswith("Fused") for p in proc_names)
+def check_qkv_fused_layers_exist(model, layer_names):
+ is_fused_submodules = []
+ for submodule in model.modules():
+ if not isinstance(submodule, AttentionModuleMixin):
+ continue
+ is_fused_attribute_set = submodule.fused_projections
+ is_fused_layer = True
+ for layer in layer_names:
+ is_fused_layer = is_fused_layer and getattr(submodule, layer, None) is not None
+ is_fused = is_fused_attribute_set and is_fused_layer
+ is_fused_submodules.append(is_fused)
+ return all(is_fused_submodules)
+
+
class SDFunctionTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.