Skip to content

Commit

Permalink
adjustments to doublestreamblock
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidLandup0 committed Sep 30, 2024
1 parent 252a814 commit d6d626c
Showing 1 changed file with 27 additions and 53 deletions.
80 changes: 27 additions & 53 deletions keras_hub/src/models/flux/flux_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,56 +181,36 @@ def call(self, x):
)



class DoubleStreamBlock(keras.Model):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float,
qkv_bias: bool = False,
):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
super().__init__()

mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size

self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = keras.layers.LayerNormalization(
elementwise_affine=False, eps=1e-6
)
self.img_attn = SelfAttention(
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
)
self.img_norm1 = keras.layers.LayerNormalization(epsilon=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)

self.img_norm2 = keras.layers.LayerNormalization(
elementwise_affine=False, eps=1e-6
)
self.img_mlp = keras.Sequential(
[
keras.layers.Dense(mlp_hidden_dim, use_bias=True),
keras.layers.Activation("tanh"),
keras.layers.Dense(hidden_size, use_bias=True),
]
)
self.img_norm2 = keras.layers.LayerNormalization(epsilon=1e-6)
self.img_mlp = keras.Sequential([
keras.layers.Dense(mlp_hidden_dim, use_bias=True),
keras.layers.Activation("tanh"),
keras.layers.Dense(hidden_size, use_bias=True)
])

self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = keras.layers.LayerNormalization(
elementwise_affine=False, eps=1e-6
)
self.txt_attn = SelfAttention(
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
)
self.txt_norm1 = keras.layers.LayerNormalization(epsilon=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)

self.txt_norm2 = keras.layers.LayerNormalization(
elementwise_affine=False, eps=1e-6
)
self.txt_mlp = keras.Sequential(
[
keras.layers.Dense(mlp_hidden_dim, use_bias=True),
keras.layers.Activation("tanh"),
keras.layers.Dense(hidden_size, use_bias=True),
]
)
self.txt_norm2 = keras.layers.LayerNormalization(epsilon=1e-6)
self.txt_mlp = keras.Sequential([
keras.layers.Dense(mlp_hidden_dim, use_bias=True),
keras.layers.Activation("tanh"),
keras.layers.Dense(hidden_size, use_bias=True)
])

def call(self, img, txt, vec, pe):
img_mod1, img_mod2 = self.img_mod(vec)
Expand All @@ -240,19 +220,15 @@ def call(self, img, txt, vec, pe):
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(
img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k = self.img_attn.norm(img_q, img_k)

# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(
txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)

# run actual attention
q = keras.ops.concatenate((txt_q, img_q), axis=2)
Expand All @@ -264,13 +240,11 @@ def call(self, img, txt, vec, pe):

# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp(
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)

# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp(
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
return img, txt


0 comments on commit d6d626c

Please sign in to comment.