From 7e0838e38dae6dcf55ceded259ee552d5bd5feb7 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 19 Mar 2023 10:14:35 -0700 Subject: [PATCH] follow spatiotemporal attention with a feedforward, and add the highly effective token shift along the time axis in the hidden layer --- README.md | 10 +++++ make_a_video_pytorch/make_a_video.py | 55 ++++++++++++++++++++-------- setup.py | 2 +- 3 files changed, 50 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 2e5ac7d..8c777dc 100644 --- a/README.md +++ b/README.md @@ -158,3 +158,13 @@ video_as_images_out = unet(video, enable_time = False) url = {https://openreview.net/forum?id=oapKSVM2bcj} } ``` + +```bibtex +@article{Dong2021AttentionIN, + title = {Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth}, + author = {Yihe Dong and Jean-Baptiste Cordonnier and Andreas Loukas}, + journal = {ArXiv}, + year = {2021}, + volume = {abs/2103.03404} +} +``` diff --git a/make_a_video_pytorch/make_a_video.py b/make_a_video_pytorch/make_a_video.py index 71d9bb3..d09d6d6 100644 --- a/make_a_video_pytorch/make_a_video.py +++ b/make_a_video_pytorch/make_a_video.py @@ -58,18 +58,35 @@ def forward(self, x): # feedforward +def shift_token(t): + t, t_shift = t.chunk(2, dim = 1) + t_shift = F.pad(t_shift, (0, 0, 0, 0, 1, -1), value = 0.) + return torch.cat((t, t_shift), dim = 1) + class GEGLU(nn.Module): def forward(self, x): - x, gate = x.chunk(2, dim = -1) + x, gate = x.chunk(2, dim = 1) return x * F.gelu(gate) -def FeedForward(dim, mult = 4): - inner_dim = int(dim * mult * 2 / 3) - return nn.Sequential( - nn.Linear(dim, inner_dim, bias = False), - GEGLU(), - nn.Linear(inner_dim, bias = False) - ) +class FeedForward(nn.Module): + def __init__(self, dim, mult = 4): + super().__init__() + + inner_dim = int(dim * mult * 2 / 3) + self.proj_in = nn.Sequential( + nn.Conv3d(dim, inner_dim * 2, 1, bias = False), + GEGLU() + ) + + self.proj_out = nn.Conv3d(inner_dim, dim, 1, bias = False) + + def forward(self, x, enable_time = True): + x = self.proj_in(x) + + if enable_time: + x = shift_token(x) + + return self.proj_out(x) # best relative positional encoding @@ -242,7 +259,6 @@ def forward( return x # factorized spatial temporal attention from Ho et al. -# todo - take care of relative positional biases + rotary embeddings class SpatioTemporalAttention(nn.Module): def __init__( @@ -250,7 +266,9 @@ def __init__( dim, *, dim_head = 64, - heads = 8 + heads = 8, + add_feed_forward = True, + ff_mult = 4 ): super().__init__() self.spatial_attn = Attention(dim = dim, dim_head = dim_head, heads = heads) @@ -259,6 +277,11 @@ def __init__( self.temporal_attn = Attention(dim = dim, dim_head = dim_head, heads = heads) self.temporal_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 1) + if not add_feed_forward: + return + + self.ff = FeedForward(dim = dim, mult = ff_mult) + def forward( self, x, @@ -282,17 +305,17 @@ def forward( else: x = rearrange(x, 'b (h w) c -> b c h w', h = h, w = w) - if not enable_time: - return x + if enable_time: - x = rearrange(x, 'b c f h w -> (b h w) f c') + x = rearrange(x, 'b c f h w -> (b h w) f c') - time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1]) + time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1]) - x = self.temporal_attn(x, rel_pos_bias = time_rel_pos_bias) + x + x = self.temporal_attn(x, rel_pos_bias = time_rel_pos_bias) + x - x = rearrange(x, '(b h w) f c -> b c f h w', w = w, h = h) + x = rearrange(x, '(b h w) f c -> b c f h w', w = w, h = h) + x = self.ff(x, enable_time = enable_time) + x return x # resnet block diff --git a/setup.py b/setup.py index d8e6bc8..0986aa9 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'make-a-video-pytorch', packages = find_packages(exclude=[]), - version = '0.0.10', + version = '0.1.0', license='MIT', description = 'Make-A-Video - Pytorch', author = 'Phil Wang',