Skip to content

Commit

Permalink
add clapping to computation of log_prob
Browse files Browse the repository at this point in the history
  • Loading branch information
Pedro Eduardo Mercado Lopez committed Aug 21, 2023
1 parent ae9c46b commit e287242
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/gluonts/torch/distributions/truncated_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,12 @@ def icdf(self, value):
return sample

def log_prob(self, value):
a = self._non_std_a + self._dtype_min_gt_0
a = a.expand_as(value)
b = self._non_std_b - self._dtype_min_gt_0
b = b.expand_as(value)
value = torch.min(torch.stack([value, b], -1), dim=-1)[0]
value = torch.max(torch.stack([value, a], -1), dim=-1)[0]
value = self._to_std_rv(value)
return self.log_prob_truncated_standard_normal(value) - self._log_scale

Expand Down

0 comments on commit e287242

Please sign in to comment.