Skip to content

improve documentation of STG #1100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 45 additions & 29 deletions captum/module/binary_concrete_stochastic_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,16 @@ class BinaryConcreteStochasticGates(StochasticGatesBase):
Then use hard-sigmoid rectification to "fold" the parts smaller than 0 or larger
than 1 back to 0 and 1.

More details can be found in the
`original paper <https://arxiv.org/abs/1712.01312>`.
More details can be found in the original paper:
https://arxiv.org/abs/1712.01312

Examples::

>>> n_params = 5 # number of parameters
>>> stg = BinaryConcreteStochasticGates(n_params, reg_weight=0.01)
>>> inputs = torch.randn(3, n_params) # mock inputs with batch size of 3
>>> gated_inputs, reg = stg(mock_inputs) # gate the inputs

"""

def __init__(
Expand All @@ -66,42 +74,42 @@ def __init__(
Args:
n_gates (int): number of gates.

mask (Optional[Tensor]): If provided, this allows grouping multiple
mask (Tensor, optional): If provided, this allows grouping multiple
input tensor elements to share the same stochastic gate.
This tensor should be broadcastable to match the input shape
and contain integers in the range 0 to n_gates - 1.
Indices grouped to the same stochastic gate should have the same value.
If not provided, each element in the input tensor
(on dimensions other than dim 0 - batch dim) is gated separately.
(on dimensions other than dim 0, i.e., batch dim) is gated separately.
Default: None

reg_weight (Optional[float]): rescaling weight for L0 regularization term.
reg_weight (float, optional): rescaling weight for L0 regularization term.
Default: 1.0

temperature (float): temperature of the concrete distribution, controls
the degree of approximation, as 0 means the original Bernoulli
temperature (float, optional): temperature of the concrete distribution,
controls the degree of approximation, as 0 means the original Bernoulli
without relaxation. The value should be between 0 and 1.
Default: 2/3

lower_bound (float): the lower bound to "stretch" the binary concrete
distribution
lower_bound (float, optional): the lower bound to "stretch" the binary
concrete distribution
Default: -0.1

upper_bound (float): the upper bound to "stretch" the binary concrete
distribution
upper_bound (float, optional): the upper bound to "stretch" the binary
concrete distribution
Default: 1.1

eps (float): term to improve numerical stability in binary concerete
sampling
eps (float, optional): term to improve numerical stability in binary
concerete sampling
Default: 1e-8

reg_reduction (str, optional): the reduction to apply to
the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be
applied and it will be the same as the return of get_active_probs,
'mean': the sum of the gates non-zero probabilities will be divided by
the number of gates, 'sum': the gates non-zero probabilities will
reg_reduction (str, optional): the reduction to apply to the regularization:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be
applied and it will be the same as the return of ``get_active_probs``,
``'mean'``: the sum of the gates non-zero probabilities will be divided
by the number of gates, ``'sum'``: the gates non-zero probabilities will
be summed.
Default: 'sum'
Default: ``'sum'``
"""
super().__init__(
n_gates, mask=mask, reg_weight=reg_weight, reg_reduction=reg_reduction
Expand Down Expand Up @@ -193,7 +201,7 @@ def _from_pretrained(cls, log_alpha_param: Tensor, *args, **kwargs):
log_alpha_param (Tensor): FloatTensor containing weights for
the pretrained log_alpha

mask (Optional[Tensor]): If provided, this allows grouping multiple
mask (Tensor, optional): If provided, this allows grouping multiple
input tensor elements to share the same stochastic gate.
This tensor should be broadcastable to match the input shape
and contain integers in the range 0 to n_gates - 1.
Expand All @@ -202,26 +210,34 @@ def _from_pretrained(cls, log_alpha_param: Tensor, *args, **kwargs):
(on dimensions other than dim 0 - batch dim) is gated separately.
Default: None

reg_weight (Optional[float]): rescaling weight for L0 regularization term.
reg_weight (float, optional): rescaling weight for L0 regularization term.
Default: 1.0

temperature (float): temperature of the concrete distribution, controls
the degree of approximation, as 0 means the original Bernoulli
temperature (float, optional): temperature of the concrete distribution,
controls the degree of approximation, as 0 means the original Bernoulli
without relaxation. The value should be between 0 and 1.
Default: 2/3

lower_bound (float): the lower bound to "stretch" the binary concrete
distribution
lower_bound (float, optional): the lower bound to "stretch" the binary
concrete distribution
Default: -0.1

upper_bound (float): the upper bound to "stretch" the binary concrete
distribution
upper_bound (float, optional): the upper bound to "stretch" the binary
concrete distribution
Default: 1.1

eps (float): term to improve numerical stability in binary concerete
sampling
eps (float, optional): term to improve numerical stability in binary
concerete sampling
Default: 1e-8

reg_reduction (str, optional): the reduction to apply to the regularization:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be
applied and it will be the same as the return of ``get_active_probs``,
``'mean'``: the sum of the gates non-zero probabilities will be divided
by the number of gates, ``'sum'``: the gates non-zero probabilities will
be summed.
Default: ``'sum'``

Returns:
stg (BinaryConcreteStochasticGates): StochasticGates instance
"""
Expand Down
49 changes: 32 additions & 17 deletions captum/module/gaussian_stochastic_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,15 @@ class GaussianStochasticGates(StochasticGatesBase):
within 0 and 1, gaussian does not have boundaries. So hard-sigmoid rectification
is used to "fold" the parts smaller than 0 or larger than 1 back to 0 and 1.

More details can be found in the
`original paper <https://arxiv.org/abs/1810.04247>`.
More details can be found in the original paper:
https://arxiv.org/abs/1810.04247

Examples::

>>> n_params = 5 # number of gates
>>> stg = GaussianStochasticGates(n_params, reg_weight=0.01)
>>> inputs = torch.randn(3, n_params) # mock inputs with batch size of 3
>>> gated_inputs, reg = stg(mock_inputs) # gate the inputs
"""

def __init__(
Expand All @@ -44,28 +51,28 @@ def __init__(
Args:
n_gates (int): number of gates.

mask (Optional[Tensor]): If provided, this allows grouping multiple
mask (Tensor, optional): If provided, this allows grouping multiple
input tensor elements to share the same stochastic gate.
This tensor should be broadcastable to match the input shape
and contain integers in the range 0 to n_gates - 1.
Indices grouped to the same stochastic gate should have the same value.
If not provided, each element in the input tensor
(on dimensions other than dim 0 - batch dim) is gated separately.
(on dimensions other than dim 0, i.e., batch dim) is gated separately.
Default: None

reg_weight (Optional[float]): rescaling weight for L0 regularization term.
reg_weight (float, optional): rescaling weight for L0 regularization term.
Default: 1.0

std (Optional[float]): standard deviation that will be fixed throughout.
Default: 0.5 (by paper reference)
std (float, optional): standard deviation that will be fixed throughout.
Default: 0.5

reg_reduction (str, optional): the reduction to apply to
the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be
applied and it will be the same as the return of get_active_probs,
'mean': the sum of the gates non-zero probabilities will be divided by
the number of gates, 'sum': the gates non-zero probabilities will
reg_reduction (str, optional): the reduction to apply to the regularization:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be
applied and it will be the same as the return of ``get_active_probs``,
``'mean'``: the sum of the gates non-zero probabilities will be divided
by the number of gates, ``'sum'``: the gates non-zero probabilities will
be summed.
Default: 'sum'
Default: ``'sum'``
"""
super().__init__(
n_gates, mask=mask, reg_weight=reg_weight, reg_reduction=reg_reduction
Expand Down Expand Up @@ -126,7 +133,7 @@ def _from_pretrained(cls, mu: Tensor, *args, **kwargs):
Args:
mu (Tensor): FloatTensor containing weights for the pretrained mu

mask (Optional[Tensor]): If provided, this allows grouping multiple
mask (Tensor, optional): If provided, this allows grouping multiple
input tensor elements to share the same stochastic gate.
This tensor should be broadcastable to match the input shape
and contain integers in the range 0 to n_gates - 1.
Expand All @@ -135,11 +142,19 @@ def _from_pretrained(cls, mu: Tensor, *args, **kwargs):
(on dimensions other than dim 0 - batch dim) is gated separately.
Default: None

reg_weight (Optional[float]): rescaling weight for L0 regularization term.
reg_weight (float, optional): rescaling weight for L0 regularization term.
Default: 1.0

std (Optional[float]): standard deviation that will be fixed throughout.
Default: 0.5 (by paper reference)
std (float, optional): standard deviation that will be fixed throughout.
Default: 0.5

reg_reduction (str, optional): the reduction to apply to the regularization:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be
applied and it will be the same as the return of ``get_active_probs``,
``'mean'``: the sum of the gates non-zero probabilities will be divided
by the number of gates, ``'sum'``: the gates non-zero probabilities will
be summed.
Default: ``'sum'``

Returns:
stg (GaussianStochasticGates): StochasticGates instance
Expand Down
22 changes: 11 additions & 11 deletions captum/module/stochastic_gates_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
Args:
n_gates (int): number of gates.

mask (Optional[Tensor]): If provided, this allows grouping multiple
mask (Tensor, optional): If provided, this allows grouping multiple
input tensor elements to share the same stochastic gate.
This tensor should be broadcastable to match the input shape
and contain integers in the range 0 to n_gates - 1.
Expand All @@ -48,16 +48,16 @@ def __init__(
(on dimensions other than dim 0 - batch dim) is gated separately.
Default: None

reg_weight (Optional[float]): rescaling weight for L0 regularization term.
reg_weight (float, optional): rescaling weight for L0 regularization term.
Default: 1.0

reg_reduction (str, optional): the reduction to apply to
the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be
applied and it will be the same as the return of get_active_probs,
'mean': the sum of the gates non-zero probabilities will be divided by
the number of gates, 'sum': the gates non-zero probabilities will
reg_reduction (str, optional): the reduction to apply to the regularization:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be
applied and it will be the same as the return of ``get_active_probs``,
``'mean'``: the sum of the gates non-zero probabilities will be divided
by the number of gates, ``'sum'``: the gates non-zero probabilities will
be summed.
Default: 'sum'
Default: ``'sum'``
"""
super().__init__()

Expand Down Expand Up @@ -143,13 +143,13 @@ def get_gate_values(self, clamp: bool = True) -> Tensor:
optionally clamped within 0 and 1.

Args:
clamp (bool): whether to clamp the gate values or not. As smoothed Bernoulli
variables, gate values are clamped within 0 and 1 by default.
clamp (bool, optional): whether to clamp the gate values or not. As smoothed
Bernoulli variables, gate values are clamped within 0 and 1 by default.
Turn this off to get the raw means of the underneath
distribution (e.g., concrete, gaussian), which can be useful to
differentiate the gates' importance when multiple gate
values are beyond 0 or 1.
Default: True
Default: ``True``

Returns:
Tensor:
Expand Down