-
Notifications
You must be signed in to change notification settings - Fork 546
Optim-wip: Fix duplicated target bug #919
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The |
| ) | ||
|
|
||
| # Filter out duplicate targets | ||
| target = list(dict.fromkeys(target)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProGamerGov, why would someone pass duplicated target here ? Shouldn't we set an assert here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NarineK There are a few reason why there can be duplicates here.
For example, optimization with transparency will be using NaturalImage or a transform as the target for one or more alpha channel related objectives. If the user is working with a CLIP model, an L2 penalty objective will also be using one of the same targets.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using multiple different penalties on the same target will also create duplicates without that line. In Optimizing with Transparency Notebook, a duplicate would be created in the final section when both a blurring penalty and an l2 penalty are using the same target layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProGamerGov, are you saying that in this line the targets will be duplicated (example from Optimizing with Transparency Notebook) ?
loss_fn = loss_fn - MeanAlphaChannelPenalty(transforms[0])
loss_fn = loss_fn - (9 * BlurActivations(transforms[0], channel_index=3))
Is it because we are concatenating the other target to the current target here ?
target = (self.target if isinstance(self.target, list) else [self.target]) + (
other.target if isinstance(other.target, list) else [other.target]
)
I was thinking why are we concatenating the targets in the above line ?
It looks like we are not concatenating if self.target is a list but otherwise we concatenate it with other.target . I was wondering if we could elaborate this logic a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NarineK Yes, that those lines will result in a duplicated target because we concatenate the target lists for every operation involving multiple loss objectives. I'll about doing a more detailed write-up of how it works in another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProGamerGov, but we don't concatenate them if self.target is a list ? You can perhaps rework this PR since it is very small or we need to document that the logic requires refinement in this PR before we merge it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NarineK It seems like it could be a bit complicated to change at the moment. Most loss objectives store a self.target value that is then called to collect the target activations:
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
return activations
The self.target value them becomes a list when combined with another loss objective in a resulting CompositeLoss instance. The two original objectives can still call and use their own self.target value. The InputOptimization module also uses a list of targets, but it does not overwrite the original self.target to make it's list.
Currently a hook is created in
ModuleOutputsHookfor ever instance of a target in the target list. Each captured set of activations for the same hook overwrites the previous set of activations, potentially leading to negative performance impacts as only one set of activations for each target is returned. This bug also causes the warning messages inModuleOutputsHookto repeat every iteration.This PR solves the issue by ensuring that duplicate target values are removed.