-
Notifications
You must be signed in to change notification settings - Fork 54
/
train.py
80 lines (63 loc) · 2.91 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import torch
import numpy as np
from tqdm import tqdm
from modules import Paella
from torch import nn, optim
from warmup_scheduler import GradualWarmupScheduler
from utils import get_dataloader, load_conditional_models
steps = 100_000
warmup_updates = 10000
batch_size = 16
checkpoint_frequency = 2000
lr = 1e-4
train_device = "cuda"
dataset_path = ""
byt5_model_name = "google/byt5-xl"
vqmodel_path = ""
run_name = "Paella-ByT5-XL-v1"
output_path = "output"
checkpoint_path = f"{run_name}.pt"
def train():
os.makedirs(output_path, exist_ok=True)
device = torch.device(train_device)
dataloader = get_dataloader(dataset_path, batch_size=batch_size)
checkpoint = torch.load(checkpoint_path, map_location=device) if os.path.exists(checkpoint_path) else None
model = Paella(byt5_embd=2560).to(device)
vqgan, (byt5_tokenizer, byt5) = load_conditional_models(byt5_model_name, vqmodel_path, device)
optimizer = optim.AdamW(model.parameters(), lr=lr)
scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_updates)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1, reduction='none')
start_iter = 1
if checkpoint is not None:
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.last_epoch = checkpoint['scheduler_last_step']
start_iter = checkpoint['scheduler_last_step'] + 1
del checkpoint
pbar = tqdm(range(start_iter, steps+1))
model.train()
for i, (images, captions) in enumerate(dataloader):
images = images.to(device)
with torch.no_grad():
if np.random.rand() < 0.05:
byt5_captions = [''] * len(captions)
else:
byt5_captions = captions
byt5_tokens = byt5_tokenizer(byt5_captions, padding="longest", return_tensors="pt", max_length=768, truncation=True).input_ids.to(device)
byt_embeddings = byt5(input_ids=byt5_tokens).last_hidden_state
t = (1-torch.rand(images.size(0), device=device))
latents = vqgan.encode(images)[2]
noised_latents, _ = model.add_noise(latents, t)
pred = model(noised_latents, t, byt_embeddings)
loss = criterion(pred, latents)
loss.backward()
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scheduler.step()
optimizer.zero_grad()
acc = (pred.argmax(1) == latents).float().mean()
pbar.set_postfix({'bs': images.size(0), 'loss': loss.item(), 'acc': acc.item(), 'grad_norm': grad_norm.item(), 'lr': optimizer.param_groups[0]['lr'], 'total_steps': scheduler.last_epoch})
if i % checkpoint_frequency == 0:
torch.save({'state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_last_step': scheduler.last_epoch, 'iter' : i}, checkpoint_path)
if __name__ == '__main__':
train()