Skip to content

Commit

Permalink
remove convnext blocks, they do not work well, validated in video dif…
Browse files Browse the repository at this point in the history
…fusion repository
  • Loading branch information
lucidrains committed May 5, 2022
1 parent 84731bb commit 989f0fc
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 54 deletions.
11 changes: 0 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,3 @@ Samples and model checkpoints will be logged to `./results` periodically
note = {under review}
}
```

```bibtex
@misc{liu2022convnet,
title = {A ConvNet for the 2020s},
author = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
year = {2022},
eprint = {2201.03545},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
44 changes: 2 additions & 42 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,39 +142,6 @@ def forward(self, x, time_emb = None):
h = self.block2(h)
return h + self.res_conv(x)

class ConvNextBlock(nn.Module):
""" https://arxiv.org/abs/2201.03545 """

def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True):
super().__init__()
self.mlp = nn.Sequential(
nn.GELU(),
nn.Linear(time_emb_dim, dim)
) if exists(time_emb_dim) else None

self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)

self.net = nn.Sequential(
LayerNorm(dim) if norm else nn.Identity(),
nn.Conv2d(dim, dim_out * mult, 3, padding = 1),
nn.GELU(),
LayerNorm(dim_out * mult),
nn.Conv2d(dim_out * mult, dim_out, 3, padding = 1)
)

self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

def forward(self, x, time_emb = None):
h = self.ds_conv(x)

if exists(self.mlp) and exists(time_emb):
assert exists(time_emb), 'time emb must be passed in'
condition = self.mlp(time_emb)
h = h + rearrange(condition, 'b c -> b c 1 1')

h = self.net(h)
return h + self.res_conv(x)

class LinearAttention(nn.Module):
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
Expand Down Expand Up @@ -237,9 +204,7 @@ def __init__(
dim_mults=(1, 2, 4, 8),
channels = 3,
with_time_emb = True,
use_convnext = False,
resnet_block_groups = 8,
convnext_mult = 2
resnet_block_groups = 8
):
super().__init__()

Expand All @@ -253,12 +218,7 @@ def __init__(
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))

# resnet or convnext

if use_convnext:
block_klass = partial(ConvNextBlock, mult = convnext_mult)
else:
block_klass = partial(ResnetBlock, groups = resnet_block_groups)
block_klass = partial(ResnetBlock, groups = resnet_block_groups)

# time embeddings

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.11.2',
version = '0.12.0',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 989f0fc

Please sign in to comment.