Skip to content

Commit a204e31

Browse files
authored
Flux Validation (#1518)
# This pr implements the validator class for flux following the method discussed in Stable Diffusion 3 paper. The paper shows that creating 8 equidistant timesteps and calculating the average loss on them will result in a highly correlated loss to external validation methods such as CLIP or FID score. This pr's implementation rather than creating 8 stratified timesteps per sample, only applies one of these equidistant timesteps to each sample in a round-robin fashion. Aggregated over many samples in a validation set, this should give a similar validation score as the full timestep method, but will process more validation samples quickly. ### Implementations - Integrates the image generation evaluation in the validation step, users can - Refactors and combines eval job_config with validation - Adds an `all_timesteps` option to the job_config to choose whether to use round robin timesteps or full timesteps per sample - Creates validator class and validation dataloader for flux, validator dataloader handles generating timesteps for round-robin method of validation ### Enabling all timesteps Developers can enable the full timestamp method of validation by setting `all_timesteps = True` in the flux validation job config. Enabling all_timesteps may require tweaking some hyperparams `validation.local_batch_size, validation.steps` to prevent spiking memory and optimizing throughput. By using a ratio of around 1/4 for `validation.local_batch_size` to `training.local_batch_size` will not spike the memory higher than training when `fsdp = 8`. Below we can see the difference between round robin and all timesteps. In the comparison the total number of validation samples processed is the same, but in `all_timesteps=True` configuration we have to lower the batch size to prevent memory spiking. All timesteps also achieves a higher throughput (tps) but still processes total samples of validation set more slowly. | Round Robin (batch_size=32, steps=1, fsdp=8) | All Timesteps (batch_size=8, steps=4, fsdp=8) | | ---- | --- | | <img width="682" height="303" alt="Screenshot 2025-08-01 at 3 46 42 PM" src="https://github.com/user-attachments/assets/30328bfe-4c3c-4912-a329-2b94c834b67b" /> | <img width="719" height="308" alt="Screenshot 2025-08-01 at 3 30 10 PM" src="https://github.com/user-attachments/assets/c7325d21-8a7b-41d9-a0d2-74052e425083" /> |
1 parent 90cfba4 commit a204e31

File tree

14 files changed

+475
-96
lines changed

14 files changed

+475
-96
lines changed

torchtitan/components/validate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def validate(
8080
self,
8181
model_parts: list[nn.Module],
8282
step: int,
83-
) -> dict[str, float]:
83+
) -> None:
8484
# Set model to eval mode
8585
model = model_parts[0]
8686
model.eval()

torchtitan/experiments/flux/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .model.args import FluxModelArgs
1818
from .model.autoencoder import AutoEncoderParams
1919
from .model.model import FluxModel
20+
from .validate import build_flux_validator
2021

2122
__all__ = [
2223
"FluxModelArgs",
@@ -117,5 +118,6 @@
117118
build_dataloader_fn=build_flux_dataloader,
118119
build_tokenizer_fn=None,
119120
build_loss_fn=build_mse_loss,
121+
build_validator_fn=build_flux_validator,
120122
)
121123
)

torchtitan/experiments/flux/dataset/flux_dataset.py

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import itertools
78
import math
89
from dataclasses import dataclass
910
from typing import Any, Callable, Optional
@@ -103,6 +104,38 @@ def _cc12m_wds_data_processor(
103104
"image": img,
104105
"clip_tokens": clip_tokens, # type: List[int]
105106
"t5_tokens": t5_tokens, # type: List[int]
107+
"prompt": sample["txt"], # type: str
108+
}
109+
110+
111+
def _coco_data_processor(
112+
sample: dict[str, Any],
113+
t5_tokenizer: FluxTokenizer,
114+
clip_tokenizer: FluxTokenizer,
115+
output_size: int = 256,
116+
) -> dict[str, Any]:
117+
"""
118+
Preprocess COCO dataset sample image and text for Flux model.
119+
120+
Args:
121+
sample: A sample from dataset
122+
t5_encoder: T5 encoder
123+
clip_encoder: CLIP encoder
124+
output_size: The output image size
125+
126+
"""
127+
img = _process_cc12m_image(sample["image"], output_size=output_size)
128+
prompt = sample["caption"]
129+
if isinstance(prompt, list):
130+
prompt = prompt[0]
131+
t5_tokens = t5_tokenizer.encode(prompt)
132+
clip_tokens = clip_tokenizer.encode(prompt)
133+
134+
return {
135+
"image": img,
136+
"clip_tokens": clip_tokens, # type: List[int]
137+
"t5_tokens": t5_tokens, # type: List[int]
138+
"prompt": prompt, # type: str
106139
}
107140

