-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
112 lines (87 loc) · 3.75 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
from model.Unet import Unet_ecnoder, Unet
from video_diffusion_pytorch.diffusion import GaussianDiffusion
from video_diffusion_pytorch.video_diffusion_pytorch import Trainer
from datasets.cityscape import data_load
import argparse
import os
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import random
from video_diffusion_pytorch.model_creation import create_gaussian_diffusion
import yaml
from types import SimpleNamespace
def arg_parse():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="configs/cityscape.yml", help="config path")
args = parser.parse_args()
with open(args.config, 'r') as file:
params = yaml.safe_load(file)
return SimpleNamespace(**params)
def main(rank, world_size, args):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = args.port
dist.init_process_group("nccl", rank=rank, world_size=world_size)
init_seed = args.random_seed
torch.manual_seed(init_seed)
random.seed(init_seed)
torch.cuda.manual_seed(init_seed)
train_dataloader = data_load(args.data_root, stage='train', batch_size=args.batch_size, num_workers=args.dataloader_num_workers, frames_per_sample=14, distributed=False)
val_dataloader = data_load(args.data_root, stage='val', batch_size=args.batch_size_val, num_workers=args.dataloader_num_workers, frames_per_sample=14, distributed=False)
encoder=Unet_ecnoder( dim=args.dim,
dim_mults=args.dim_mults,
img_size=args.img_size,
n_head_channels=args.n_head_channels,
attn_resolutions=args.attn_resolutions
)
# #res=128
model = Unet( dim=args.dim,
dim_mults=args.dim_mults,
img_size=args.img_size,
n_head_channels=args.n_head_channels,
attn_resolutions=args.attn_resolutions
)
diffusion = create_gaussian_diffusion(
steps = args.total_steps, # total diffusion steps for training
timestep_respacing=args.time_respacing, # steps for sampling
encoder=encoder,
model=model,
image_size = args.img_size,
num_frames = 8,
rank=rank,
loss_type = 'l1' # L1 or L2
).to(rank)
diffusion = DDP(diffusion, device_ids=[rank])
trainer = Trainer(
rank = rank,
train_dataloader = train_dataloader,
val_dataloader = val_dataloader,
diffusion_model=diffusion,
train_batch_size = 1,
train_lr = args.train_lr,
save_and_sample_every = args.save_steps,
train_num_steps = args.train_steps, # total training steps
gradient_accumulate_every = 2, # gradient accumulation steps
ema_decay = 0.995, # exponential moving average decay
amp = args.amp # turn on mixed precision
)
#train
if args.train:
if args.resume:
trainer.load()
trainer.train()
else:
trainer.pth_transfer(args.check_points)
test_dataloader = data_load(args.data_root, stage='test', batch_size=args.batch_size_test, num_workers=args.dataloader_num_workers, frames_per_sample=30, distributed=False)
trainer.test(test_dataloader,flag = args.task)
#trainer.inpainting(train_dataloader)
#inpainting
# trainer.pth_transfer(333)
# if rank==0:
# trainer.inpainting(test_dataloader)
if __name__=="__main__":
args = arg_parse()
mp.spawn(main, args=(args.gpus, args), nprocs=args.gpus, join=True)
#dist.barrier()
dist.destroy_process_group()