Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement gamma sampler using core.Primitive interface #1790

Merged
merged 6 commits into from
Jan 8, 2020

Conversation

fehiepsi
Copy link
Contributor

@fehiepsi fehiepsi commented Dec 1, 2019

This PR tries to address #1789. While working on a fix, I realize that (compiling time) lax.map outperforms vmap for most usage cases. So I believe that the tests for gamma/beta/dirichlet will be much faster than before. With this, probably #552 and #1188 is not necessary to be addressed anymore.

However, I am getting trouble with translation rule. With the script

import jax; jax.config.update('jax_platform_name', 'cpu')
from jax import vmap, random, jit, grad
import jax.numpy as np

def f(a):
    return random.gamma(random.PRNGKey(0), a)

jit(f)(np.ones(3)) throws an error at printing function

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
~/miniconda3/envs/pydata/lib/python3.6/site-packages/IPython/core/formatters.py in __call__(self, obj)
    700                 type_pprinters=self.type_printers,
    701                 deferred_pprinters=self.deferred_printers)
--> 702             printer.pretty(obj)
    703             printer.flush()
    704             return stream.getvalue()

~/miniconda3/envs/pydata/lib/python3.6/site-packages/IPython/lib/pretty.py in pretty(self, obj)
    400                         if cls is not object \
    401                                 and callable(cls.__dict__.get('__repr__')):
--> 402                             return _repr_pprint(obj, self, cycle)
    403 
    404             return _default_pprint(obj, self, cycle)

~/miniconda3/envs/pydata/lib/python3.6/site-packages/IPython/lib/pretty.py in _repr_pprint(obj, p, cycle)
    695     """A pprint that just redirects to the normal repr function."""
    696     # Find newlines and replace them with p.break_()
--> 697     output = repr(obj)
    698     for idx,output_line in enumerate(output.splitlines()):
    699         if idx:

~/jax/jax/interpreters/xla.py in __repr__(self)
    655     line_width = onp.get_printoptions()['linewidth']
    656     prefix = '{}('.format(self.__class__.__name__)
--> 657     s = onp.array2string(self._value, prefix=prefix, suffix=',',
    658                          separator=', ', max_line_width=line_width)
    659     dtype_str = 'dtype={})'.format(self.dtype.name)

~/jax/jax/interpreters/xla.py in _value(self)
    608     if self._npy_value is None:
    609       self._npy_value = self.device_buffer.to_py()
--> 610       self._npy_value.flags.writeable = False
    611     return self._npy_value
    612 

AttributeError: 'tuple' object has no attribute 'flags'

while jit(f)(np.ones(())) throws the error

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
~/jax/jax/abstract_arrays.py in __len__(self)
    137     try:
--> 138       return self.shape[0]
    139     except IndexError:

IndexError: tuple index out of range

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
<ipython-input-3-6929b861a3fc> in <module>
----> 1 jit(f)(np.ones(()))

~/jax/jax/api.py in f_jitted(*args, **kwargs)
    148     _check_args(args_flat)
    149     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 150     out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
    151     return tree_unflatten(out_tree(), out)
    152 

~/jax/jax/core.py in call_bind(primitive, f, *args, **params)
    590   if top_trace is None:
    591     with new_sublevel():
--> 592       outs = primitive.impl(f, *args, **params)
    593   else:
    594     tracers = map(top_trace.full_raise, args)

~/jax/jax/interpreters/xla.py in _xla_call_impl(fun, *args, **params)
    373   device = params['device']
    374   backend = params.get('backend', None)
--> 375   compiled_fun = _xla_callable(fun, device, backend, *map(abstractify, args))
    376   try:
    377     return compiled_fun(*args)

~/jax/jax/linear_util.py in memoized_fun(fun, *args)
    207       fun.populate_stores(stores)
    208     else:
--> 209       ans = call(fun, *args)
    210       cache[key] = (ans, fun.stores)
    211     return ans

~/jax/jax/interpreters/xla.py in _xla_callable(fun, device, backend, *abstract_args)
    411   xla_consts = _map(c.Constant, consts)
    412   xla_args = _xla_callable_args(c, abstract_args, tuple_args)
--> 413   out_nodes = jaxpr_subcomp(c, jaxpr, backend, axis_env, xla_consts, (), *xla_args)
    414   built = c.Build(c.Tuple(*out_nodes))
    415 

~/jax/jax/interpreters/xla.py in jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, freevars, *args)
    260       ans = rule(c, *in_nodes, **eqn.params)
    261     elif eqn.primitive in translations:
--> 262       ans = translations[eqn.primitive](c, *in_nodes, **eqn.params)
    263     elif eqn.primitive in reduction_translations:
    264       new_params = check_backend_params(eqn.params, backend)

~/jax/jax/interpreters/xla.py in f(c, *args, **params)
    533     pvals = [pe.PartialVal((a, core.unit)) for a in avals]
    534     jaxpr, _, consts = pe.trace_to_jaxpr(
--> 535         lu.wrap_init(fun, params), pvals, instantiate=True)
    536     consts = _map(c.Constant, consts)
    537     outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, (), *xla_args)

