diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 509cd873..4ad6ce63 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -8,18 +8,17 @@ from ldm.modules.diffusionmodules.util import checkpoint +# CrossAttn precision handling +import os +_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") try: import xformers import xformers.ops - XFORMERS_IS_AVAILBLE = True + XFORMERS_IS_AVAILBLE = os.environ.get("ATTN_XFORMERS", "enabled") == "enabled" except: XFORMERS_IS_AVAILBLE = False -# CrossAttn precision handling -import os -_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") - def exists(val): return val is not None @@ -177,9 +176,9 @@ def forward(self, x, context=None, mask=None): sim = einsum('b i d, b j d -> b i j', q, k) * self.scale else: sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - + del q, k - + if exists(mask): mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max @@ -211,7 +210,7 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) - self.attention_op: Optional[Any] = None + self.attention_op: Optional[Any] = xformers.ops.MemoryEfficientAttentionCutlassOp if hasattr(xformers.ops, "MemoryEfficientAttentionCutlassOp") else None def forward(self, x, context=None, mask=None): q = self.to_q(x) diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index b089eebb..06c58bc2 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -9,9 +9,10 @@ from ldm.modules.attention import MemoryEfficientCrossAttention try: + import os import xformers import xformers.ops - XFORMERS_IS_AVAILBLE = True + XFORMERS_IS_AVAILBLE = os.environ.get("ATTN_XFORMERS", "enabled") == "enabled" except: XFORMERS_IS_AVAILBLE = False print("No module 'xformers'. Proceeding without it.") @@ -234,7 +235,7 @@ def __init__(self, in_channels): kernel_size=1, stride=1, padding=0) - self.attention_op: Optional[Any] = None + self.attention_op: Optional[Any] = xformers.ops.MemoryEfficientAttentionCutlassOp if hasattr(xformers.ops, "MemoryEfficientAttentionCutlassOp") else None def forward(self, x): h_ = x @@ -288,8 +289,11 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): elif attn_type == "vanilla-xformers": print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") return MemoryEfficientAttnBlock(in_channels) - elif type == "memory-efficient-cross-attn": - attn_kwargs["query_dim"] = in_channels + elif attn_type == "memory-efficient-cross-attn": + if attn_kwargs is None: + attn_kwargs = {"query_dim": in_channels} + else: + attn_kwargs["query_dim"] = in_channels return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) elif attn_type == "none": return nn.Identity(in_channels) diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 4edd5496..07dd7116 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -144,9 +144,8 @@ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", freeze=True, layer="last"): super().__init__() assert layer in self.LAYERS - model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) - del model.visual - self.model = model + self.model = open_clip.create_model(arch, device=torch.device('cpu'), pretrained=version) + del self.model.visual self.device = device self.max_length = max_length