diff --git a/src/gluonts/torch/distributions/utils/truncated_normal.py b/src/gluonts/torch/distributions/utils/truncated_normal.py index 55edcbf10f..158ddc1763 100644 --- a/src/gluonts/torch/distributions/utils/truncated_normal.py +++ b/src/gluonts/torch/distributions/utils/truncated_normal.py @@ -11,9 +11,9 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -# from https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py - -# from https://github.com/toshas/torch_truncnorm +# The implementation is strongly inspired from: +# - https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py +# - https://github.com/toshas/torch_truncnorm import math from numbers import Number @@ -27,12 +27,13 @@ CONST_INV_SQRT_2 = 1 / math.sqrt(2) CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI) CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e) +torch.manual_seed(0) -class TruncatedStandardNormal(Distribution): - """Truncated Standard Normal distribution. +class TruncatedNormal(Distribution): + """Truncated Normal distribution. - Source: https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf """ arg_constraints = { @@ -42,15 +43,20 @@ class TruncatedStandardNormal(Distribution): has_rsample = True eps = 1e-6 - def __init__(self, a, b, validate_args=None): - self.a, self.b = broadcast_all(a, b) + def __init__(self, loc, scale, a, b): + scale = scale.clamp_min(self.eps) + self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b) + self._non_std_a = a + self._non_std_b = b + self.a = (a - self.loc) / self.scale + self.b = (b - self.loc) / self.scale + if isinstance(a, Number) and isinstance(b, Number): batch_shape = torch.Size() else: batch_shape = self.a.size() - super(TruncatedStandardNormal, self).__init__( - batch_shape, validate_args=validate_args - ) + + super(TruncatedNormal, self).__init__(batch_shape) if self.a.dtype != self.b.dtype: raise ValueError("Truncation bounds types are different") if any( @@ -61,6 +67,7 @@ def __init__(self, a, b, validate_args=None): .tolist() ): raise ValueError("Incorrect truncation range") + eps = self.eps self._dtype_min_gt_0 = eps self._dtype_max_lt_1 = 1 - eps @@ -85,6 +92,10 @@ def __init__(self, a, b, validate_args=None): self._entropy = ( CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z ) + self._log_scale = self.scale.log() + self._mean_non_std = self._mean * self.scale + self.loc + self._variance_non_std = self._variance * self.scale**2 + self._entropy_non_std = self._entropy + self._log_scale @constraints.dependent_property def support(self): @@ -92,19 +103,15 @@ def support(self): @property def mean(self): - return self._mean + return self._mean_non_std @property def variance(self): - return self._variance + return self._variance_non_std @property def entropy(self): - return self._entropy - - @property - def auc(self): - return self._Z + return self._entropy_non_std @staticmethod def _little_phi(x): @@ -118,54 +125,21 @@ def _big_phi(self, x): def _inv_big_phi(x): return CONST_SQRT_2 * (2 * x - 1).erfinv() - def cdf(self, value): + def cdf_truncated_standard_normal(self, value): if self._validate_args: self._validate_sample(value) return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1) - def icdf(self, value): + def icdf_truncated_standard_normal(self, value): y = self._big_phi_a + value * self._Z y = y.clamp(self.eps, 1 - self.eps) return self._inv_big_phi(y) - def log_prob(self, value): + def log_prob_truncated_standard_normal(self, value): if self._validate_args: self._validate_sample(value) return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5 - def rsample(self, sample_shape=None): - if sample_shape is None: - sample_shape = torch.Size([]) - shape = self._extended_shape(sample_shape) - p = torch.empty(shape, device=self.a.device).uniform_( - self._dtype_min_gt_0, self._dtype_max_lt_1 - ) - return self.icdf(p) - - -class TruncatedNormal(TruncatedStandardNormal): - """Truncated Normal distribution. - - https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - """ - - has_rsample = True - - def __init__(self, loc, scale, a, b, validate_args=None): - scale = scale.clamp_min(self.eps) - self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b) - self._non_std_a = a - self._non_std_b = b - a = (a - self.loc) / self.scale - b = (b - self.loc) / self.scale - super(TruncatedNormal, self).__init__( - a, b, validate_args=validate_args - ) - self._log_scale = self.scale.log() - self._mean = self._mean * self.scale + self.loc - self._variance = self._variance * self.scale**2 - self._entropy += self._log_scale - def _to_std_rv(self, value): return (value - self.loc) / self.scale @@ -173,10 +147,10 @@ def _from_std_rv(self, value): return value * self.scale + self.loc def cdf(self, value): - return super(TruncatedNormal, self).cdf(self._to_std_rv(value)) + return self.cdf_truncated_standard_normal(self._to_std_rv(value)) def icdf(self, value): - sample = self._from_std_rv(super().icdf(value)) + sample = self._from_std_rv(self.icdf_truncated_standard_normal(value)) # clamp data but keep gradients sample_clip = torch.stack( @@ -190,4 +164,13 @@ def icdf(self, value): def log_prob(self, value): value = self._to_std_rv(value) - return super(TruncatedNormal, self).log_prob(value) - self._log_scale + return self.log_prob_truncated_standard_normal(value) - self._log_scale + + def rsample(self, sample_shape=None): + if sample_shape is None: + sample_shape = torch.Size([]) + shape = self._extended_shape(sample_shape) + p = torch.empty(shape, device=self.a.device).uniform_( + self._dtype_min_gt_0, self._dtype_max_lt_1 + ) + return self.icdf(p)