Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement gamma sampler using core.Primitive interface #1790

Merged
merged 6 commits into from
Jan 8, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
from jax import abstract_arrays
from jax.scipy.special import logit
from jax.scipy.linalg import cholesky
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla


Expand Down Expand Up @@ -878,18 +880,46 @@ def _case4(zagf):
def _gamma_grad(sample, a):
samples = np.reshape(sample, -1)
alphas = np.reshape(a, -1)
grads = vmap(_gamma_grad_one)(samples, alphas)
if np.size(alphas) == 1:
grads = _gamma_grad_one(samples[0], alphas[0])
else:
# TODO: benchmark execute time against grads = vmap(_gamma_grad_one)(samples, alphas)
grads = lax.map(lambda args: _gamma_grad_one(*args), (samples, alphas))
return grads.reshape(onp.shape(a))

@custom_transforms
def _gamma_impl(key, a):
alphas = np.reshape(a, -1)
keys = split(key, onp.size(alphas))
samples = vmap(_gamma_one)(keys, alphas)
return np.reshape(samples, onp.shape(a))

defjvp(_gamma_impl, None,
lambda tangent, ans, key, a, **kwargs: tangent * _gamma_grad(ans, a))
if key.ndim == 2: # batch of keys and alphas
size = np.size(a[0])
if size > 1:
key = lax.map(lambda k: split(k, size), key)
else:
size = np.size(a)
if size > 1:
key = split(key, size)
alphas = np.reshape(a, -1)
keys = np.reshape(key, (-1, 2))
if np.size(alphas) == 1:
samples = _gamma_one(keys[0], alphas[0])
else:
# XXX in GPU, using lax.map is slower than using vmap if alphas.size > 50000
# but that usage case is rare and can be resolved by vectorizing gamma sampler
samples = lax.map(lambda args: _gamma_one(*args), (keys, alphas))
return np.reshape(samples, np.shape(a))

def _gamma_batching_rule(batched_args, batch_dims):
k, a = batched_args
bk, ba = batch_dims
size = next(t.shape[i] for t, i in zip(batched_args, batch_dims) if i is not None)
k = batching.bdim_at_front(k, bk, size)
a = batching.bdim_at_front(a, ba, size)
return random_gamma_p.bind(k, a), 0

random_gamma_p = core.Primitive('random_gamma')
random_gamma_p.def_impl(_gamma_impl) # partial(xla.apply_primitive, random_gamma_p))
random_gamma_p.def_abstract_eval(lambda key, a: abstract_arrays.raise_to_shaped(a))
ad.defjvp2(random_gamma_p, None, lambda tangent, ans, key, a, **kwargs: tangent * _gamma_grad(ans, a))
xla.translations[random_gamma_p] = xla.lower_fun(_gamma_impl, instantiate=True)
batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule

def gamma(key, a, shape=None, dtype=onp.float64):
"""Sample Gamma random values with given shape and float dtype.
Expand All @@ -911,7 +941,7 @@ def gamma(key, a, shape=None, dtype=onp.float64):
dtype = dtypes.canonicalize_dtype(dtype)
return _gamma(key, a, shape, dtype)

@partial(jit, static_argnums=(2, 3))
# @partial(jit, static_argnums=(2, 3))
def _gamma(key, a, shape, dtype):
if shape is None:
shape = onp.shape(a)
Expand All @@ -921,7 +951,7 @@ def _gamma(key, a, shape, dtype):
a = lax.convert_element_type(a, dtype)
if onp.shape(a) != shape:
a = np.broadcast_to(a, shape)
return _gamma_impl(key, a)
return random_gamma_p.bind(key, a)


def gumbel(key, shape=(), dtype=onp.float64):
Expand Down
8 changes: 8 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
import scipy.stats

from jax import api
from jax import grad
from jax import lax
from jax import numpy as np
from jax import random
from jax import test_util as jtu
from jax import vmap
from jax.interpreters import xla

from jax.config import config
Expand Down Expand Up @@ -443,6 +445,12 @@ def testIssue756(self):
else:
self.assertEqual(onp.result_type(w), onp.float32)

def testIssue1789(self):
def f(x):
return random.gamma(random.PRNGKey(0), x)

grad(lambda x: np.sum(vmap(f)(x)))(np.ones(2))

def testNoOpByOpUnderHash(self):
def fail(*args, **kwargs): assert False
apply_primitive, xla.apply_primitive = xla.apply_primitive, fail
Expand Down