From b6e0a17c5488b923d884272f7e46170352b0f0d5 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 18 Mar 2023 21:26:31 -0700 Subject: [PATCH] make continuous positional bias calculations more efficient, while generalizing to any number of dimensions --- make_a_video_pytorch/make_a_video.py | 52 ++++++++++++++++++---------- setup.py | 2 +- 2 files changed, 35 insertions(+), 19 deletions(-) diff --git a/make_a_video_pytorch/make_a_video.py b/make_a_video_pytorch/make_a_video.py index 8d7d4cd..71d9bb3 100644 --- a/make_a_video_pytorch/make_a_video.py +++ b/make_a_video_pytorch/make_a_video.py @@ -3,6 +3,7 @@ from operator import mul import torch +import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat, pack, unpack @@ -81,13 +82,10 @@ def __init__( dim, heads, num_dims = 1, - layers = 2, - log_dist = True, - cache_rel_pos = False + layers = 2 ): super().__init__() self.num_dims = num_dims - self.log_dist = log_dist self.net = nn.ModuleList([]) self.net.append(nn.Sequential(nn.Linear(self.num_dims, dim), nn.SiLU())) @@ -97,9 +95,6 @@ def __init__( self.net.append(nn.Linear(dim, heads)) - self.cache_rel_pos = cache_rel_pos - self.register_buffer('rel_pos', None, persistent = False) - @property def device(self): return next(self.parameters()).device @@ -107,23 +102,44 @@ def device(self): def forward(self, *dimensions): device = self.device - if not exists(self.rel_pos) or not self.cache_rel_pos: - positions = [torch.arange(d, device = device) for d in dimensions] - grid = torch.stack(torch.meshgrid(*positions, indexing = 'ij')) - grid = rearrange(grid, 'c ... -> (...) c') - rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c') + shape = torch.tensor(dimensions, device = device) + rel_pos_shape = 2 * shape - 1 + + # calculate strides + + strides = torch.flip(rel_pos_shape, (0,)).cumprod(dim = -1) + strides = torch.flip(F.pad(strides, (1, -1), value = 1), (0,)) + + # get all positions and calculate all the relative distances + + positions = [torch.arange(d, device = device) for d in dimensions] + grid = torch.stack(torch.meshgrid(*positions, indexing = 'ij'), dim = -1) + grid = rearrange(grid, '... c -> (...) c') + rel_dist = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c') - if self.log_dist: - rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1) + # get all relative positions across all dimensions - self.register_buffer('rel_pos', rel_pos, persistent = False) + rel_positions = [torch.arange(-d + 1, d, device = device) for d in dimensions] + rel_pos_grid = torch.stack(torch.meshgrid(*rel_positions, indexing = 'ij'), dim = -1) + rel_pos_grid = rearrange(rel_pos_grid, '... c -> (...) c') - rel_pos = self.rel_pos.float() + # mlp input + + bias = rel_pos_grid.float() for layer in self.net: - rel_pos = layer(rel_pos) + bias = layer(bias) + + # convert relative distances to indices of the bias + + rel_dist += (shape - 1) # make sure all positive + rel_dist *= strides + rel_dist_indices = rel_dist.sum(dim = -1) + + # now select the bias for each unique relative position combination - return rearrange(rel_pos, 'i j h -> h i j') + bias = bias[rel_dist_indices] + return rearrange(bias, 'i j h -> h i j') # helper classes diff --git a/setup.py b/setup.py index ff33156..d8e6bc8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'make-a-video-pytorch', packages = find_packages(exclude=[]), - version = '0.0.9', + version = '0.0.10', license='MIT', description = 'Make-A-Video - Pytorch', author = 'Phil Wang',