Skip to content

Commit

Permalink
fix doc-string
Browse files Browse the repository at this point in the history
  • Loading branch information
Pedro Eduardo Mercado Lopez committed Aug 22, 2023
1 parent 0e65431 commit c10f145
Showing 1 changed file with 21 additions and 25 deletions.
46 changes: 21 additions & 25 deletions src/gluonts/torch/distributions/truncated_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# The implementation is strongly inspired from:
# - https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py
# - https://github.com/toshas/torch_truncnorm

import math
from numbers import Number
from typing import Dict, Optional, Tuple, Union
Expand All @@ -36,8 +32,6 @@


class TruncatedNormal(Distribution):
"""Truncated Normal distribution."""

"""Implements a Truncated Normal distribution with location scaling.
Location scaling prevents the location to be "too far" from 0, which ultimately
Expand All @@ -49,26 +43,28 @@ class TruncatedNormal(Distribution):
This behaviour can be disabled by switching off the tanh_loc parameter (see below).
Args:
loc (torch.Tensor): normal distribution location parameter
scale (torch.Tensor): normal distribution sigma parameter (squared root of variance)
upscale (torch.Tensor or number, optional): 'a' scaling factor in the formula:
.. math::
loc = tanh(loc / upscale) * upscale.
Default is 5.0
min (torch.Tensor or number, optional): minimum value of the distribution. Default = -1.0;
max (torch.Tensor or number, optional): maximum value of the distribution. Default = 1.0;
tanh_loc (bool, optional): if ``True``, the above formula is used for
the location scaling, otherwise the raw value is kept.
Default is ``False``;
References:
Parameters
----------
loc (torch.Tensor):
normal distribution location parameter
scale (torch.Tensor):
normal distribution sigma parameter (squared root of variance)
min (torch.Tensor or number, optional):
minimum value of the distribution. Default = -1.0
max (torch.Tensor or number, optional):
maximum value of the distribution. Default = 1.0
upscale (torch.Tensor or number, optional):
scaling factor. Default = 5.0
tanh_loc (bool, optional): if ``True``, the above formula is used for
the location scaling, otherwise the raw value is kept.
Default is ``False``
References
----------
- https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
Notes
-----
This implementation is strongly based on:
- https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py
- https://github.com/toshas/torch_truncnorm
Expand Down

0 comments on commit c10f145

Please sign in to comment.