-
Notifications
You must be signed in to change notification settings - Fork 516
Optim-wip: Composable loss improvements #828
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: optim-wip
Are you sure you want to change the base?
Optim-wip: Composable loss improvements #828
Conversation
* Added `operator.floordiv` support. * Added the `basic_torch_module_op` function that should allow for the composability of many common torch operations. * Added `rmodule_op` function for handling the 3 "r" versions of math operations.
* Improved documentation. * Renamed `basic_torch_module_op` to `custom_composable_op`. * Removed the reduction OP from 'r' module calls as it's not required.
* Custom loss objections can support any number of batch dimension values.
Hi, thank you for making this, but I may miss some context/history here. Why do we need the "composable loss" in Captum? Pytorch has already provided a convention for loss: def new_loss(output_tensor):
return nn.SomeLoss(output_tensor) + some_other_loss(output_tensor) + torch.linalg.norm(output_tensor) Pytorch I think our optim loss can work the same without "composable loss" and even be more flexible. For example deepdream = DeepDream(target)
layeractivation = LayerActivation(target)
def new_loss(targets_to_values: ModuleOutputMapping):
loss = deepdream(targets_to_values) + layeractivation(targets_to_values)
# can also use pytorch loss
return loss + nn.SomeLoss(targets_to_values[target]) |
target = loss.target | ||
return CompositeLoss(loss_fn, name=name, target=target) | ||
|
||
|
||
class BaseLoss(Loss): | ||
def __init__( | ||
self, | ||
target: Union[nn.Module, List[nn.Module]] = [], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the target
can be List[nn.Module]
, many losses below cannot directly use it as dict key targets_to_values[self.target]
. Did I miss anything?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aobo-y Losses like ActivationInterpolation
have multiple targets (Faceted loss as well in an upcoming PR), but BaseLoss
works off using a single target
variable.
The BaseLoss
class is called in the __init__
functions of loss classes like so:
# Single target
BaseLoss.__init__(self, target, batch_index)
# Multiple targets
BaseLoss.__init__(self, [target1, target2])
The loss class itself will indicate via target: List[nn.Module]
type hint that multiple targets are supported / required, or it is handled things internally by passing the targets as a list to BaseLoss
like in ActivationInterpolation
.
The ActivationInterpolation
loss class can be found here: https://github.com/ProGamerGov/captum/blob/optim-wip-composable-loss-improvements/captum/optim/_core/loss.py#L506
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but cases like DeepDream
and some others directly inherits BaseLoss's init definition, where target
can be a list while actually it should not https://github.com/ProGamerGov/captum/blob/optim-wip-composable-loss-improvements/captum/optim/_core/loss.py#L393-L407
If these losses have different assumptions of what their targets should be, why do we abstract the target
into the base class. The base class BaseLoss
does not need target
anyway. Each class can define their own target
in __init__
. Or we can have 2 other intermediate abstract classes SingleTargetLoss
MultiTargetsLoss
But anyway, this is just for discussion. It has nth related to this PR. We can leave it to future updates if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, yeah I see what you mean now. In the original code, I think that Ludwig had SingleTargetObjective
& MultiObjective
for handling these cases: https://github.com/ludwigschubert/captum/blob/f1fd0729dece59564a7c10b7b397617d8a09a247/captum/optim/optim/objectives.py#L108
It'd probably be best to leave this to a future PR if decide on the changes
@aobo-y Originally Captum’s loss functions were setup similar to the simple class-like functions like Lucid uses. Upon review we then changed the losses to use classes instead. Ludwig (one of the main Lucid developers) designed the initial optim module to utilize a Lucid-like composable loss system. One of the main benefits of the composable loss system is ease of use and built-in target tracking (the list of targets has to created regardless of whether not we use composable losses, and doing it this way means the user doesn't have to repeat the loss targets in multiple locations). It also allows for easy-to-use handling of things like batch specific targeting. |
This PR can be skipped for now. |
This PR adds a few simple improvements to the
CompositeLoss
class and it's features.__pos__
and__abs__
unary operators toCompositeLoss
. These appear to be the only other basic operators that make sense to add support for.operator.floordiv
support. The current operator is being depreciated, but the operator symbol itself is likely not going be removed. Instead it's functionality will be changed: Python Array API Compatibility Tracker pytorch#58743CompositeLoss
's reduction operation a global variable that can be changed by users. This should improve the generality of the optim module, and it makes it possible to disable this aspect ofCompositeLoss
.torch.mean
andtorch.sum
reduction operations toCompositeLoss
. These are common operations, so there are likely use cases that can benefit from them. Example usage:loss_fn.mean()
&loss_fn.sum()
.custom_composable_op
function that should allow for the composability of many Python & PyTorch operations, as well as custom user operations. This should allow users to cover any operations that aren't covered by default in Captum.rmodule_op
function for handling the 3 "r" versions of math operations. This helps simplify the code.2.0 ** loss_obj
,2.0 / loss_obj
, &2.0 // loss_obj
all work without the reduction op, so I've removed it for those cases.