Skip to content

Commit

Permalink
add sandwich norm, from the coqview paper, for stabilizing training e…
Browse files Browse the repository at this point in the history
…ven more, hidden behind feature flag
  • Loading branch information
lucidrains committed Oct 19, 2021
1 parent 15d2f35 commit e4e101f
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
2 changes: 2 additions & 0 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def __init__(
attn_types = None,
loss_img_weight = 7,
stable = False,
sandwich_norm = False,
shift_tokens = True,
rotary_emb = True
):
Expand Down Expand Up @@ -371,6 +372,7 @@ def __init__(
image_fmap_size = image_fmap_size,
sparse_attn = sparse_attn,
stable = stable,
sandwich_norm = sandwich_norm,
shift_tokens = shift_tokens,
rotary_emb = rotary_emb
)
Expand Down
12 changes: 8 additions & 4 deletions dalle_pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,16 @@ def forward(self, x, **kwargs):
# layer norm

class PreNorm(nn.Module):
def __init__(self, dim, fn):
def __init__(self, dim, fn, sandwich = False):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity()
self.fn = fn

def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
x = self.norm(x)
x = self.fn(x, **kwargs)
return self.norm_out(x)

# feed forward

Expand Down Expand Up @@ -145,6 +148,7 @@ def __init__(
image_fmap_size = None,
sparse_attn = False,
stable = False,
sandwich_norm = False,
shift_tokens = False,
rotary_emb = True
):
Expand Down Expand Up @@ -183,8 +187,8 @@ def __init__(
attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff))

layers.append(nn.ModuleList([
LayerScale(dim, ind + 1, PreNorm(dim, attn)),
LayerScale(dim, ind + 1, PreNorm(dim, ff))
LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)),
LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich = sandwich_norm))
]))

execute_type = ReversibleSequence if reversible else SequentialSequence
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'dalle-pytorch',
packages = find_packages(),
include_package_data = True,
version = '1.0.8',
version = '1.1.0',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit e4e101f

Please sign in to comment.