Skip to content

Commit

Permalink
simply classes
Browse files Browse the repository at this point in the history
  • Loading branch information
Pedro Eduardo Mercado Lopez committed Aug 18, 2023
1 parent 03c6cea commit 6c24985
Showing 1 changed file with 40 additions and 57 deletions.
97 changes: 40 additions & 57 deletions src/gluonts/torch/distributions/utils/truncated_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# from https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py

# from https://github.com/toshas/torch_truncnorm
# The implementation is strongly inspired from:
# - https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py
# - https://github.com/toshas/torch_truncnorm

import math
from numbers import Number
Expand All @@ -27,12 +27,13 @@
CONST_INV_SQRT_2 = 1 / math.sqrt(2)
CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI)
CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e)
torch.manual_seed(0)


class TruncatedStandardNormal(Distribution):
"""Truncated Standard Normal distribution.
class TruncatedNormal(Distribution):
"""Truncated Normal distribution.
Source: https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
"""

arg_constraints = {
Expand All @@ -42,15 +43,20 @@ class TruncatedStandardNormal(Distribution):
has_rsample = True
eps = 1e-6

def __init__(self, a, b, validate_args=None):
self.a, self.b = broadcast_all(a, b)
def __init__(self, loc, scale, a, b):
scale = scale.clamp_min(self.eps)
self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b)
self._non_std_a = a
self._non_std_b = b
self.a = (a - self.loc) / self.scale
self.b = (b - self.loc) / self.scale

if isinstance(a, Number) and isinstance(b, Number):
batch_shape = torch.Size()
else:
batch_shape = self.a.size()
super(TruncatedStandardNormal, self).__init__(
batch_shape, validate_args=validate_args
)

super(TruncatedNormal, self).__init__(batch_shape)
if self.a.dtype != self.b.dtype:
raise ValueError("Truncation bounds types are different")
if any(
Expand All @@ -61,6 +67,7 @@ def __init__(self, a, b, validate_args=None):
.tolist()
):
raise ValueError("Incorrect truncation range")

eps = self.eps
self._dtype_min_gt_0 = eps
self._dtype_max_lt_1 = 1 - eps
Expand All @@ -85,26 +92,26 @@ def __init__(self, a, b, validate_args=None):
self._entropy = (
CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z
)
self._log_scale = self.scale.log()
self._mean_non_std = self._mean * self.scale + self.loc
self._variance_non_std = self._variance * self.scale**2
self._entropy_non_std = self._entropy + self._log_scale

@constraints.dependent_property
def support(self):
return constraints.interval(self.a, self.b)

@property
def mean(self):
return self._mean
return self._mean_non_std

@property
def variance(self):
return self._variance
return self._variance_non_std

@property
def entropy(self):
return self._entropy

@property
def auc(self):
return self._Z
return self._entropy_non_std

@staticmethod
def _little_phi(x):
Expand All @@ -118,65 +125,32 @@ def _big_phi(self, x):
def _inv_big_phi(x):
return CONST_SQRT_2 * (2 * x - 1).erfinv()

def cdf(self, value):
def cdf_truncated_standard_normal(self, value):
if self._validate_args:
self._validate_sample(value)
return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1)

def icdf(self, value):
def icdf_truncated_standard_normal(self, value):
y = self._big_phi_a + value * self._Z
y = y.clamp(self.eps, 1 - self.eps)
return self._inv_big_phi(y)

def log_prob(self, value):
def log_prob_truncated_standard_normal(self, value):
if self._validate_args:
self._validate_sample(value)
return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5

def rsample(self, sample_shape=None):
if sample_shape is None:
sample_shape = torch.Size([])
shape = self._extended_shape(sample_shape)
p = torch.empty(shape, device=self.a.device).uniform_(
self._dtype_min_gt_0, self._dtype_max_lt_1
)
return self.icdf(p)


class TruncatedNormal(TruncatedStandardNormal):
"""Truncated Normal distribution.
https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
"""

has_rsample = True

def __init__(self, loc, scale, a, b, validate_args=None):
scale = scale.clamp_min(self.eps)
self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b)
self._non_std_a = a
self._non_std_b = b
a = (a - self.loc) / self.scale
b = (b - self.loc) / self.scale
super(TruncatedNormal, self).__init__(
a, b, validate_args=validate_args
)
self._log_scale = self.scale.log()
self._mean = self._mean * self.scale + self.loc
self._variance = self._variance * self.scale**2
self._entropy += self._log_scale

def _to_std_rv(self, value):
return (value - self.loc) / self.scale

def _from_std_rv(self, value):
return value * self.scale + self.loc

def cdf(self, value):
return super(TruncatedNormal, self).cdf(self._to_std_rv(value))
return self.cdf_truncated_standard_normal(self._to_std_rv(value))

def icdf(self, value):
sample = self._from_std_rv(super().icdf(value))
sample = self._from_std_rv(self.icdf_truncated_standard_normal(value))

# clamp data but keep gradients
sample_clip = torch.stack(
Expand All @@ -190,4 +164,13 @@ def icdf(self, value):

def log_prob(self, value):
value = self._to_std_rv(value)
return super(TruncatedNormal, self).log_prob(value) - self._log_scale
return self.log_prob_truncated_standard_normal(value) - self._log_scale

def rsample(self, sample_shape=None):
if sample_shape is None:
sample_shape = torch.Size([])
shape = self._extended_shape(sample_shape)
p = torch.empty(shape, device=self.a.device).uniform_(
self._dtype_min_gt_0, self._dtype_max_lt_1
)
return self.icdf(p)

0 comments on commit 6c24985

Please sign in to comment.