Skip to content

Commit

Permalink
give an initial conv
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 1, 2022
1 parent 84bf535 commit 86568e7
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions video_diffusion_pytorch/video_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
def exists(x):
return x is not None

def is_odd(n):
return (n % 2) == 1

def default(val, d):
if exists(val):
return val
Expand Down Expand Up @@ -296,12 +299,20 @@ def __init__(
dim_mults=(1, 2, 4, 8),
channels = 3,
attn_heads = 8,
use_bert_text_cond = False
use_bert_text_cond = False,
init_dim = None,
init_kernel_size = 7
):
super().__init__()
self.channels = channels

dims = [channels, *map(lambda m: dim * m, dim_mults)]
init_dim = default(init_dim, dim // 3 * 2)
assert is_odd(init_kernel_size)

init_padding = init_kernel_size // 2
self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size, init_kernel_size), padding = (0, init_padding, init_padding))

dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))

time_dim = dim * 4
Expand Down

0 comments on commit 86568e7

Please sign in to comment.