Skip to content

Commit

Permalink
give an initial conv
Browse files Browse the repository at this point in the history
0.2.2
  • Loading branch information
lucidrains committed May 1, 2022
1 parent 84bf535 commit f3afff9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'video-diffusion-pytorch',
packages = find_packages(exclude=[]),
version = '0.2.0',
version = '0.2.2',
license='MIT',
description = 'Video Diffusion - Pytorch',
author = 'Phil Wang',
Expand Down
17 changes: 15 additions & 2 deletions video_diffusion_pytorch/video_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
def exists(x):
return x is not None

def is_odd(n):
return (n % 2) == 1

def default(val, d):
if exists(val):
return val
Expand Down Expand Up @@ -296,12 +299,20 @@ def __init__(
dim_mults=(1, 2, 4, 8),
channels = 3,
attn_heads = 8,
use_bert_text_cond = False
use_bert_text_cond = False,
init_dim = None,
init_kernel_size = 7
):
super().__init__()
self.channels = channels

dims = [channels, *map(lambda m: dim * m, dim_mults)]
init_dim = default(init_dim, dim // 3 * 2)
assert is_odd(init_kernel_size)

init_padding = init_kernel_size // 2
self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size, init_kernel_size), padding = (0, init_padding, init_padding))

dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))

time_dim = dim * 4
Expand Down Expand Up @@ -394,6 +405,8 @@ def forward(
):
assert not (self.has_cond and not exists(cond)), 'cond must be passed in if cond_dim specified'

x = self.init_conv(x)

t = self.time_mlp(time) if exists(self.time_mlp) else None

# classifier free guidance
Expand Down

0 comments on commit f3afff9

Please sign in to comment.