Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes and improvement when using xformers. #4

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,17 @@

from ldm.modules.diffusionmodules.util import checkpoint

# CrossAttn precision handling
import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")

try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
XFORMERS_IS_AVAILBLE = os.environ.get("ATTN_XFORMERS", "enabled") == "enabled"
except:
XFORMERS_IS_AVAILBLE = False

# CrossAttn precision handling
import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")

def exists(val):
return val is not None

Expand Down Expand Up @@ -177,9 +176,9 @@ def forward(self, x, context=None, mask=None):
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

del q, k

if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
Expand Down Expand Up @@ -211,7 +210,7 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
self.attention_op: Optional[Any] = xformers.ops.MemoryEfficientAttentionCutlassOp if hasattr(xformers.ops, "MemoryEfficientAttentionCutlassOp") else None

def forward(self, x, context=None, mask=None):
q = self.to_q(x)
Expand Down
12 changes: 8 additions & 4 deletions ldm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from ldm.modules.attention import MemoryEfficientCrossAttention

try:
import os
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
XFORMERS_IS_AVAILBLE = os.environ.get("ATTN_XFORMERS", "enabled") == "enabled"
except:
XFORMERS_IS_AVAILBLE = False
print("No module 'xformers'. Proceeding without it.")
Expand Down Expand Up @@ -234,7 +235,7 @@ def __init__(self, in_channels):
kernel_size=1,
stride=1,
padding=0)
self.attention_op: Optional[Any] = None
self.attention_op: Optional[Any] = xformers.ops.MemoryEfficientAttentionCutlassOp if hasattr(xformers.ops, "MemoryEfficientAttentionCutlassOp") else None

def forward(self, x):
h_ = x
Expand Down Expand Up @@ -288,8 +289,11 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
elif attn_type == "vanilla-xformers":
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
return MemoryEfficientAttnBlock(in_channels)
elif type == "memory-efficient-cross-attn":
attn_kwargs["query_dim"] = in_channels
elif attn_type == "memory-efficient-cross-attn":
if attn_kwargs is None:
attn_kwargs = {"query_dim": in_channels}
else:
attn_kwargs["query_dim"] = in_channels
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
elif attn_type == "none":
return nn.Identity(in_channels)
Expand Down
5 changes: 2 additions & 3 deletions ldm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,8 @@ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda",
freeze=True, layer="last"):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just a cosmetic change (to remove unused variables), or this this needed for any functional changes?

I'd avoid too much divergence from the stability codebase, even if we don't really agree with the coding style. Ideally I would not have a custom fork at all.

Copy link
Author

@madrang madrang Dec 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_model_and_transforms loads and create image transforms for img2img but they don't use them.
It's faster to just call create_model and only process the part that is needed.
That change is already in a PR waiting to be merged Stability-AI#52

del model.visual
self.model = model
self.model = open_clip.create_model(arch, device=torch.device('cpu'), pretrained=version)
del self.model.visual

self.device = device
self.max_length = max_length
Expand Down