Skip to content

Commit

Permalink
fix: Replace unsafe under vectorize pt.zeros(value.shape) with zeros_…
Browse files Browse the repository at this point in the history
…like
  • Loading branch information
ferrine authored and ricardoV94 committed Jul 19, 2024
1 parent 8fd4f1c commit f6d1b33
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ def log_jac_det(self, value, *inputs):
N = N.astype(value.dtype)
sum_value = pt.sum(value, -1, keepdims=True)
value_sum_expanded = value + sum_value
value_sum_expanded = pt.concatenate([value_sum_expanded, pt.zeros(sum_value.shape)], -1)
value_sum_expanded = pt.concatenate([value_sum_expanded, pt.zeros_like(sum_value)], -1)
logsumexp_value_expanded = pt.logsumexp(value_sum_expanded, -1, keepdims=True)
res = pt.log(N) + (N * sum_value) - (N * logsumexp_value_expanded)
return pt.sum(res, -1)
Expand All @@ -977,7 +977,7 @@ def forward(self, value, *inputs):
return pt.as_tensor_variable(value)

def log_jac_det(self, value, *inputs):
return pt.zeros(value.shape)
return pt.zeros_like(value)


class ChainedTransform(Transform):
Expand Down

0 comments on commit f6d1b33

Please sign in to comment.