Skip to content

Commit

Permalink
Merge pull request #5074 from jakevdp:multigammaln-check
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 345144611
  • Loading branch information
jax authors committed Dec 2, 2020
2 parents 982fd35 + c43cfbd commit 621f34b
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
import numpy as np
import scipy.special as osp_special

from jax import lax
from jax import api
from jax import api, lax, core
from jax.interpreters import ad
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.lax_numpy import (asarray, _reduction_dims, _constant_like,
Expand Down Expand Up @@ -153,8 +152,9 @@ def entr(x):

@_wraps(osp_special.multigammaln, update_doc=False)
def multigammaln(a, d):
a, = _promote_args_inexact("multigammaln", a)
d = lax.convert_element_type(d, lax.dtype(a))
d = core.concrete_or_error(int, d, "d argument of multigammaln")
a, d = _promote_args_inexact("multigammaln", a, d)

constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d),
lax.sub(d, _constant_like(a, 1))),
lax.log(_constant_like(a, np.pi)))
Expand Down

0 comments on commit 621f34b

Please sign in to comment.