10
10
import torch_tensorrt
11
11
from accelerate .hooks import remove_hook_from_module
12
12
from diffusers import FluxPipeline
13
- from diffusers .models .transformers .transformer_flux import FluxTransformer2DModel
14
- from torch_tensorrt .dynamo ._defaults import DEBUG_LOGGING_DIR
15
13
16
14
DEVICE = "cuda:0"
17
15
@@ -23,6 +21,7 @@ def compile_model(
23
21
]:
24
22
use_explicit_typing = False
25
23
if args .use_sdpa :
24
+ # currently use sdpa is not working correctly with flux model, so we don't use it
26
25
# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
27
26
sys .path .append (os .path .join (os .path .dirname (__file__ ), "../dynamo" ))
28
27
import register_sdpa
@@ -55,13 +54,6 @@ def compile_model(
55
54
torch_dtype = torch .float16 ,
56
55
).to (torch .float16 )
57
56
58
- # # Use a small transformer for debugging
59
- # if args.debug:
60
- # pipe.transformer = FluxTransformer2DModel(
61
- # num_layers=1, num_single_layers=1, guidance_embeds=True
62
- # )
63
- # pipe.to(torch.float16)
64
-
65
57
if args .low_vram_mode :
66
58
pipe .enable_model_cpu_offload ()
67
59
else :
@@ -135,70 +127,11 @@ def forward_loop(mod):
135
127
pipe .enable_sequential_cpu_offload ()
136
128
remove_hook_from_module (pipe .transformer , recurse = True )
137
129
pipe .transformer .to (DEVICE )
138
- if args .use_dynamo :
139
- dummy_inputs = {
140
- "hidden_states" : torch .randn (
141
- (batch_size , 4096 , 64 ), dtype = torch .float16
142
- ).to (DEVICE ),
143
- "encoder_hidden_states" : torch .randn (
144
- (batch_size , 512 , 4096 ), dtype = torch .float16
145
- ).to (DEVICE ),
146
- "pooled_projections" : torch .randn (
147
- (batch_size , 768 ), dtype = torch .float16
148
- ).to (DEVICE ),
149
- "timestep" : torch .tensor ([1.0 ] * batch_size , dtype = torch .float16 ).to (
150
- DEVICE
151
- ),
152
- "txt_ids" : torch .randn ((512 , 3 ), dtype = torch .float16 ).to (DEVICE ),
153
- "img_ids" : torch .randn ((4096 , 3 ), dtype = torch .float16 ).to (DEVICE ),
154
- "guidance" : torch .tensor ([1.0 ] * batch_size , dtype = torch .float32 ).to (
155
- DEVICE
156
- ),
157
- "joint_attention_kwargs" : {},
158
- "return_dict" : False ,
159
- }
160
- from modelopt .torch .quantization .utils import export_torch_mode
161
-
162
- with export_torch_mode ():
163
- ep = torch .export .export (
164
- backbone ,
165
- args = (),
166
- kwargs = dummy_inputs ,
167
- dynamic_shapes = dynamic_shapes ,
168
- strict = False ,
169
- )
170
- if args .debug :
171
- with torch_tensorrt .dynamo .Debugger (
172
- "graphs" ,
173
- logging_dir = DEBUG_LOGGING_DIR ,
174
- # capture_fx_graph_after=["remove_num_users_is_0_nodes"],
175
- save_engine_profile = True ,
176
- profile_format = "trex" ,
177
- engine_builder_monitor = True ,
178
- ):
179
- trt_gm = torch_tensorrt .dynamo .compile (
180
- ep , inputs = dummy_inputs , ** settings
181
- )
182
- else :
183
- trt_gm = torch_tensorrt .dynamo .compile (ep , inputs = dummy_inputs , ** settings )
184
- pipe .transformer = trt_gm
185
- pipe .transformer .config = backbone .config
186
- else :
187
- if args .debug :
188
- with torch_tensorrt .dynamo .Debugger (
189
- "graphs" ,
190
- logging_dir = DEBUG_LOGGING_DIR ,
191
- capture_fx_graph_after = ["remove_num_users_is_0_nodes" ],
192
- save_engine_profile = True ,
193
- profile_format = "trex" ,
194
- engine_builder_monitor = True ,
195
- ):
196
- trt_gm = torch_tensorrt .MutableTorchTensorRTModule (backbone , ** settings )
197
- else :
198
- trt_gm = torch_tensorrt .MutableTorchTensorRTModule (backbone , ** settings )
199
- if dynamic_shapes :
200
- trt_gm .set_expected_dynamic_shape_range ((), dynamic_shapes )
201
- pipe .transformer = trt_gm
130
+
131
+ trt_gm = torch_tensorrt .MutableTorchTensorRTModule (backbone , ** settings )
132
+ if dynamic_shapes :
133
+ trt_gm .set_expected_dynamic_shape_range ((), dynamic_shapes )
134
+ pipe .transformer = trt_gm
202
135
seed = 42
203
136
image = pipe (
204
137
["Beach and Kids" ],
@@ -208,7 +141,7 @@ def forward_loop(mod):
208
141
generator = torch .Generator ("cuda" ).manual_seed (seed ),
209
142
).images
210
143
print (f"generated { len (image )} images" )
211
- image [0 ].save ("warmup1 .png" )
144
+ image [0 ].save ("beach_kids .png" )
212
145
213
146
torch .cuda .empty_cache ()
214
147
@@ -336,22 +269,11 @@ def main(args):
336
269
default = "fp16" ,
337
270
help = "Select the data type to use (fp4 or fp8 or int8 or fp16)" ,
338
271
)
339
- parser .add_argument (
340
- "--use_dynamo" ,
341
- action = "store_true" ,
342
- help = "Use dynamo compile" ,
343
- default = False ,
344
- )
345
272
parser .add_argument (
346
273
"--fp4_mha" ,
347
274
action = "store_true" ,
348
275
help = "Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_FP8_MHA_CONFIG" ,
349
276
)
350
- parser .add_argument (
351
- "--debug" ,
352
- action = "store_true" ,
353
- help = "Use debug mode" ,
354
- )
355
277
parser .add_argument (
356
278
"--low_vram_mode" ,
357
279
action = "store_true" ,
0 commit comments