Skip to content

Commit

Permalink
training code
Browse files Browse the repository at this point in the history
  • Loading branch information
Cakeyan committed Dec 10, 2024
1 parent 1a7512f commit a67f8c0
Show file tree
Hide file tree
Showing 12 changed files with 2,067 additions and 20 deletions.
73 changes: 73 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

## News 🔥

- [2024/12/10] 🚀 We release the training code for further training / fine-tuning!

- [2024/11/25] 🚀 [Allegro-TI2V](https://huggingface.co/rhymes-ai/Allegro-TI2V) is open sourced!

- [2024/10/30] 🚀 We release multi-card inference code and PAB in [Allegro-VideoSys](https://github.com/nightsnack/Allegro-VideoSys). With VideoSys framework, the inference time can be further reduced to 3 mins (8xH100) and 2 mins (8xH100+PAB). We also opened a PR to the original [VideoSys repo](https://github.com/NUS-HPC-AI-Lab/VideoSys).
Expand Down Expand Up @@ -160,6 +162,77 @@
### Multi-Card Inference
For both Allegro & Allegro TI2V: We release multi-card inference code and PAB in [Allegro-VideoSys](https://github.com/nightsnack/Allegro-VideoSys).

### Training / Fine-tuning

1. Download the [Allegro GitHub code](https://github.com/rhymes-ai/Allegro), [Allegro model weights](https://huggingface.co/rhymes-ai/Allegro) and prepare the environment in [requirements.txt](https://github.com/rhymes-ai/Allegro/blob/main/requirements.txt).

2. Our training code loads the dataset from `.parquet` files. We recommend first constructing a `.jsonl` file to store all data cases in a list. Each case should be stored as a dict, like this:

```json
[
{"path": "foo/bar.mp4", "num_frames": 123, "height": 1080, "width": 1920, "cap": "This is a fake caption."}
...
]
```

After that, run [dataset_utils.py](https://github.com/rhymes-ai/Allegro/blob/main/allegro/utils/dataset_utils.py) to convert `.jsonl` into `.parquet`.

> The absolute path to each video is constructed by joining `args.data_dir` in [train.py](https://github.com/rhymes-ai/Allegro/blob/main/train.py) with the `path` value from the dataset. Therefore, you may define `path` as a relative path within your dataset and set `args.data_dir` to the root dir when running training.

3. Run Training / Fine-tuning:

```bash
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1

export WANDB_API_KEY=YOUR_WANDB_KEY

accelerate launch \
--num_machines 1 \
--num_processes 8 \
--machine_rank 0 \
--config_file config/accelerate_config.yaml \
train.py \
--project_name Allegro_Finetune_88x720p \
--dit_config /huggingface/rhymes-ai/Allegro/transformer/config.json \
--dit /huggingface/rhymes-ai/Allegro/transformer/ \
--tokenizer /huggingface/rhymes-ai/Allegro/tokenizer \
--text_encoder /huggingface/rhymes-ai/Allegro/text_encoder \
--vae /huggingface/rhymes-ai/Allegro/vae \
--vae_load_mode encoder_only \
--enable_ae_compile \
--dataset t2v \
--data_dir /data_root/ \
--meta_file data.parquet \
--sample_rate 2 \
--num_frames 88 \
--max_height 720 \
--max_width 1280 \
--hw_thr 1.0 \
--hw_aspect_thr 1.5 \
--dataloader_num_workers 10 \
--gradient_checkpointing \
--train_batch_size 1 \
--gradient_accumulation_steps 1 \
--max_train_steps 1000000 \
--learning_rate 1e-4 \
--lr_scheduler constant \
--lr_warmup_steps 0 \
--mixed_precision bf16 \
--report_to wandb \
--allow_tf32 \
--enable_stable_fp32 \
--model_max_length 512 \
--cfg 0.1 \
--checkpointing_steps 100 \
--resume_from_checkpoint latest \
--output_dir ./output/Allegro_Finetune_88x720p
```

4. (Optional) To customize the model training arguments, you may create a `.json` file following [config.json](https://huggingface.co/rhymes-ai/Allegro/blob/main/transformer/config.json). Feel free to use our training code to train a video diffusion model from scratch.



## Limitation
- The model cannot render celebrities, legible text, specific locations, streets or buildings.

Expand Down
18 changes: 18 additions & 0 deletions allegro/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from torchvision import transforms
from transformers import T5Tokenizer

from allegro.dataset.allegro_datasets import Allegro_dataset
from allegro.dataset.transform import ToTensorVideo, TemporalRandomCrop, CenterCropResizeVideo

def getdataset(args):
temporal_sample = TemporalRandomCrop(args.num_frames)
norm_fun = transforms.Lambda(lambda x: 2. * x - 1.)
if args.dataset == 't2v':
transform = transforms.Compose([
ToTensorVideo(),
CenterCropResizeVideo((args.max_height, args.max_width)),
norm_fun
])
tokenizer = T5Tokenizer.from_pretrained(args.tokenizer, cache_dir=args.cache_dir)
return Allegro_dataset(args, transform=transform, temporal_sample=temporal_sample, tokenizer=tokenizer)
raise NotImplementedError(args.dataset)
177 changes: 177 additions & 0 deletions allegro/dataset/allegro_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import os
import pickle
import random
import numpy as np
import pandas as pd
from einops import rearrange
from collections import OrderedDict
import hashlib
import json

import torch
from torch.utils.data import Dataset
from PIL import Image
from decord import VideoReader
from accelerate.logging import get_logger

from allegro.utils.utils import text_preprocessing, lprint

logger = get_logger(__name__)

def filter_resolution(height, width, max_height, max_width, hw_thr, hw_aspect_thr):
aspect = max_height / max_width
if height >= max_height * hw_thr and width >= max_width * hw_thr and height / width >= aspect / hw_aspect_thr and height / width <= aspect * hw_aspect_thr:
return True
return False

def filter_duration(num_frames, sample_frames, sample_rate):
target_frames = (sample_frames - 1) * sample_rate + 1
if num_frames >= target_frames:
return True
return False

def random_sample_rate(num_frames, sample_frames, sample_rate):
supported_sample_rate = []
for sr in sample_rate:
if filter_duration(num_frames, sample_frames, sr):
supported_sample_rate.append(sr)
sr = None
if len(supported_sample_rate) > 0:
sr = random.choice(supported_sample_rate)
return sr

class Allegro_dataset(Dataset):
def __init__(self, args, transform, temporal_sample, tokenizer):
self.data_dir = args.data_dir
self.meta_file = args.meta_file
self.num_frames = args.num_frames
self.sample_rate = sorted(list(map(int, args.sample_rate.split(','))))
self.transform = transform
self.temporal_sample = temporal_sample
self.tokenizer = tokenizer
self.model_max_length = args.model_max_length
self.cfg = args.cfg
self.max_height = args.max_height
self.max_width = args.max_width
self.hw_thr = args.hw_thr
self.hw_aspect_thr = args.hw_aspect_thr
self.cache_dir = args.cache_dir

self.filter_data_list()

def __len__(self):
return len(self.data_list)

def __getitem__(self, idx):
try:
data = self.data_list.loc[idx]
if data['path'].endswith('.mp4'):
return self.get_video(data)
else:
return self.get_image(data)
except Exception as e:
logger.info(f"Error with {e}, file {data['path']}")
return self.__getitem__(random.randint(0, self.__len__() - 1))

def get_video(self, data):
vr = VideoReader(os.path.join(self.data_dir, data['path']))
sr = random_sample_rate(len(vr), self.num_frames, self.sample_rate)
if sr is None:
raise ValueError(f'no supported sr for num_frames ({len(vr)}), sample_frames ({self.num_frames}), sample_rate ({self.sample_rate})')
fidx = np.arange(0, len(vr), sr).astype(int)
sidx, eidx = self.temporal_sample(len(fidx))
fidx = fidx[sidx: eidx]
if self.num_frames != len(fidx):
raise ValueError(f'num_frames ({self.num_frames}) is not equal with frame_indices ({len(fidx)})')
video = vr.get_batch(fidx).asnumpy()
video = torch.from_numpy(video)
video = video.permute(0, 3, 1, 2)
video = self.transform(video)
video = video.transpose(0, 1)

text = text_preprocessing(data['cap']) if random.random() > self.cfg else ""
text_tokens_and_mask = self.tokenizer(
text,
max_length=self.model_max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors='pt'
)
input_ids = text_tokens_and_mask['input_ids']
cond_mask = text_tokens_and_mask['attention_mask']

return dict(pixel_values=video, input_ids=input_ids, cond_mask=cond_mask)

def get_image(self, data):
image = Image.open(os.path.join(self.data_dir, data['path'])).convert('RGB')
image = torch.from_numpy(np.array(image))
image = rearrange(image, 'h w c -> c h w').unsqueeze(0)
image = self.transform(image)
image = image.transpose(0, 1)

text = text_preprocessing(data['cap']) if random.random() > self.cfg else ""
text_tokens_and_mask = self.tokenizer(
text,
max_length=self.model_max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors='pt'
)
input_ids = text_tokens_and_mask['input_ids']
cond_mask = text_tokens_and_mask['attention_mask']

image.close()
return dict(pixel_values=image, input_ids=input_ids, cond_mask=cond_mask)

def filter_data_list(self):
lprint(f'Filter data {self.meta_file}')
cache_path = self.check_cache()
if os.path.exists(cache_path):
lprint(f'Load cache {cache_path}')
with open(cache_path, 'rb') as f:
self.data_list = pickle.load(f)
lprint(f'Data length: {len(self.data_list)}')
return

self.data_list = pd.read_parquet(self.meta_file)
pick_list = []
for i in range(len(self.data_list)):
data = self.data_list.loc[i]
is_pick = filter_resolution(data['height'], data['width'], self.max_height, self.max_width, self.hw_thr, self.hw_aspect_thr)
if data['path'].endswith('.mp4'):
is_pick = is_pick and filter_duration(data['num_frames'], self.num_frames, self.sample_rate[0])
pick_list.append(is_pick)
if i % 1000000 == 0:
lprint(f'Filter {i}')
self.data_list = self.data_list.loc[pick_list]
self.data_list = self.data_list.reset_index(drop=True)
lprint(f'Data length: {len(self.data_list)}')
with open(cache_path, 'wb') as f:
pickle.dump(self.data_list, f)
lprint(f'Save cache {cache_path}')

def check_cache(self):
unique_identifiers = OrderedDict()
unique_identifiers['class'] = type(self).__name__
unique_identifiers['data_dir'] = self.data_dir
unique_identifiers['meta_file'] = self.meta_file
unique_identifiers['num_frames'] = self.num_frames
unique_identifiers['sample_rate'] = self.sample_rate[0]
unique_identifiers['hw_thr'] = self.hw_thr
unique_identifiers['hw_aspect_thr'] = self.hw_aspect_thr
unique_identifiers['max_height'] = self.max_height
unique_identifiers['max_width'] = self.max_width
unique_description = json.dumps(
unique_identifiers, indent=4, default=lambda obj: obj.unique_identifiers
)
unique_description_hash = hashlib.md5(unique_description.encode('utf-8')).hexdigest()
path_to_cache = os.path.join(self.cache_dir, 'data_cache')
os.makedirs(path_to_cache, exist_ok=True)
cache_path = os.path.join(
path_to_cache, f'{unique_description_hash}-{type(self).__name__}-filter_cache.pkl'
)
return cache_path
Loading

0 comments on commit a67f8c0

Please sign in to comment.