-
Notifications
You must be signed in to change notification settings - Fork 233
/
inference_multigpu.py
123 lines (100 loc) · 4.67 KB
/
inference_multigpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import torch
import sys
import argparse
import random
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from diffusers.utils import export_to_video
from pyramid_dit import PyramidDiTForVideoGeneration
from trainer_misc import init_distributed_mode, init_sequence_parallel_group
import PIL
from PIL import Image
def get_args():
parser = argparse.ArgumentParser('Pytorch Multi-process Script', add_help=False)
parser.add_argument('--model_name', default='pyramid_flux', type=str, help="The model name", choices=["pyramid_flux", "pyramid_mmdit"])
parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16")
parser.add_argument('--model_path', default='/home/jinyang06/models/pyramid-flow', type=str, help='Set it to the downloaded checkpoint dir')
parser.add_argument('--variant', default='diffusion_transformer_768p', type=str,)
parser.add_argument('--task', default='t2v', type=str, choices=['i2v', 't2v'])
parser.add_argument('--temp', default=16, type=int, help='The generated latent num, num_frames = temp * 8 + 1')
parser.add_argument('--sp_group_size', default=2, type=int, help="The number of gpus used for inference, should be 2 or 4")
parser.add_argument('--sp_proc_num', default=-1, type=int, help="The number of process used for video training, default=-1 means using all process.")
return parser.parse_args()
def main():
args = get_args()
# setup DDP
init_distributed_mode(args)
assert args.world_size == args.sp_group_size, "The sequence parallel size should be DDP world size"
# Enable sequence parallel
init_sequence_parallel_group(args)
device = torch.device('cuda')
rank = args.rank
model_dtype = args.model_dtype
model = PyramidDiTForVideoGeneration(
args.model_path,
model_dtype,
model_name=args.model_name,
model_variant=args.variant,
)
model.vae.to(device)
model.dit.to(device)
model.text_encoder.to(device)
model.vae.enable_tiling()
if model_dtype == "bf16":
torch_dtype = torch.bfloat16
elif model_dtype == "fp16":
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
# The video generation config
if args.variant == 'diffusion_transformer_768p':
width = 1280
height = 768
else:
assert args.variant == 'diffusion_transformer_384p'
width = 640
height = 384
if args.task == 't2v':
prompt = "A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors"
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):
frames = model.generate(
prompt=prompt,
num_inference_steps=[20, 20, 20],
video_num_inference_steps=[10, 10, 10],
height=height,
width=width,
temp=args.temp,
guidance_scale=7.0, # The guidance for the first frame, set it to 7 for 384p variant
video_guidance_scale=5.0, # The guidance for the other video latent
output_type="pil",
save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
cpu_offloading=False, # If OOM, set it to True to reduce memory usage
inference_multigpu=True,
)
if rank == 0:
export_to_video(frames, "./text_to_video_sample.mp4", fps=24)
else:
assert args.task == 'i2v'
image_path = 'assets/the_great_wall.jpg'
image = Image.open(image_path).convert("RGB")
image = image.resize((width, height))
prompt = "FPV flying over the Great Wall"
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):
frames = model.generate_i2v(
prompt=prompt,
input_image=image,
num_inference_steps=[10, 10, 10],
temp=args.temp,
video_guidance_scale=4.0,
output_type="pil",
save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
cpu_offloading=False, # If OOM, set it to True to reduce memory usage
inference_multigpu=True,
)
if rank == 0:
export_to_video(frames, "./image_to_video_sample.mp4", fps=24)
torch.distributed.barrier()
if __name__ == "__main__":
main()