108141

@@ -126,6 +159,11 @@ class TextToImageDatasetConfig:
126159
),
127160
data_processor=_cc12m_wds_data_processor,
128161
),
162+
"coco-validation": TextToImageDatasetConfig(
163+
path="howard-hou/COCO-Text",
164+
loader=lambda path: load_dataset(path, split="validation", streaming=True),
165+
data_processor=_coco_data_processor,
166+
),
129167
}
130168

131169

@@ -242,8 +280,9 @@ def __iter__(self):
242280

243281
# skip low quality image or image with color channel = 1
244282
if sample_dict["image"] is None:
283+
sample = sample.get("__key__", "unknown")
245284
logger.warning(
246-
f"Low quality image {sample['__key__']} is skipped in Flux Dataloader."
285+
f"Low quality image {sample} is skipped in Flux Dataloader."
247286
)
248287
continue
249288

@@ -308,3 +347,87 @@ def build_flux_dataloader(
308347
dp_world_size=dp_world_size,
309348
batch_size=batch_size,
310349
)
350+
351+
352+
class FluxValidationDataset(FluxDataset):
353+
"""
354+
Adds logic to generate timesteps for flux validation method described in SD3 paper
355+
356+
Args:
357+
generate_timesteps (bool): Generate stratified timesteps in round-robin style for validation
358+
"""
359+
360+
def __init__(
361+
self,
362+
dataset_name: str,
363+
dataset_path: Optional[str],
364+
t5_tokenizer: BaseTokenizer,
365+
clip_tokenizer: BaseTokenizer,
366+
job_config: Optional[JobConfig] = None,
367+
dp_rank: int = 0,
368+
dp_world_size: int = 1,
369+
generate_timesteps: bool = True,
370+
) -> None:
371+
# Call parent constructor correctly
372+
super().__init__(
373+
dataset_name=dataset_name,
374+
dataset_path=dataset_path,
375+
t5_tokenizer=t5_tokenizer,
376+
clip_tokenizer=clip_tokenizer,
377+
job_config=job_config,
378+
dp_rank=dp_rank,
379+
dp_world_size=dp_world_size,
380+
infinite=False,
381+
)
382+
383+
# Initialize timestep generation for validation
384+
self.generate_timesteps = generate_timesteps
385+
if self.generate_timesteps:
386+
# Generate stratified timesteps as described in SD3 paper
387+
val_timesteps = [1 / 8 * (i + 0.5) for i in range(8)]
388+
self.timestep_cycle = itertools.cycle(val_timesteps)
389+
390+
def __iter__(self):
391+
# Get parent iterator and add timesteps to each sample
392+
parent_iterator = super().__iter__()
393+
394+
for sample_dict, labels in parent_iterator:
395+
# Add timestep to the sample dict if timestep generation is enabled
396+
if self.generate_timesteps:
397+
sample_dict["timestep"] = next(self.timestep_cycle)
398+
399+
yield sample_dict, labels
400+
401+
402+
def build_flux_validation_dataloader(
403+
dp_world_size: int,
404+
dp_rank: int,
405+
job_config: JobConfig,
406+
# This parameter is not used, keep it for compatibility
407+
tokenizer: BaseTokenizer | None,
408+
generate_timestamps: bool = True,
409+
) -> ParallelAwareDataloader:
410+
"""Build a data loader for HuggingFace datasets."""
411+
dataset_name = job_config.validation.dataset
412+
dataset_path = job_config.validation.dataset_path
413+
batch_size = job_config.validation.local_batch_size
414+
415+
t5_tokenizer, clip_tokenizer = build_flux_tokenizer(job_config)
416+
417+
ds = FluxValidationDataset(
418+
dataset_name=dataset_name,
419+
dataset_path=dataset_path,
420+
t5_tokenizer=t5_tokenizer,
421+
clip_tokenizer=clip_tokenizer,
422+
job_config=job_config,
423+
dp_rank=dp_rank,
424+
dp_world_size=dp_world_size,
425+
generate_timesteps=generate_timestamps,
426+
)
427+
428+
return ParallelAwareDataloader(
429+
dataset=ds,
430+
dp_rank=dp_rank,
431+
dp_world_size=dp_world_size,
432+
batch_size=batch_size,
433+
)

