Skip to content

Commit

Permalink
now accepts timestep conditioning, for ddpm
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 10, 2022
1 parent b4730fe commit 8de344e
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 20 deletions.
81 changes: 62 additions & 19 deletions make_a_video_pytorch/make_a_video.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import math
import functools
from operator import mul

import torch
from torch import nn, einsum

from einops import rearrange, pack, unpack
Expand All @@ -22,6 +24,24 @@ def divisible_by(numer, denom):

mlist = nn.ModuleList

# for time conditioning

class SinusoidalPosEmb(nn.Module):
def __init__(self, dim, theta = 10000):
super().__init__()
self.theta = theta
self.dim = dim

def forward(self, x):
dtype, device = x.dtype, x.device
assert dtype == torch.float, 'input to sinusoidal pos emb must be a float type'

half_dim = self.dim // 2
emb = math.log(self.theta) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)

# layernorm 3d

class LayerNorm(nn.Module):
Expand Down Expand Up @@ -226,17 +246,17 @@ def __init__(
dim,
dim_out,
*,
time_cond_dim = None,
timestep_cond_dim = None,
groups = 8
):
super().__init__()

self.time_mlp = None
self.timestep_mlp = None

if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
if exists(timestep_cond_dim):
self.timestep_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_cond_dim, dim_out * 2)
nn.Linear(timestep_cond_dim, dim_out * 2)
)

self.block1 = Block(dim, dim_out, groups = groups)
Expand All @@ -246,14 +266,17 @@ def __init__(
def forward(
self,
x,
time_emb = None,
timestep_emb = None,
enable_time = True
):
assert not (exists(timestep_emb) ^ exists(self.timestep_mlp))

scale_shift = None
if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1')

if exists(self.timestep_mlp) and exists(timestep_emb):
time_emb = self.timestep_mlp(timestep_emb)
to_einsum_eq = 'b c 1 1 1' if x.ndim == 5 else 'b c 1 1'
time_emb = rearrange(time_emb, f'b c -> {to_einsum_eq}')
scale_shift = time_emb.chunk(2, dim = 1)

h = self.block1(x, scale_shift = scale_shift, enable_time = enable_time)
Expand Down Expand Up @@ -367,7 +390,8 @@ def __init__(
self_attns = (False, False, False, True),
temporal_compression = (False, True, True, True),
attn_dim_head = 64,
attn_heads = 8
attn_heads = 8,
condition_on_timestep = True
):
super().__init__()
assert len(dim_mult) == len(self_attns) == len(temporal_compression)
Expand All @@ -376,6 +400,20 @@ def __init__(
dims = [dim, *map(lambda mult: mult * dim, dim_mult)]
dim_in_out = zip(dims[:-1], dims[1:])

# timestep conditioning for DDPM, not to be confused with the time dimension of the video

self.to_timestep_cond = None
timestep_cond_dim = (dim * 4) if condition_on_timestep else None

if condition_on_timestep:
self.to_timestep_cond = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, timestep_cond_dim),
nn.SiLU()
)

# layers

self.downs = mlist([])
self.ups = mlist([])

Expand All @@ -386,21 +424,21 @@ def __init__(

mid_dim = dims[-1]

self.mid_block1 = ResnetBlock(mid_dim, mid_dim)
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim = timestep_cond_dim)
self.mid_attn = SpatioTemporalAttention(dim = mid_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim = timestep_cond_dim)

for _, self_attend, (dim_in, dim_out), compress_time in zip(range(num_layers), self_attns, dim_in_out, temporal_compression):

self.downs.append(mlist([
ResnetBlock(dim_in, dim_out),
ResnetBlock(dim_in, dim_out, timestep_cond_dim = timestep_cond_dim),
ResnetBlock(dim_out, dim_out),
SpatioTemporalAttention(dim = dim_out, **attn_kwargs) if self_attend else None,
Downsample(dim_out, downsample_time = compress_time)
]))

self.ups.append(mlist([
ResnetBlock(dim_out * 2, dim_in),
ResnetBlock(dim_out * 2, dim_in, timestep_cond_dim = timestep_cond_dim),
ResnetBlock(dim_in, dim_in),
SpatioTemporalAttention(dim = dim_in, **attn_kwargs) if self_attend else None,
Upsample(dim_out, upsample_time = compress_time)
Expand All @@ -415,14 +453,19 @@ def __init__(
def forward(
self,
x,
timestep = None,
enable_time = True
):
assert not (exists(self.to_timestep_cond) ^ exists(timestep))

t = self.to_timestep_cond(rearrange(timestep, '... -> (...)')) if exists(timestep) else None

x = self.conv_in(x, enable_time = enable_time)

hiddens = []

for block1, block2, maybe_attention, downsample in self.downs:
x = block1(x, enable_time = enable_time)
x = block1(x, t, enable_time = enable_time)
x = block2(x, enable_time = enable_time)

if exists(maybe_attention):
Expand All @@ -432,15 +475,15 @@ def forward(

x = downsample(x, enable_time = enable_time)

x = self.mid_block1(x, enable_time = enable_time)
x = self.mid_block1(x, t, enable_time = enable_time)
x = self.mid_attn(x, enable_time = enable_time)
x = self.mid_block2(x, enable_time = enable_time)
x = self.mid_block2(x, t, enable_time = enable_time)

for block1, block2, maybe_attention, upsample in reversed(self.ups):
x = upsample(x, enable_time = enable_time)
x = torch.cat((hiddens.pop() * self.skip_scale, x), dim = 1)

x = block1(x, enable_time = enable_time)
x = block1(x, t, enable_time = enable_time)
x = block2(x, enable_time = enable_time)

if exists(maybe_attention):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'make-a-video-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.4',
version = '0.0.5',
license='MIT',
description = 'Make-A-Video - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 8de344e

Please sign in to comment.