~/jax/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, **kwargs)
    313   with new_master(JaxprTrace) as master:
    314     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 315     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    316     assert not env
    317     del master

~/jax/jax/linear_util.py in call_wrapped(***failed resolving arguments***)
    155     while stack:
    156       gen, out_store = stack.pop()
--> 157       ans = gen.send(ans)
    158       if out_store is not None:
    159         ans, side = ans

~/jax/jax/interpreters/partial_eval.py in trace_to_subjaxpr(master, instantiate, pvals)
    325   in_tracers = map(trace.new_arg, pvals)
    326   ans = yield in_tracers, {}
--> 327   instantiate = [instantiate] * len(ans) if type(instantiate) is bool else instantiate
    328   out_tracers = map(trace.full_raise, map(core.full_lower, ans))
    329   out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers)

~/jax/jax/core.py in __len__(self)
    292 
    293   def __len__(self):
--> 294     return self.aval._len(self)
    295 
    296   @property

~/jax/jax/abstract_arrays.py in _len(self, ignored_tracer)
    141 
    142   def _len(self, ignored_tracer):
--> 143     return len(self)
    144 
    145   def strip_weak_type(self):

~/jax/jax/abstract_arrays.py in __len__(self)
    138       return self.shape[0]
    139     except IndexError:
--> 140       raise TypeError("len() of unsized object")  # same as numpy error
    141 
    142   def _len(self, ignored_tracer):

TypeError: len() of unsized object

I guess the output with translation rule is a tuple and I need to pack it into an array, but I am not sure how to process. I tried to mimic custom_transform implementation to use tree_flatten, flatten_fun_nokwargs... but don't know how to write a compatible batching rule. Several months ago, the following code works

    jaxpr, out, consts = partial_eval.trace_unwrapped_to_jaxpr(_gamma_impl,
                                                               tuple(lax._abstractify(o) for o in (key, a)))
    aval, _ = out
    return random_gamma_p.bind(key, a, jaxpr=jaxpr, aval=aval, consts=consts)

def _random_gamma_translate(c, key, a, jaxpr, aval, consts):
    xla_computation = xla.jaxpr_computation(jaxpr, consts, (), c.GetShape(key), c.GetShape(alpha))
    return c.Call(xla_computation, (key, alpha))

but those utilities jaxpr_computation, trace_unwrapped_to_jaxpr are no longer available and I am not sure if a similar code will work now.

I appreciate any help!

cc @mattjj @shoyer @jekbradbury


Updated: thanks to @jekbradbury, I am able to make this work. This PR resolves #1789, #1875 (comment), #1188, #552 (actually, the fourth one is not relevant, but with the change in this PR, it seems fast to get samples from gamma sampler now).

@mattjj It would be great to have this for the next jax release. :)

@fehiepsi fehiepsi changed the title Not use custom transform for gamma sampler Implement gamma sampler using core.Primitive interface Dec 24, 2019
@@ -616,7 +618,6 @@ def beta(key, a, b, shape=None, dtype=onp.float64):
dtype = dtypes.canonicalize_dtype(dtype)
return _beta(key, a, b, shape, dtype)

@partial(jit, static_argnums=(3, 4))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to jit this because we don't need to compile gamma_b.

@fehiepsi
Copy link
Contributor Author

fehiepsi commented Jan 5, 2020

@mattjj Can we have this for the next JAX release? We need this fix to unblock some issues related to gamma sampler in NumPyro.

@mattjj
Copy link
Collaborator

mattjj commented Jan 5, 2020

Thanks for the ping. We’ve been slow the last few weeks because of NeurIPS then the holidays.

I promise to follow up on this on Monday.

And yes we’ll get it in the next release! Sorry for blocking you.

@fehiepsi
Copy link
Contributor Author

fehiepsi commented Jan 5, 2020

Please take your time! I saw a new release of JAX so I just wanted to make sure that you get notified about this PR. :)

Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, great work, and thanks!

My only suggestion is that I think we can get rid of random_gamma_p.multiple_results = True and then remove some singleton-tuplification in a few places. Let me know what you think of that, then let's merge!

@mattjj mattjj self-assigned this Jan 7, 2020
@mattjj
Copy link
Collaborator

mattjj commented Jan 7, 2020

I suppose we also need to resolve conflicts with master before being able to merge. Let me know if you need any help with that!

@mattjj mattjj merged commit 9cd5df1 into jax-ml:master Jan 8, 2020
@mattjj
Copy link
Collaborator

mattjj commented Jan 8, 2020

@fehiepsi just uploaded jax==0.1.57 to pypi with this change!

@fehiepsi
Copy link
Contributor Author

fehiepsi commented Jan 8, 2020

Thanks a lot, @mattjj !!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

grad and vmap do not composable with gamma sampler
3 participants