Skip to content

Commit

Permalink
automate more
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 24, 2022
1 parent c633e63 commit b68bd6b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ If you have a better idea how this is done, just open a github issue.
- [x] relative positional encodings in attention (space and time) - use T5 relative positional bias instead of what they used
- [x] add a forward keyword argument that arrests attention across time (as reported / claimed in the paper, this type of image + video simultaneous training improves results)
- [x] consider doing a 3d version of CLIP, so one can eventually apply the lessons of DALL-E2 to video https://github.com/lucidrains/dalle2-video
- [x] offer way for Trainer to curtail or pad frames, if gif is too long
- [ ] find a good torchvideo-like library (torchvideo seems immature) for training on fireworks
- [ ] project text into 4-8 tokens, and use them as memory key / values to condition both time and space in attention blocks
- [ ] offer way for Trainer to curtail or pad frames, if gif is too long
- [ ] prepare a jax version for large scale TPU training
- [ ] have Trainer take care of conditional video synthesis, with text offered as corresponding {video_filename}.txt within the same folder
- [ ] see if ffcv or squirrel-core is a good fit
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'video-diffusion-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.5',
version = '0.1.6',
license='MIT',
description = 'Video Diffusion - Pytorch',
author = 'Phil Wang',
Expand Down
19 changes: 17 additions & 2 deletions video_diffusion_pytorch/video_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ def gif_to_tensor(path, channels = 3, transform = T.ToTensor()):
tensors = tuple(map(transform, seek_all_images(img, channels = channels)))
return torch.stack(tensors, dim = 1)

def identity(t):
def identity(t, *args, **kwargs):
return t

def normalize_img(t):
Expand All @@ -668,6 +668,17 @@ def normalize_img(t):
def unnormalize_img(t):
return (t + 1) * 0.5

def cast_num_frames(t, *, frames):
f = t.shape[1]

if f == frames:
return t

if f > frames:
return t[:, :frames]

return F.pad(t, (0, 0, 0, 0, 0, frames - f))

class Dataset(data.Dataset):
def __init__(
self,
Expand All @@ -676,6 +687,7 @@ def __init__(
channels = 3,
num_frames = 16,
horizontal_flip = False,
force_num_frames = True,
exts = ['gif']
):
super().__init__()
Expand All @@ -684,6 +696,8 @@ def __init__(
self.channels = channels
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

self.cast_num_frames_fn = partial(cast_num_frames, frames = num_frames) if force_num_frames else identity

self.transform = T.Compose([
T.Resize(image_size),
T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity),
Expand All @@ -697,7 +711,8 @@ def __len__(self):

def __getitem__(self, index):
path = self.paths[index]
return gif_to_tensor(path, self.channels, transform = self.transform)
tensor = gif_to_tensor(path, self.channels, transform = self.transform)
return self.cast_num_frames_fn(tensor)

# trainer class

Expand Down

0 comments on commit b68bd6b

Please sign in to comment.