diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index 45b35edb..5d711ded 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -324,6 +324,7 @@ def __init__( attn_types = None, loss_img_weight = 7, stable = False, + sandwich_norm = False, shift_tokens = True, rotary_emb = True ): @@ -371,6 +372,7 @@ def __init__( image_fmap_size = image_fmap_size, sparse_attn = sparse_attn, stable = stable, + sandwich_norm = sandwich_norm, shift_tokens = shift_tokens, rotary_emb = rotary_emb ) diff --git a/dalle_pytorch/transformer.py b/dalle_pytorch/transformer.py index 2c11ed98..f592e542 100644 --- a/dalle_pytorch/transformer.py +++ b/dalle_pytorch/transformer.py @@ -56,13 +56,16 @@ def forward(self, x, **kwargs): # layer norm class PreNorm(nn.Module): - def __init__(self, dim, fn): + def __init__(self, dim, fn, sandwich = False): super().__init__() self.norm = nn.LayerNorm(dim) + self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity() self.fn = fn def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) + x = self.norm(x) + x = self.fn(x, **kwargs) + return self.norm_out(x) # feed forward @@ -145,6 +148,7 @@ def __init__( image_fmap_size = None, sparse_attn = False, stable = False, + sandwich_norm = False, shift_tokens = False, rotary_emb = True ): @@ -183,8 +187,8 @@ def __init__( attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff)) layers.append(nn.ModuleList([ - LayerScale(dim, ind + 1, PreNorm(dim, attn)), - LayerScale(dim, ind + 1, PreNorm(dim, ff)) + LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)), + LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich = sandwich_norm)) ])) execute_type = ReversibleSequence if reversible else SequentialSequence diff --git a/setup.py b/setup.py index 992ea1c8..0ea3610c 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'dalle-pytorch', packages = find_packages(), include_package_data = True, - version = '1.0.8', + version = '1.1.0', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang',