Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make timesteps work in the standard way when Huber loss is used #1628

Merged
merged 1 commit into from
Sep 25, 2024

Conversation

recris
Copy link

@recris recris commented Sep 21, 2024

Tested with SDXL LoRA.

When Huber loss is selected, the sampled timestep is the same for the entire batch. This PR changes the behavior to match what is used for L2 loss.

Instead of passing a scalar huber_c around, turn it into a batch sized Tensor object.

def conditional_loss(
model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1
model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Callers I could find are explicitly passing all arguments, so it should be safe to remove default values

@recris
Copy link
Author

recris commented Sep 21, 2024

I think we can make some additional changes to more easily integrate Huber loss into the upcoming SD3/Flux work:

  • Split method get_timesteps_and_huber_c to separate timestep sampling from Huber "delta" calculation, this would simplify getting the Huber values in the Flux trainer later.
  • Add a explicit message stating that SNR schedule is not supported in models where this concept does not apply

@kohya-ss Should this be merged as-is or can I implement the above changes?

@recris
Copy link
Author

recris commented Sep 21, 2024

After spending an afternoon going thru a rabbit hole, I concluded that the current Huber loss code is simply incorrect.

This is the textbook definition of Huber loss (from wikipedia):

image

This is a piecewise defined function. What we have in the code is:

loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)

Which is not the same and appears to be an approximation, at best. The code seems to be identical to a version that exists in the diffusers package, which is probably the source.

Looking at the PyTorch internal implementation we have:

z = (input - target).abs()
loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta))

It matches the textbook definition.

I've been getting some inconsistent results in Flux using Huber loss when following the recommendations in the original paper. I suspect this might be one of the causes.

Smooth L1 loss seems to have the same problem outlined above.

I am currently testing changes to use the same logic as PyTorch currently does.

@recris
Copy link
Author

recris commented Sep 21, 2024

After some additional measurements I think the current formulation and defaults for Huber schedule are not suitable for use in Flux.

The primary reason being that mean latent loss term (and variance) is much higher than what we find in prior models (like SDXL).

A default like huber_c=0.1 is too low for either constant or exponential schedules. A quick measurement of typical average latent abs(predicted - target) values gave me around ~0.5 (std deviation ~0.4). This means a Huber delta of 0.1 will result in applying a MAE-like estimation to most latent values.

If I understand the math correctly, given that MAE grows linearly this means we'll get smaller gradient magnitudes, on average, when compared to MSE (which has quadratic behavior). As a result the learning gets severely dampened. I've done a few test runs with different huber_c values and could see a correlation between the parameter value and the effectiveness of the training over a preset number of steps.

When it comes to the exponential schedule we have an additional problem. The current formulation (per the paper) is:

delta = exp(log(delta_0) * t)

Where delta_0 is the huber_c parameter and t is the sampled timestep (normalized to [0,1]). This means delta has an upper bound of 1.0 (when delta_0 <= 1.0) and this leads to a problem similar to the constant schedule case.

We need a way to control the upper bound of this function, probably with an additional s parameter, like delta = s * exp(log(delta_0) * t) or by shifting t like delta = exp(log(delta_0) * (t + s)).

@recris recris mentioned this pull request Sep 21, 2024
25 tasks
@kohya-ss
Copy link
Owner

Thank you for this PR.

@kohya-ss Should this be merged as-is or can I implement the above changes?

I think this PR can be merged as is.

After spending an afternoon going thru a rabbit hole, I concluded that the current Huber loss code is simply incorrect.

I don't fully understand the mathematical background, but kabachuha, the author of #1228 said "Pseudo-Huber Loss" in huggingface/diffusers#7527, so I think this difference is intentional.

Are the results significantly different and can we be sure that Wikipedia's implementation is better?

@recris
Copy link
Author

recris commented Sep 23, 2024

Are the results significantly different and can we be sure that Wikipedia's implementation is better?

In the last few days I've been testing both approaches, both seem to work well enough. The the problems I was facing (in Flux testing) seem to be related to the current Huber schedules and default values, which are inadequate for the new model. I have another set of proposed changes but I'll leave that for a later PR.

@kohya-ss
Copy link
Owner

Thank you! I will merge this PR as is.

@kohya-ss kohya-ss merged commit c1d16a7 into kohya-ss:dev Sep 25, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants