Support different reg_reduction in Captum STG #1090
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
Add a new
strargumentreg_reductionin Captum STG classes, which specifies how the returned regularization should be reduced. Following Pytorch Loss's design, support 3 modes:sum,mean, andnone. The default issum.(There may be needs for other modes in future, like
weighted_sum. With customizedmask, each gate may handle different number of elements. The application may want to use as few elements as possible instead of as few gates. For now, such use cases can usenoneoption and reduce themselves)Although we previously used
mean, we decided to change tosumas default for 3 reasons:sumboth in its writing and its implementation {F822978249}sumover each parameter without averaging over total number of parameters within a model. See Pytorch's implementationsumbut notmean. If the model has 2 STG, where one has 100 gates and the other has one single gate, the regularization of each gate in the 1st STG will be divided by 100 inmean, which makes the 1st STG 100 times weaker than the 2nd STG. This is usually unexpected for users.Using
meanorsumwill not impact the performance when there is only one BSN layer, coz people can tunereg_weightto counter the difference. The authors of "Feature selection using Stochastic Gates" mixed usingsumandmeanin their implementationFor backward compatibility, explicitly specified
reg_reduction = "mean"for all existing usages in Pyper and MVAI.Differential Revision: D41991741