forked from chengzeyi/stable-fast
-
Notifications
You must be signed in to change notification settings - Fork 0
/
optimize_lcm_pipeline.py
286 lines (252 loc) · 10.2 KB
/
optimize_lcm_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
MODEL = 'SimianLuo/LCM_Dreamshaper_v7'
VARIANT = None
CUSTOM_PIPELINE = 'latent_consistency_txt2img'
SCHEDULER = 'EulerAncestralDiscreteScheduler'
LORA = None
CONTROLNET = None
STEPS = 4
PROMPT = 'best quality, realistic, unreal engine, 4K, a beautiful girl'
NEGATIVE_PROMPT = None
SEED = None
WARMUPS = 3
BATCH = 1
HEIGHT = 768
WIDTH = 768
INPUT_IMAGE = None
CONTROL_IMAGE = None
OUTPUT_IMAGE = None
EXTRA_CALL_KWARGS = None
import importlib
import inspect
import argparse
import time
import json
import torch
from PIL import (Image, ImageDraw)
from diffusers.utils import load_image
from sfast.compilers.diffusion_pipeline_compiler import (compile,
CompilationConfig)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default=MODEL)
parser.add_argument('--variant', type=str, default=VARIANT)
parser.add_argument('--custom-pipeline', type=str, default=CUSTOM_PIPELINE)
parser.add_argument('--scheduler', type=str, default=SCHEDULER)
parser.add_argument('--lora', type=str, default=LORA)
parser.add_argument('--controlnet', type=str, default=CONTROLNET)
parser.add_argument('--steps', type=int, default=STEPS)
parser.add_argument('--prompt', type=str, default=PROMPT)
parser.add_argument('--negative-prompt', type=str, default=NEGATIVE_PROMPT)
parser.add_argument('--seed', type=int, default=SEED)
parser.add_argument('--warmups', type=int, default=WARMUPS)
parser.add_argument('--batch', type=int, default=BATCH)
parser.add_argument('--height', type=int, default=HEIGHT)
parser.add_argument('--width', type=int, default=WIDTH)
parser.add_argument('--extra-call-kwargs',
type=str,
default=EXTRA_CALL_KWARGS)
parser.add_argument('--input-image', type=str, default=INPUT_IMAGE)
parser.add_argument('--control-image', type=str, default=CONTROL_IMAGE)
parser.add_argument('--output-image', type=str, default=OUTPUT_IMAGE)
parser.add_argument(
'--compiler',
type=str,
default='sfast',
choices=['none', 'sfast', 'compile', 'compile-max-autotune'])
parser.add_argument('--quantize', action='store_true')
parser.add_argument('--no-fusion', action='store_true')
return parser.parse_args()
def load_model(pipeline_cls,
model,
variant=None,
custom_pipeline=None,
scheduler=None,
lora=None,
controlnet=None):
extra_kwargs = {}
if custom_pipeline is not None:
extra_kwargs['custom_pipeline'] = custom_pipeline
if variant is not None:
extra_kwargs['variant'] = variant
if controlnet is not None:
from diffusers import ControlNetModel
controlnet = ControlNetModel.from_pretrained(controlnet,
torch_dtype=torch.float16)
extra_kwargs['controlnet'] = controlnet
model = pipeline_cls.from_pretrained(model,
torch_dtype=torch.float16,
**extra_kwargs)
if scheduler is not None:
scheduler_cls = getattr(importlib.import_module('diffusers'),
scheduler)
model.scheduler = scheduler_cls.from_config(model.scheduler.config)
if lora is not None:
model.load_lora_weights(lora)
model.fuse_lora()
model.safety_checker = None
model.to(torch.device('cuda'))
return model
def compile_model(model):
config = CompilationConfig.Default()
# xformers and Triton are suggested for achieving best performance.
# It might be slow for Triton to generate, compile and fine-tune kernels.
try:
import xformers
config.enable_xformers = True
except ImportError:
print('xformers not installed, skip')
# NOTE:
# When GPU VRAM is insufficient or the architecture is too old, Triton might be slow.
# Disable Triton if you encounter this problem.
try:
import triton
config.enable_triton = True
except ImportError:
print('Triton not installed, skip')
# NOTE:
# CUDA Graph is suggested for small batch sizes and small resolutions to reduce CPU overhead.
# My implementation can handle dynamic shape with increased need for GPU memory.
# But when your GPU VRAM is insufficient or the image resolution is high,
# CUDA Graph could cause less efficient VRAM utilization and slow down the inference,
# especially when on Windows or WSL which has the "shared VRAM" mechanism.
# If you meet problems related to it, you should disable it.
config.enable_cuda_graph = True
model = compile(model, config)
return model
class IterationProfiler:
def __init__(self):
self.begin = None
self.end = None
self.num_iterations = 0
def get_iter_per_sec(self):
if self.begin is None or self.end is None:
return None
self.end.synchronize()
dur = self.begin.elapsed_time(self.end)
return self.num_iterations / dur * 1000.0
def callback_on_step_end(self, pipe, i, t, callback_kwargs):
if self.begin is None:
event = torch.cuda.Event(enable_timing=True)
event.record()
self.begin = event
else:
event = torch.cuda.Event(enable_timing=True)
event.record()
self.end = event
self.num_iterations += 1
return callback_kwargs
def main():
args = parse_args()
if args.input_image is None:
from diffusers import AutoPipelineForText2Image as pipeline_cls
else:
from diffusers import AutoPipelineForImage2Image as pipeline_cls
model = load_model(
pipeline_cls,
args.model,
variant=args.variant,
custom_pipeline=args.custom_pipeline,
scheduler=args.scheduler,
lora=args.lora,
controlnet=args.controlnet,
)
height = args.height or model.unet.config.sample_size * model.vae_scale_factor
width = args.width or model.unet.config.sample_size * model.vae_scale_factor
if args.quantize:
def quantize_unet(m):
from diffusers.utils import USE_PEFT_BACKEND
assert USE_PEFT_BACKEND
m = torch.quantization.quantize_dynamic(m, {torch.nn.Linear},
dtype=torch.qint8,
inplace=True)
return m
model.unet = quantize_unet(model.unet)
if hasattr(model, 'controlnet'):
model.controlnet = quantize_unet(model.controlnet)
if args.no_fusion:
torch.jit.set_fusion_strategy([('STATIC', 0), ('DYNAMIC', 0)])
if args.compiler == 'none':
pass
elif args.compiler == 'sfast':
model = compile_model(model)
elif args.compiler in ('compile', 'compile-max-autotune'):
mode = 'max-autotune' if args.compiler == 'compile-max-autotune' else None
model.unet = torch.compile(model.unet, mode=mode)
if hasattr(model, 'controlnet'):
model.controlnet = torch.compile(model.controlnet, mode=mode)
model.vae = torch.compile(model.vae, mode=mode)
else:
raise ValueError(f'Unknown compiler: {args.compiler}')
if args.input_image is None:
input_image = None
else:
input_image = load_image(args.input_image)
input_image = input_image.resize((width, height),
Image.LANCZOS)
if args.control_image is None:
if args.controlnet is None:
control_image = None
else:
control_image = Image.new('RGB', (width, height))
draw = ImageDraw.Draw(control_image)
draw.ellipse((width // 4, height // 4,
width // 4 * 3, height // 4 * 3),
fill=(255, 255, 255))
del draw
else:
control_image = load_image(args.control_image)
control_image = control_image.resize((width, height),
Image.LANCZOS)
def get_kwarg_inputs():
kwarg_inputs = dict(
prompt=args.prompt,
negative_prompt=args.negative_prompt,
height=height,
width=width,
num_inference_steps=args.steps,
num_images_per_prompt=args.batch,
generator=None if args.seed is None else torch.Generator(
device='cuda').manual_seed(args.seed),
**(dict() if args.extra_call_kwargs is None else json.loads(
args.extra_call_kwargs)),
)
if input_image is not None:
kwarg_inputs['image'] = input_image
if control_image is not None:
if input_image is None:
kwarg_inputs['image'] = control_image
else:
kwarg_inputs['control_image'] = control_image
return kwarg_inputs
# NOTE: Warm it up.
# The initial calls will trigger compilation and might be very slow.
# After that, it should be very fast.
if args.warmups > 0:
print('Begin warmup')
for _ in range(args.warmups):
model(**get_kwarg_inputs())
print('End warmup')
# Let's see it!
# Note: Progress bar might work incorrectly due to the async nature of CUDA.
kwarg_inputs = get_kwarg_inputs()
iter_profiler = IterationProfiler()
if 'callback_on_step_end' in inspect.signature(model).parameters:
kwarg_inputs[
'callback_on_step_end'] = iter_profiler.callback_on_step_end
begin = time.time()
output_images = model(**kwarg_inputs).images
end = time.time()
# Let's view it in terminal!
from sfast.utils.term_image import print_image
for image in output_images:
print_image(image, max_width=80)
print(f'Inference time: {end - begin:.3f}s')
iter_per_sec = iter_profiler.get_iter_per_sec()
if iter_per_sec is not None:
print(f'Iterations per second: {iter_per_sec:.3f}')
peak_mem = torch.cuda.max_memory_allocated()
print(f'Peak memory: {peak_mem / 1024**3:.3f}GiB')
if args.output_image is not None:
output_images[0].save(args.output_image)
if __name__ == '__main__':
main()