Skip to content

Commit

Permalink
fix: remove flash attention for now
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Jan 19, 2023
1 parent 76876ee commit 4a9edd9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
17 changes: 13 additions & 4 deletions a_unet/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

import torch
import torch.nn.functional as F
import xformers
import xformers.ops
from einops import pack, rearrange, reduce, repeat, unpack
from torch import Tensor, einsum, nn
from typing_extensions import TypeGuard
Expand Down Expand Up @@ -230,18 +228,28 @@ def ConvNextV2Block(dim: int, channels: int) -> nn.Module:


def AttentionBase(features: int, head_features: int, num_heads: int) -> nn.Module:
scale = head_features**-0.5
mid_features = head_features * num_heads
to_out = nn.Linear(in_features=mid_features, out_features=features, bias=False)

def forward(
q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
) -> Tensor:
# Use memory efficient attention
out = xformers.ops.memory_efficient_attention(q, k, v)
h = num_heads
# Split heads
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
# Compute similarity matrix and add eventual mask
sim = einsum("... n d, ... m d -> ... n m", q, k) * scale
# Get attention matrix with softmax
attn = sim.softmax(dim=-1)
# Compute values
out = einsum("... n m, ... m d -> ... n d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return to_out(out)

return Module([to_out], forward)


def LinearAttentionBase(features: int, head_features: int, num_heads: int) -> nn.Module:
scale = head_features**-0.5
mid_features = head_features * num_heads
Expand All @@ -262,6 +270,7 @@ def forward(q: Tensor, k: Tensor, v: Tensor) -> Tensor:

return Module([to_out], forward)


def FixedEmbedding(max_length: int, features: int):
embedding = nn.Embedding(max_length, features)

Expand Down
9 changes: 7 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@
setup(
name="a-unet",
packages=find_packages(exclude=[]),
version="0.0.13",
version="0.0.14",
license="MIT",
description="A-UNet",
long_description_content_type="text/markdown",
author="Flavio Schneider",
author_email="archinetai@protonmail.com",
url="https://github.com/archinetai/a-unet",
keywords=["artificial intelligence", "deep learning"],
install_requires=["torch>=1.6", "data-science-types>=0.2", "einops>=0.6.0", "xformers>=0.0.13"],
install_requires=[
"torch>=1.6",
"data-science-types>=0.2",
"einops>=0.6.0",
"xformers>=0.0.13",
],
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
Expand Down

0 comments on commit 4a9edd9

Please sign in to comment.