Skip to content

Commit

Permalink
add the main contributions of the paper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 17, 2022
1 parent c272866 commit 225f304
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 0 deletions.
85 changes: 85 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,91 @@ The gist of the paper comes down to, take a SOTA text-to-image model (here they

<a href="https://www.youtube.com/watch?v=AcvmyqGgMh8">AI Coffee Break explanation</a>

## Install

```bash
$ pip install make-a-video
```

## Usage

Passing in video features

```python
import torch
from make_a_video_pytorch import Pseudo3DConv, SpatioTemporalAttention

conv = Pseudo3DConv(
dim = 256,
kernel_size = 3
)

attn = SpatioTemporalAttention(
dim = 256,
dim_head = 64,
heads = 8
)

video = torch.randn(1, 256, 8, 16, 16) # (batch, features, frames, height, width)

conv_out = conv(video) # (1, 256, 8, 16, 16)
attn_out = attn(video) # (1, 256, 8, 16, 16)
```

Passing in images (if one were to pretrain on images first, both temporal convolution and attention will be automatically skipped)

```python
import torch
from make_a_video_pytorch import Pseudo3DConv, SpatioTemporalAttention

conv = Pseudo3DConv(
dim = 256,
kernel_size = 3
)

attn = SpatioTemporalAttention(
dim = 256,
dim_head = 64,
heads = 8
)

images = torch.randn(1, 256, 16, 16) # (batch, features, height, width)

conv_out = conv(images) # (1, 256, 16, 16)
attn_out = attn(images) # (1, 256, 16, 16)
```

You can also control the two modules so that when fed 3-dimensional features, it only does training spatially

```python
import torch
from make_a_video_pytorch import Pseudo3DConv, SpatioTemporalAttention

conv = Pseudo3DConv(
dim = 256,
kernel_size = 3
)

attn = SpatioTemporalAttention(
dim = 256,
dim_head = 64,
heads = 8
)

video = torch.randn(1, 256, 8, 16, 16) # (batch, features, frames, height, width)

# below it will not train across time

conv_out = conv(video, convolve_across_time = False) # (1, 256, 8, 16, 16)
attn_out = attn(video, attend_across_time = False) # (1, 256, 8, 16, 16)
```

## Todo

- [ ] wire up <a href="https://github.com/lucidrains/dalle2-pytorch">dalle2-pytorch</a> unet with pseudo 3d convs + spatial temporal attention
- [ ] give attention the best positional embeddings research has to offer
- [ ] soup up the attention

## Citations

```bibtex
Expand Down
1 change: 1 addition & 0 deletions make_a_video_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from make_a_video_pytorch.make_a_video import Pseudo3DConv, SpatioTemporalAttention
162 changes: 162 additions & 0 deletions make_a_video_pytorch/make_a_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import torch
from torch import nn, einsum
from einops import rearrange

# helper functions

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

# layernorm 3d

class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.g = nn.Parameter(torch.ones(dim))

def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + eps).rsqrt() * self.g

# helper classes

class Attention(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads

self.norm = LayerNorm(dim)

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

nn.init.zeros_(self.to_out.weight.data) # identity with skip connection

def forward(self, x):
x = self.norm(x)

q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

q = q * self.scale

sim = einsum('b h i d, b h j d -> b h i j', q, k)

attn = sim.softmax(dim = -1)

out = einsum('b h i j, b h j d -> b h i d', attn, v)

out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

# main contribution - pseudo 3d conv

class Pseudo3DConv(nn.Module):
def __init__(
self,
dim,
*,
kernel_size,
dim_out = None,
temporal_kernel_size = None,
**kwargs
):
super().__init__()
dim_out = default(dim_out, dim)
temporal_kernel_size = default(temporal_kernel_size, kernel_size)

self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2)
self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size, padding = temporal_kernel_size // 2)

nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
nn.init.zeros_(self.temporal_conv.bias.data)

def forward(
self,
x,
convolve_across_time = True
):
b, c, *_, h, w = x.shape

is_video = x.ndim == 5
convolve_across_time &= is_video

if is_video:
x = rearrange(x, 'b c f h w -> (b f) c h w')

x = self.spatial_conv(x)

if is_video:
x = rearrange(x, '(b f) c h w -> b c f h w', b = b)

if not convolve_across_time:
return x

x = rearrange(x, 'b c f h w -> (b h w) c f')

x = self.temporal_conv(x)

x = rearrange(x, '(b h w) c f -> b c f h w', h = h, w = w)

return x

# factorized spatial temporal attention from Ho et al.
# todo - take care of relative positional biases + rotary embeddings

class SpatioTemporalAttention(nn.Module):
def __init__(
self,
dim,
*,
dim_head = 64,
heads = 8
):
super().__init__()
self.spatial_attn = Attention(dim = dim, dim_head = dim_head, heads = heads)
self.temporal_attn = Attention(dim = dim, dim_head = dim_head, heads = heads)

def forward(
self,
x,
attend_across_time = True
):
b, c, *_, h, w = x.shape
is_video = x.ndim == 5
attend_across_time &= is_video

if is_video:
x = rearrange(x, 'b c f h w -> (b f) (h w) c')
else:
x = rearrange(x, 'b c h w -> b (h w) c')

x = self.spatial_attn(x) + x

if is_video:
x = rearrange(x, '(b f) (h w) c -> b c f h w', b = b, h = h, w = w)
else:
x = rearrange(x, 'b (h w) c -> b c h w', h = h, w = w)

if not attend_across_time:
return x

x = rearrange(x, 'b c f h w -> (b h w) f c')

x = self.temporal_attn(x) + x

x = rearrange(x, '(b h w) f c -> b c f h w', w = w, h = h)

return x

0 comments on commit 225f304

Please sign in to comment.