diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 5f9b8778bfdc..b980391193df 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2113,6 +2113,7 @@ def _naryop_weak_type_rule(name, *avals, **kwargs): return all(aval.weak_type for aval in avals) def naryop(result_dtype, accepted_dtypes, name, translation_rule=None): + # TODO(frostig,mattjj): only used with arity > 2 once, simplify dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name) shape_rule = partial(_broadcasting_shape_rule, name) weak_type_rule = partial(_naryop_weak_type_rule, name) diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index 86f6e4553e21..c81edef7948f 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -331,7 +331,8 @@ def broadcast_batcher(prim, args, dims, **params): return (out, (d,) * len(out)) if prim.multiple_results else (out, d) else: size, = {shape[d] for shape, d in shapes if d is not not_mapped} - args = [bdim_at_front(x, d, size) for x, d in zip(args, dims)] + args = [bdim_at_front(x, d, size) if np.ndim(x) else x + for x, d in zip(args, dims)] ndim = max(np.ndim(x) for x in args) # special-case scalar broadcasting args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)] out = prim.bind(*args, **params) diff --git a/tests/batching_test.py b/tests/batching_test.py index 9aec166b9a93..03f463f3d174 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -21,6 +21,7 @@ import jax import jax.numpy as jnp +import jax.scipy as jsp from jax import test_util as jtu from jax import lax from jax._src.lax import parallel @@ -1240,5 +1241,13 @@ def testNonJaxTypedOutput(self): TypeError, "Output from batched function.*is not a valid JAX type"): vmap(lambda x: "hello")(np.arange(5)) + def testIssue6096(self): + def f(x): + return jsp.special.betainc(jnp.ones(3), 1., x) + + self.assertEquals(f(jnp.ones(3)).shape, (3,)) + self.assertEquals(jax.vmap(f)(jnp.ones((2, 3))).shape, (2, 3)) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())