Skip to content

Commit 3b61879

Browse files
committed
remove profile, debug flag
1 parent 6dc056d commit 3b61879

File tree

4 files changed

+12
-431
lines changed

4 files changed

+12
-431
lines changed

examples/apps/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ python flux_demo.py
2323

2424
### Using Different Precision Modes
2525

26+
- FP4 mode:
27+
```bash
28+
python flux_demo.py --dtype fp4
29+
```
30+
2631
- FP8 mode:
2732
```bash
2833
python flux_demo.py --dtype fp8

examples/apps/flux_demo.py

Lines changed: 7 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
import torch_tensorrt
1111
from accelerate.hooks import remove_hook_from_module
1212
from diffusers import FluxPipeline
13-
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
14-
from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR
1513

1614
DEVICE = "cuda:0"
1715

@@ -23,6 +21,7 @@ def compile_model(
2321
]:
2422
use_explicit_typing = False
2523
if args.use_sdpa:
24+
# currently use sdpa is not working correctly with flux model, so we don't use it
2625
# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
2726
sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo"))
2827
import register_sdpa
@@ -55,13 +54,6 @@ def compile_model(
5554
torch_dtype=torch.float16,
5655
).to(torch.float16)
5756

58-
# # Use a small transformer for debugging
59-
# if args.debug:
60-
# pipe.transformer = FluxTransformer2DModel(
61-
# num_layers=1, num_single_layers=1, guidance_embeds=True
62-
# )
63-
# pipe.to(torch.float16)
64-
6557
if args.low_vram_mode:
6658
pipe.enable_model_cpu_offload()
6759
else:
@@ -135,70 +127,11 @@ def forward_loop(mod):
135127
pipe.enable_sequential_cpu_offload()
136128
remove_hook_from_module(pipe.transformer, recurse=True)
137129
pipe.transformer.to(DEVICE)
138-
if args.use_dynamo:
139-
dummy_inputs = {
140-
"hidden_states": torch.randn(
141-
(batch_size, 4096, 64), dtype=torch.float16
142-
).to(DEVICE),
143-
"encoder_hidden_states": torch.randn(
144-
(batch_size, 512, 4096), dtype=torch.float16
145-
).to(DEVICE),
146-
"pooled_projections": torch.randn(
147-
(batch_size, 768), dtype=torch.float16
148-
).to(DEVICE),
149-
"timestep": torch.tensor([1.0] * batch_size, dtype=torch.float16).to(
150-
DEVICE
151-
),
152-
"txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
153-
"img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
154-
"guidance": torch.tensor([1.0] * batch_size, dtype=torch.float32).to(
155-
DEVICE
156-
),
157-
"joint_attention_kwargs": {},
158-
"return_dict": False,
159-
}
160-
from modelopt.torch.quantization.utils import export_torch_mode
161-
162-
with export_torch_mode():
163-
ep = torch.export.export(
164-
backbone,
165-
args=(),
166-
kwargs=dummy_inputs,
167-
dynamic_shapes=dynamic_shapes,
168-
strict=False,
169-
)
170-
if args.debug:
171-
with torch_tensorrt.dynamo.Debugger(
172-
"graphs",
173-
logging_dir=DEBUG_LOGGING_DIR,
174-
# capture_fx_graph_after=["remove_num_users_is_0_nodes"],
175-
save_engine_profile=True,
176-
profile_format="trex",
177-
engine_builder_monitor=True,
178-
):
179-
trt_gm = torch_tensorrt.dynamo.compile(
180-
ep, inputs=dummy_inputs, **settings
181-
)
182-
else:
183-
trt_gm = torch_tensorrt.dynamo.compile(ep, inputs=dummy_inputs, **settings)
184-
pipe.transformer = trt_gm
185-
pipe.transformer.config = backbone.config
186-
else:
187-
if args.debug:
188-
with torch_tensorrt.dynamo.Debugger(
189-
"graphs",
190-
logging_dir=DEBUG_LOGGING_DIR,
191-
capture_fx_graph_after=["remove_num_users_is_0_nodes"],
192-
save_engine_profile=True,
193-
profile_format="trex",
194-
engine_builder_monitor=True,
195-
):
196-
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
197-
else:
198-
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
199-
if dynamic_shapes:
200-
trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes)
201-
pipe.transformer = trt_gm
130+
131+
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
132+
if dynamic_shapes:
133+
trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes)
134+
pipe.transformer = trt_gm
202135
seed = 42
203136
image = pipe(
204137
["Beach and Kids"],
@@ -208,7 +141,7 @@ def forward_loop(mod):
208141
generator=torch.Generator("cuda").manual_seed(seed),
209142
).images
210143
print(f"generated {len(image)} images")
211-
image[0].save("warmup1.png")
144+
image[0].save("beach_kids.png")
212145

213146
torch.cuda.empty_cache()
214147

@@ -336,22 +269,11 @@ def main(args):
336269
default="fp16",
337270
help="Select the data type to use (fp4 or fp8 or int8 or fp16)",
338271
)
339-
parser.add_argument(
340-
"--use_dynamo",
341-
action="store_true",
342-
help="Use dynamo compile",
343-
default=False,
344-
)
345272
parser.add_argument(
346273
"--fp4_mha",
347274
action="store_true",
348275
help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_FP8_MHA_CONFIG",
349276
)
350-
parser.add_argument(
351-
"--debug",
352-
action="store_true",
353-
help="Use debug mode",
354-
)
355277
parser.add_argument(
356278
"--low_vram_mode",
357279
action="store_true",

0 commit comments

Comments
 (0)