You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# This pr implements the validator class for flux following the method
discussed in Stable Diffusion 3 paper.
The paper shows that creating 8 equidistant timesteps and calculating
the average loss on them will result in a highly correlated loss to
external validation methods such as CLIP or FID score.
This pr's implementation rather than creating 8 stratified timesteps per
sample, only applies one of these equidistant timesteps to each sample
in a round-robin fashion. Aggregated over many samples in a validation
set, this should give a similar validation score as the full timestep
method, but will process more validation samples quickly.
### Implementations
- Integrates the image generation evaluation in the validation step,
users can
- Refactors and combines eval job_config with validation
- Adds an `all_timesteps` option to the job_config to choose whether to
use round robin timesteps or full timesteps per sample
- Creates validator class and validation dataloader for flux, validator
dataloader handles generating timesteps for round-robin method of
validation
### Enabling all timesteps
Developers can enable the full timestamp method of validation by setting
`all_timesteps = True` in the flux validation job config. Enabling
all_timesteps may require tweaking some hyperparams
`validation.local_batch_size, validation.steps` to prevent spiking
memory and optimizing throughput. By using a ratio of around 1/4 for
`validation.local_batch_size` to `training.local_batch_size` will not
spike the memory higher than training when `fsdp = 8`.
Below we can see the difference between round robin and all timesteps.
In the comparison the total number of validation samples processed is
the same, but in `all_timesteps=True` configuration we have to lower the
batch size to prevent memory spiking. All timesteps also achieves a
higher throughput (tps) but still processes total samples of validation
set more slowly.
| Round Robin (batch_size=32, steps=1, fsdp=8) | All Timesteps
(batch_size=8, steps=4, fsdp=8) |
| ---- | --- |
| <img width="682" height="303" alt="Screenshot 2025-08-01 at 3 46
42 PM"
src="https://github.com/user-attachments/assets/30328bfe-4c3c-4912-a329-2b94c834b67b"
/> | <img width="719" height="308" alt="Screenshot 2025-08-01 at 3 30
10 PM"
src="https://github.com/user-attachments/assets/c7325d21-8a7b-41d9-a0d2-74052e425083"
/> |
0 commit comments