-
Notifications
You must be signed in to change notification settings - Fork 67
/
Copy pathdiffusion_trt.py
204 lines (176 loc) · 7.03 KB
/
diffusion_trt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
from onnx_utils.export import (
generate_dummy_inputs_and_dynamic_axes_and_shapes,
get_io_shapes,
remove_nesting,
update_dynamic_axes,
)
from quantize import create_pipeline
import modelopt.torch.opt as mto
from modelopt.torch._deploy._runtime import RuntimeRegistry
from modelopt.torch._deploy._runtime.tensorrt.constants import SHA_256_HASH_LENGTH
from modelopt.torch._deploy._runtime.tensorrt.tensorrt_utils import prepend_hash_to_bytes
from modelopt.torch._deploy.device_model import DeviceModel
from modelopt.torch._deploy.utils import get_onnx_bytes_and_metadata
def generate_image(pipe, prompt, image_name):
seed = 42
image = pipe(
prompt,
output_type="pil",
num_inference_steps=30,
generator=torch.Generator("cuda").manual_seed(seed),
).images[0]
image.save(image_name)
print(f"Image generated saved as {image_name}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
default="flux-dev",
choices=["sdxl-1.0", "sdxl-turbo", "sd3-medium", "flux-dev", "flux-schnell"],
)
parser.add_argument(
"--override-model-path",
type=str,
default=None,
help="Path to the model if not using default paths in MODEL_ID mapping.",
)
parser.add_argument(
"--model-dtype",
type=str,
default="Half",
choices=["Half", "BFloat16", "Float"],
help="Precision used to load the model.",
)
parser.add_argument(
"--restore-from", type=str, default=None, help="Path to the modelopt quantized checkpoint"
)
parser.add_argument(
"--prompt",
type=str,
default="a photo of an astronaut riding a horse on mars",
help="Input text prompt for the model",
)
parser.add_argument(
"--onnx-load-path", type=str, default="", help="Path to load the ONNX model"
)
parser.add_argument(
"--trt-engine-load-path", type=str, default=None, help="Path to load the TRT engine"
)
parser.add_argument(
"--dq-only", action="store_true", help="Converts the ONNX model to a dq_only model"
)
parser.add_argument(
"--torch", action="store_true", help="Generate an image using the torch pipeline"
)
parser.add_argument("--save-image-as", type=str, default=None, help="Name of the image to save")
args = parser.parse_args()
image_name = args.save_image_as if args.save_image_as else f"{args.model}.png"
pipe = create_pipeline(args.model, args.model_dtype, args.override_model_path)
# Save the backbone of the pipeline and move it to the GPU
add_embedding = None
backbone = None
if hasattr(pipe, "transformer"):
backbone = pipe.transformer
elif hasattr(pipe, "unet"):
backbone = pipe.unet
add_embedding = backbone.add_embedding
else:
raise ValueError("Pipeline does not have a transformer or unet backbone")
if args.restore_from:
mto.restore(backbone, args.restore_from)
if args.torch:
if hasattr(pipe, "transformer"):
pipe.transformer = backbone
elif hasattr(pipe, "unet"):
pipe.unet = backbone
pipe.to("cuda")
generate_image(pipe, args.prompt, image_name)
return
backbone.to("cuda")
# Generate dummy inputs for the backbone
dummy_inputs, dynamic_axes, dynamic_shapes = generate_dummy_inputs_and_dynamic_axes_and_shapes(
args.model, backbone
)
# Postprocess the dynamic axes to match the input and output names with DeviceModel
if args.onnx_load_path == "":
update_dynamic_axes(args.model, dynamic_axes)
compilation_args = dynamic_shapes
# We only need to remove the nesting for SDXL models as they contain the nested input added_cond_kwargs
# which are renamed by the DeviceModel
ignore_nesting = False
if args.onnx_load_path != "" and args.model in ["sdxl-1.0", "sdxl-turbo"]:
remove_nesting(compilation_args)
ignore_nesting = True
# Define deployment configuration
deployment = {
"runtime": "TRT",
"version": "10.3",
"precision": "stronglyTyped",
"onnx_opset": "17",
"verbose": "false",
}
client = RuntimeRegistry.get(deployment)
# Export onnx model and get some required names from it
onnx_bytes, metadata = get_onnx_bytes_and_metadata(
model=backbone,
dummy_input=dummy_inputs,
onnx_load_path=args.onnx_load_path,
dynamic_axes=dynamic_axes,
onnx_opset=int(deployment["onnx_opset"]),
remove_exported_model=False,
dq_only=args.dq_only,
)
if not args.trt_engine_load_path:
# Compile the TRT engine from the exported ONNX model
compiled_model = client.ir_to_compiled(onnx_bytes, compilation_args)
# Save TRT engine for future use
with open(f"{args.model}.plan", "wb") as f:
# Remove the SHA-256 hash from the compiled model, used to maintain state in the trt_client
f.write(compiled_model[SHA_256_HASH_LENGTH:])
else:
with open(args.trt_engine_load_path, "rb") as f:
compiled_model = f.read()
# Prepend the SHA-256 hash from the compiled model, used to maintain state in the trt_client
compiled_model = prepend_hash_to_bytes(compiled_model)
# The output shapes will need to be specified for models with dynamic output dimensions
device_model = DeviceModel(
client,
compiled_model,
metadata,
compilation_args,
get_io_shapes(args.model, args.onnx_load_path, dynamic_shapes),
ignore_nesting,
)
if hasattr(pipe, "unet") and add_embedding:
setattr(device_model, "add_embedding", add_embedding)
# Move the backbone back to the CPU and set the backbone to the compiled device model
backbone.to("cpu")
if hasattr(pipe, "unet"):
pipe.unet = device_model
elif hasattr(pipe, "transformer"):
pipe.transformer = device_model
else:
raise ValueError("Pipeline does not have a transformer or unet backbone")
pipe.to("cuda")
generate_image(pipe, args.prompt, image_name)
print(f"Image generated using {args.model} model saved as {image_name}")
print(f"Inference latency of the backbone of the pipeline is {device_model.get_latency()} ms")
if __name__ == "__main__":
main()