torchtitan/experiments/flux/job_config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class Encoder:
3636

3737

3838
@dataclass
39-
class Eval:
39+
class Validation:
4040
enable_classifier_free_guidance: bool = False
4141
"""Whether to use classifier-free guidance during sampling"""
4242
classifier_free_guidance_scale: float = 5.0
@@ -45,8 +45,13 @@ class Eval:
4545
"""How many denoising steps to sample when generating an image"""
4646
eval_freq: int = 100
4747
"""Frequency of evaluation/sampling during training"""
48+
save_img_count: int = 1
49+
""" How many images to generate and save during validation, starting from
50+
the beginning of validation set, -1 means generate on all samples"""
4851
save_img_folder: str = "img"
4952
"""Directory to save image generated/sampled from the model"""
53+
all_timesteps: bool = False
54+
"""Whether to generate all stratified timesteps per sample or use round robin"""
5055

5156

5257
@dataclass
@@ -57,4 +62,4 @@ class JobConfig:
5762

5863
training: Training = field(default_factory=Training)
5964
encoder: Encoder = field(default_factory=Encoder)
60-
eval: Eval = field(default_factory=Eval)
65+
validation: Validation = field(default_factory=Validation)

torchtitan/experiments/flux/sampling.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ def generate_image(
9393
img_height = 16 * (job_config.training.img_size // 16)
9494
img_width = 16 * (job_config.training.img_size // 16)
9595

96-
enable_classifier_free_guidance = job_config.eval.enable_classifier_free_guidance
96+
enable_classifier_free_guidance = (
97+
job_config.validation.enable_classifier_free_guidance
98+
)
9799

98100
# Tokenize the prompt. Unsqueeze to add a batch dimension.
99101
clip_tokens = clip_tokenizer.encode(prompt).unsqueeze(0)
@@ -132,7 +134,7 @@ def generate_image(
132134
model=model,
133135
img_width=img_width,
134136
img_height=img_height,
135-
denoising_steps=job_config.eval.denoising_steps,
137+
denoising_steps=job_config.validation.denoising_steps,
136138
clip_encodings=batch["clip_encodings"],
137139
t5_encodings=batch["t5_encodings"],
138140
enable_classifier_free_guidance=enable_classifier_free_guidance,
@@ -142,7 +144,7 @@ def generate_image(
142144
empty_clip_encodings=(
143145
empty_batch["clip_encodings"] if enable_classifier_free_guidance else None
144146
),
145-
classifier_free_guidance_scale=job_config.eval.classifier_free_guidance_scale,
147+
classifier_free_guidance_scale=job_config.validation.classifier_free_guidance_scale,
146148
)
147149

148150
img = autoencoder.decode(img)

torchtitan/experiments/flux/tests/integration_tests.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def build_test_list():
6464
"Checkpoint Integration Test - Save Model Only fp32",
6565
"last_save_model_only_fp32",
6666
),
67+
OverrideDefinitions(
68+
[["--validation.enabled"]], "Flux Validation Test", "validation"
69+
),
6770
# Parallelism tests.
6871
OverrideDefinitions(
6972
[

torchtitan/experiments/flux/tests/test_generate_image.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ def test_generate_image(self):
5757
"--training.img_size",
5858
str(img_width),
5959
# eval params
60-
"--eval.denoising_steps",
60+
"--validation.denoising_steps",
6161
str(num_steps),
62-
"--eval.enable_classifier_free_guidance",
63-
"--eval.classifier_free_guidance_scale",
62+
"--validation.enable_classifier_free_guidance",
63+
"--validation.classifier_free_guidance_scale",
6464
str(classifier_free_guidance_scale),
65-
"--eval.save_img_folder",
65+
"--validation.save_img_folder",
6666
"img",
6767
]
6868
)
@@ -120,7 +120,7 @@ def test_generate_image(self):
120120
save_image(
121121
name=f"img_unit_test_{config.training.seed}.jpg",
122122
output_dir=os.path.join(
123-
config.job.dump_folder, config.eval.save_img_folder
123+
config.job.dump_folder, config.validation.save_img_folder
124124
),
125125
x=image,
126126
add_sampling_metadata=True,

torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ def test_load_dataset(self):
7979
for i in range(0, num_steps):
8080
input_data, labels = next(it)
8181

82-
assert len(input_data) == 2 # (clip_encodings, t5_encodings)
82+
assert (
83+
len(input_data) == 3
84+
) # (clip_encodings, t5_encodings, prompt)
8385
assert labels.shape == (batch_size, 3, 256, 256)
8486
assert input_data["clip_tokens"].shape == (
8587
batch_size,

torchtitan/experiments/flux/train.py

Lines changed: 10 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,18 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import os
8-
from typing import Iterable, Optional
8+
from typing import Optional
99

1010
import torch
11-
from torch.distributed.fsdp import FSDPModule
1211

1312
from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP
1413
from torchtitan.distributed import utils as dist_utils
1514
from torchtitan.tools.logging import init_logger, logger
1615
from torchtitan.train import Trainer
1716

18-
from .dataset.tokenizer import build_flux_tokenizer
1917
from .infra.parallelize import parallelize_encoders
2018
from .model.autoencoder import load_ae
2119
from .model.hf_embedder import FluxEmbedder
22-
from .sampling import generate_image, save_image
2320
from .utils import (
2421
create_position_encoding_for_latents,
2522
pack_latents,
@@ -81,6 +78,15 @@ def __init__(self, job_config: JobConfig):
8178
job_config=job_config,
8279
)
8380

81+
if job_config.validation.enabled:
82+
self.validator.flux_init(
83+
device=self.device,
84+
_dtype=self._dtype,
85+
autoencoder=self.autoencoder,
86+
t5_encoder=self.t5_encoder,
87+
clip_encoder=self.clip_encoder,
88+
)
89+
8490
def forward_backward_step(
8591
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
8692
) -> torch.Tensor:
@@ -147,64 +153,6 @@ def forward_backward_step(
147153

148154
return loss
149155

150-
def eval_step(self, prompt: str = "A photo of a cat"):
151-
"""
152-
Evaluate the Flux model.
153-
1) generate and save images every few steps. Currently, we run the eval and on the same
154-
prompts across all DP ranks. We will change this behavior to run on validation set prompts.
155-
Due to random noise generation, results could be different across DP ranks cause we assign
156-
different random seeds to each DP rank.
157-
2) [TODO] Calculate loss with fixed t value on validation set.
158-
"""
159-
160-
t5_tokenizer, clip_tokenizer = build_flux_tokenizer(self.job_config)
161-
162-
image = generate_image(
163-
device=self.device,
164-
dtype=self._dtype,
165-
job_config=self.job_config,
166-
model=self.model_parts[0],
167-
prompt=prompt, # TODO(jianiw): change this to a prompt from validation set
168-
autoencoder=self.autoencoder,
169-
t5_tokenizer=t5_tokenizer,
170-
clip_tokenizer=clip_tokenizer,
171-
t5_encoder=self.t5_encoder,
172-
clip_encoder=self.clip_encoder,
173-
)
174-
175-
save_image(
176-
name=f"image_rank{str(torch.distributed.get_rank())}_{self.step}.png",
177-
output_dir=os.path.join(
178-
self.job_config.job.dump_folder, self.job_config.eval.save_img_folder
179-
),
180-
x=image,
181-
add_sampling_metadata=True,
182-
prompt=prompt,
183-
)
184-
185-
# Reshard after run forward pass in eval_step.
186-
# This is to ensure the model weights are sharded the same way for checkpoint saving.
187-
for module in self.model_parts[0].modules():
188-
if isinstance(module, FSDPModule):
189-
module.reshard()
190-
191-
def train_step(
192-
self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
193-
):
194-
super().train_step(data_iterator)
195-
196-
# Evaluate the model during training
197-
if (
198-
self.step % self.job_config.eval.eval_freq == 0
199-
or self.step == self.job_config.training.steps
200-
):
201-
model = self.model_parts[0]
202-
model.eval()
203-
# We need to set reshard_after_forward before last forward pass.
204-
# So the model wieghts are sharded the same way for checkpoint saving.
205-
self.eval_step()
206-
model.train()
207-
208156

209157
if __name__ == "__main__":
210158
init_logger()

0 commit comments

Comments
 (0)