diff --git a/captum/module/binary_concrete_stochastic_gates.py b/captum/module/binary_concrete_stochastic_gates.py index c3cabc39b4..01d0e83ace 100644 --- a/captum/module/binary_concrete_stochastic_gates.py +++ b/captum/module/binary_concrete_stochastic_gates.py @@ -47,8 +47,16 @@ class BinaryConcreteStochasticGates(StochasticGatesBase): Then use hard-sigmoid rectification to "fold" the parts smaller than 0 or larger than 1 back to 0 and 1. - More details can be found in the - `original paper `. + More details can be found in the original paper: + https://arxiv.org/abs/1712.01312 + + Examples:: + + >>> n_params = 5 # number of parameters + >>> stg = BinaryConcreteStochasticGates(n_params, reg_weight=0.01) + >>> inputs = torch.randn(3, n_params) # mock inputs with batch size of 3 + >>> gated_inputs, reg = stg(mock_inputs) # gate the inputs + """ def __init__( @@ -66,42 +74,42 @@ def __init__( Args: n_gates (int): number of gates. - mask (Optional[Tensor]): If provided, this allows grouping multiple + mask (Tensor, optional): If provided, this allows grouping multiple input tensor elements to share the same stochastic gate. This tensor should be broadcastable to match the input shape and contain integers in the range 0 to n_gates - 1. Indices grouped to the same stochastic gate should have the same value. If not provided, each element in the input tensor - (on dimensions other than dim 0 - batch dim) is gated separately. + (on dimensions other than dim 0, i.e., batch dim) is gated separately. Default: None - reg_weight (Optional[float]): rescaling weight for L0 regularization term. + reg_weight (float, optional): rescaling weight for L0 regularization term. Default: 1.0 - temperature (float): temperature of the concrete distribution, controls - the degree of approximation, as 0 means the original Bernoulli + temperature (float, optional): temperature of the concrete distribution, + controls the degree of approximation, as 0 means the original Bernoulli without relaxation. The value should be between 0 and 1. Default: 2/3 - lower_bound (float): the lower bound to "stretch" the binary concrete - distribution + lower_bound (float, optional): the lower bound to "stretch" the binary + concrete distribution Default: -0.1 - upper_bound (float): the upper bound to "stretch" the binary concrete - distribution + upper_bound (float, optional): the upper bound to "stretch" the binary + concrete distribution Default: 1.1 - eps (float): term to improve numerical stability in binary concerete - sampling + eps (float, optional): term to improve numerical stability in binary + concerete sampling Default: 1e-8 - reg_reduction (str, optional): the reduction to apply to - the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be - applied and it will be the same as the return of get_active_probs, - 'mean': the sum of the gates non-zero probabilities will be divided by - the number of gates, 'sum': the gates non-zero probabilities will + reg_reduction (str, optional): the reduction to apply to the regularization: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be + applied and it will be the same as the return of ``get_active_probs``, + ``'mean'``: the sum of the gates non-zero probabilities will be divided + by the number of gates, ``'sum'``: the gates non-zero probabilities will be summed. - Default: 'sum' + Default: ``'sum'`` """ super().__init__( n_gates, mask=mask, reg_weight=reg_weight, reg_reduction=reg_reduction @@ -193,7 +201,7 @@ def _from_pretrained(cls, log_alpha_param: Tensor, *args, **kwargs): log_alpha_param (Tensor): FloatTensor containing weights for the pretrained log_alpha - mask (Optional[Tensor]): If provided, this allows grouping multiple + mask (Tensor, optional): If provided, this allows grouping multiple input tensor elements to share the same stochastic gate. This tensor should be broadcastable to match the input shape and contain integers in the range 0 to n_gates - 1. @@ -202,26 +210,34 @@ def _from_pretrained(cls, log_alpha_param: Tensor, *args, **kwargs): (on dimensions other than dim 0 - batch dim) is gated separately. Default: None - reg_weight (Optional[float]): rescaling weight for L0 regularization term. + reg_weight (float, optional): rescaling weight for L0 regularization term. Default: 1.0 - temperature (float): temperature of the concrete distribution, controls - the degree of approximation, as 0 means the original Bernoulli + temperature (float, optional): temperature of the concrete distribution, + controls the degree of approximation, as 0 means the original Bernoulli without relaxation. The value should be between 0 and 1. Default: 2/3 - lower_bound (float): the lower bound to "stretch" the binary concrete - distribution + lower_bound (float, optional): the lower bound to "stretch" the binary + concrete distribution Default: -0.1 - upper_bound (float): the upper bound to "stretch" the binary concrete - distribution + upper_bound (float, optional): the upper bound to "stretch" the binary + concrete distribution Default: 1.1 - eps (float): term to improve numerical stability in binary concerete - sampling + eps (float, optional): term to improve numerical stability in binary + concerete sampling Default: 1e-8 + reg_reduction (str, optional): the reduction to apply to the regularization: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be + applied and it will be the same as the return of ``get_active_probs``, + ``'mean'``: the sum of the gates non-zero probabilities will be divided + by the number of gates, ``'sum'``: the gates non-zero probabilities will + be summed. + Default: ``'sum'`` + Returns: stg (BinaryConcreteStochasticGates): StochasticGates instance """ diff --git a/captum/module/gaussian_stochastic_gates.py b/captum/module/gaussian_stochastic_gates.py index 13054c55f5..f837cb2d7a 100644 --- a/captum/module/gaussian_stochastic_gates.py +++ b/captum/module/gaussian_stochastic_gates.py @@ -28,8 +28,15 @@ class GaussianStochasticGates(StochasticGatesBase): within 0 and 1, gaussian does not have boundaries. So hard-sigmoid rectification is used to "fold" the parts smaller than 0 or larger than 1 back to 0 and 1. - More details can be found in the - `original paper `. + More details can be found in the original paper: + https://arxiv.org/abs/1810.04247 + + Examples:: + + >>> n_params = 5 # number of gates + >>> stg = GaussianStochasticGates(n_params, reg_weight=0.01) + >>> inputs = torch.randn(3, n_params) # mock inputs with batch size of 3 + >>> gated_inputs, reg = stg(mock_inputs) # gate the inputs """ def __init__( @@ -44,28 +51,28 @@ def __init__( Args: n_gates (int): number of gates. - mask (Optional[Tensor]): If provided, this allows grouping multiple + mask (Tensor, optional): If provided, this allows grouping multiple input tensor elements to share the same stochastic gate. This tensor should be broadcastable to match the input shape and contain integers in the range 0 to n_gates - 1. Indices grouped to the same stochastic gate should have the same value. If not provided, each element in the input tensor - (on dimensions other than dim 0 - batch dim) is gated separately. + (on dimensions other than dim 0, i.e., batch dim) is gated separately. Default: None - reg_weight (Optional[float]): rescaling weight for L0 regularization term. + reg_weight (float, optional): rescaling weight for L0 regularization term. Default: 1.0 - std (Optional[float]): standard deviation that will be fixed throughout. - Default: 0.5 (by paper reference) + std (float, optional): standard deviation that will be fixed throughout. + Default: 0.5 - reg_reduction (str, optional): the reduction to apply to - the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be - applied and it will be the same as the return of get_active_probs, - 'mean': the sum of the gates non-zero probabilities will be divided by - the number of gates, 'sum': the gates non-zero probabilities will + reg_reduction (str, optional): the reduction to apply to the regularization: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be + applied and it will be the same as the return of ``get_active_probs``, + ``'mean'``: the sum of the gates non-zero probabilities will be divided + by the number of gates, ``'sum'``: the gates non-zero probabilities will be summed. - Default: 'sum' + Default: ``'sum'`` """ super().__init__( n_gates, mask=mask, reg_weight=reg_weight, reg_reduction=reg_reduction @@ -126,7 +133,7 @@ def _from_pretrained(cls, mu: Tensor, *args, **kwargs): Args: mu (Tensor): FloatTensor containing weights for the pretrained mu - mask (Optional[Tensor]): If provided, this allows grouping multiple + mask (Tensor, optional): If provided, this allows grouping multiple input tensor elements to share the same stochastic gate. This tensor should be broadcastable to match the input shape and contain integers in the range 0 to n_gates - 1. @@ -135,11 +142,19 @@ def _from_pretrained(cls, mu: Tensor, *args, **kwargs): (on dimensions other than dim 0 - batch dim) is gated separately. Default: None - reg_weight (Optional[float]): rescaling weight for L0 regularization term. + reg_weight (float, optional): rescaling weight for L0 regularization term. Default: 1.0 - std (Optional[float]): standard deviation that will be fixed throughout. - Default: 0.5 (by paper reference) + std (float, optional): standard deviation that will be fixed throughout. + Default: 0.5 + + reg_reduction (str, optional): the reduction to apply to the regularization: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be + applied and it will be the same as the return of ``get_active_probs``, + ``'mean'``: the sum of the gates non-zero probabilities will be divided + by the number of gates, ``'sum'``: the gates non-zero probabilities will + be summed. + Default: ``'sum'`` Returns: stg (GaussianStochasticGates): StochasticGates instance diff --git a/captum/module/stochastic_gates_base.py b/captum/module/stochastic_gates_base.py index 16691a4e36..76f01281d9 100644 --- a/captum/module/stochastic_gates_base.py +++ b/captum/module/stochastic_gates_base.py @@ -39,7 +39,7 @@ def __init__( Args: n_gates (int): number of gates. - mask (Optional[Tensor]): If provided, this allows grouping multiple + mask (Tensor, optional): If provided, this allows grouping multiple input tensor elements to share the same stochastic gate. This tensor should be broadcastable to match the input shape and contain integers in the range 0 to n_gates - 1. @@ -48,16 +48,16 @@ def __init__( (on dimensions other than dim 0 - batch dim) is gated separately. Default: None - reg_weight (Optional[float]): rescaling weight for L0 regularization term. + reg_weight (float, optional): rescaling weight for L0 regularization term. Default: 1.0 - reg_reduction (str, optional): the reduction to apply to - the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be - applied and it will be the same as the return of get_active_probs, - 'mean': the sum of the gates non-zero probabilities will be divided by - the number of gates, 'sum': the gates non-zero probabilities will + reg_reduction (str, optional): the reduction to apply to the regularization: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be + applied and it will be the same as the return of ``get_active_probs``, + ``'mean'``: the sum of the gates non-zero probabilities will be divided + by the number of gates, ``'sum'``: the gates non-zero probabilities will be summed. - Default: 'sum' + Default: ``'sum'`` """ super().__init__() @@ -143,13 +143,13 @@ def get_gate_values(self, clamp: bool = True) -> Tensor: optionally clamped within 0 and 1. Args: - clamp (bool): whether to clamp the gate values or not. As smoothed Bernoulli - variables, gate values are clamped within 0 and 1 by default. + clamp (bool, optional): whether to clamp the gate values or not. As smoothed + Bernoulli variables, gate values are clamped within 0 and 1 by default. Turn this off to get the raw means of the underneath distribution (e.g., concrete, gaussian), which can be useful to differentiate the gates' importance when multiple gate values are beyond 0 or 1. - Default: True + Default: ``True`` Returns: Tensor: