Skip to content

Support different reg_reduction in Captum STG #1090

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

aobo-y
Copy link
Contributor

@aobo-y aobo-y commented Dec 13, 2022

Summary:
Add a new str argument reg_reduction in Captum STG classes, which specifies how the returned regularization should be reduced. Following Pytorch Loss's design, support 3 modes: sum, mean, and none. The default is sum.
(There may be needs for other modes in future, like weighted_sum. With customized mask, each gate may handle different number of elements. The application may want to use as few elements as possible instead of as few gates. For now, such use cases can use none option and reduce themselves)

Although we previously used mean, we decided to change to sum as default for 3 reasons:

  1. The original paper "LEARNING SPARSE NEURAL NETWORKS THROUGH L0 REGULARIZATION" used sum both in its writing and its implementation {F822978249}
  2. L^1 and L^2 regularization also sum over each parameter without averaging over total number of parameters within a model. See Pytorch's implementation
  3. When there are multiple STG of imbalanced lengths, the results are comparable in sum but not mean. If the model has 2 STG, where one has 100 gates and the other has one single gate, the regularization of each gate in the 1st STG will be divided by 100 in mean, which makes the 1st STG 100 times weaker than the 2nd STG. This is usually unexpected for users.

Using mean or sum will not impact the performance when there is only one BSN layer, coz people can tune reg_weight to counter the difference. The authors of "Feature selection using Stochastic Gates" mixed using sum and mean in their implementation

For backward compatibility, explicitly specified reg_reduction = "mean" for all existing usages in Pyper and MVAI.

Differential Revision: D41991741

Summary:
Add a new `str` argument `reg_reduction` in Captum STG classes, which specifies how the returned regularization should be reduced. Following Pytorch Loss's design, support 3 modes: `sum`, `mean`, and `none`. The default is `sum`.
(There may be needs for other modes in future, like `weighted_sum`. With customized `mask`, each gate may handle different number of elements. The application may want to use as few elements as possible instead of as few gates. For now, such use cases can use `none` option and reduce themselves)

Although we previously used `mean`, we decided to change to `sum` as default for 3 reasons:
1. The original paper "LEARNING SPARSE NEURAL NETWORKS THROUGH L0 REGULARIZATION" used `sum` both in its writing and its [implementation](https://github.com/AMLab-Amsterdam/L0_regularization/blob/master/l0_layers.py#L70) {F822978249}
2. L^1 and L^2 regularization also `sum` over each parameter without averaging over total number of parameters within a model. See [Pytorch's implementation](https://github.com/pytorch/pytorch/blob/df569367ef444dc9831ef0dde3bc611bcabcfbf9/torch/optim/adagrad.py#L268)
3. When there are multiple STG of imbalanced lengths, the results are comparable in `sum` but not `mean`. If the model has 2 STG, where one has 100 gates and the other has one single gate, the regularization of each gate in the 1st STG will be divided by 100 in `mean`, which makes the 1st STG 100 times weaker than the 2nd STG. This is usually unexpected for users.

Using `mean` or `sum` will not impact the performance when there is only one BSN layer, coz people can tune `reg_weight` to counter the difference. The authors of "Feature selection using Stochastic Gates" mixed using `sum` and `mean` in [their implementation](https://github.com/runopti/stg/blob/master/python/stg/models.py#L164-L195)

For backward compatibility, explicitly specified `reg_reduction = "mean"` for all existing usages in Pyper and MVAI.

Differential Revision: D41991741

fbshipit-source-id: 77f54cf3948e44e943afff795bf473adaa01fa56
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D41991741

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in dcb87d3.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants