Support different reg_reduction in Captum STG #1090
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
Add a new
str
argumentreg_reduction
in Captum STG classes, which specifies how the returned regularization should be reduced. Following Pytorch Loss's design, support 3 modes:sum
,mean
, andnone
. The default issum
.(There may be needs for other modes in future, like
weighted_sum
. With customizedmask
, each gate may handle different number of elements. The application may want to use as few elements as possible instead of as few gates. For now, such use cases can usenone
option and reduce themselves)Although we previously used
mean
, we decided to change tosum
as default for 3 reasons:sum
both in its writing and its implementation {F822978249}sum
over each parameter without averaging over total number of parameters within a model. See Pytorch's implementationsum
but notmean
. If the model has 2 STG, where one has 100 gates and the other has one single gate, the regularization of each gate in the 1st STG will be divided by 100 inmean
, which makes the 1st STG 100 times weaker than the 2nd STG. This is usually unexpected for users.Using
mean
orsum
will not impact the performance when there is only one BSN layer, coz people can tunereg_weight
to counter the difference. The authors of "Feature selection using Stochastic Gates" mixed usingsum
andmean
in their implementationFor backward compatibility, explicitly specified
reg_reduction = "mean"
for all existing usages in Pyper and MVAI.Differential Revision: D41991741