Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a better error message for grad of while not supported #2497

Merged
merged 1 commit into from
Apr 17, 2020

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Mar 24, 2020

The issue that I wanted to fix was that when running grad(lax.while_loop),
the error was a cryptic assertion failure (that all primals are known
after linearization, in ad.py:linearize). I could not figure out
how to detect before that assertion that we are doing a reverse AD
for while_loop. If I turn off that assertion the code fails a bit further down.

So, I implemented a simple form of partial evaluation, to allow the primals after linearization to be known, so that the code proceeds and can then fail gracefully when trying to transpose the while.

This is not a proper implementation of partial evaluation. The known outputs are computed early, properly. But the unknown outputs are computed by a whole computation of the loop, including the known parts. There are no residuals saved from the known computation. Perhaps one can do better, but it is not obvious. In any case, this code is only useful for a better error message :-(

Fixes issue: #2129

@gnecula gnecula requested review from dougalm and mattjj March 24, 2020 09:41
@mattjj
Copy link
Collaborator

mattjj commented Mar 25, 2020

I am behind in getting to this because I'm still trying to fix bugs from #2026. Sorry! I hope to get to it sometime this week. LMK if that doesn't work.

@gnecula
Copy link
Collaborator Author

gnecula commented Mar 25, 2020

That will work. Another one to look at is #2455, cjfj@ is waiting on that one also.

@gnecula
Copy link
Collaborator Author

gnecula commented Apr 6, 2020

@mattjj and @dougalm: I'd like your thoughts on this.

Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only after looking at and thinking about this change, I had an idea for an alternative, simpler way to get a good error message here.

There's a tag on JaxprTraces (from #1848) that indicates whether we're doing partial evaluation for the purpose of staging things out (in which case we likely want to stage out the whole while_loop primitive, see #2638) or for the purpose of reverse-mode autodiff. I think we can use that tag to raise the error we want here.

Concretely, I think we can do this:

def _while_loop_partial_eval(trace, *tracers, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr):
  if trace.master.trace_type is pe.StagingJaxprTrace:
    params = dict(cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr, body_nconsts=body_nconsts, body_jaxpr=body_jaxpr)
    return trace.default_process_primitive(while_p, tracers, params)
  else:
    raise ValueError("Reverse-mode differentiation does not work for lax.while_loop.")

I haven't tried this yet though...

EDIT: actually, this precise idea won't work: I forgot that control flow primitives trace without that tag. So, the tag right now indicates whether the JaxprTrace is for (1) reverse-mode autodiff or control flow, versus (2) jit/pmap. We need to further distinguish between the first two sub-cases. Still, I think this tag-based direction, where we just tag JaxprTraces based on their purpos and check the tag, is probably a good one because it's simpler... WDYT?

@@ -99,6 +99,8 @@ def process_primitive(self, primitive, tracers, params):
return self.default_process_primitive(primitive, tracers, params)

def default_process_primitive(self, primitive, tracers, params):
"""By default, if all the input tracers are known, then execute the primitive
and all the ouputs are known. Otherwise, all the outputs are abstract."""
Copy link
Collaborator

@mattjj mattjj Apr 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! How about replacing "abstract" with "unknown"? Also perhaps remove "By default".

@gnecula
Copy link
Collaborator Author

gnecula commented Apr 9, 2020

We need to further distinguish between the first two sub-cases. Still, I think this tag-based direction, where we just tag JaxprTraces based on their purpos and check the tag, is probably a good one because it's simpler... WDYT?

I'd love this to work, this is what I was looking for. But how could it be enough to check the master trace type, don't we have to somehow inspect the entire trace stack?

@mattjj
Copy link
Collaborator

mattjj commented Apr 9, 2020

We'll end up checking all the partial eval traces on the trace stack because for each one we'll call into the _while_loop_partial_eval rule. If any one of them fails, i.e. if any one of them finds it's responsible for a reverse-mode autodiff partial eval, then we'll get an exception (but if all of them are instead either for outer control flow primitives or for jit/pmap staging then it'll go through).

We can actually attach any info to the master trace, like we can also just set trace.master.some_property = True anywhere. We don't need to do it via a type tag on master.trace_type. That just seemed like the smallest change at the time of #1848. But let's put our heads together and think about how best to attach this metadata for JaxprTraces!

Should we start with an enum that just indicates whether it's for reverse mode autodiff, control flow, or staging out of the system (i.e. jit/pmap)?

@gnecula gnecula force-pushed the while_partial_eval branch from 9d52ece to b7db174 Compare April 10, 2020 09:44
@gnecula
Copy link
Collaborator Author

gnecula commented Apr 10, 2020

I have tried something along the lines you suggested, but I think I still don't understand.

First, please check the comment that I wrote at the top of JaxprTrace. It is a refinement of a comment you had in JaxprTrace.process_custom_jvp_calll, adding the case for linearization.

I decided to try for now a simple change where we tag the linearization case, to separate it out of the
handling of control-flow. It seems easier to tag the former, because control-flow handling is quite
pervasive. I added LinearizeJaxprTrace, a sibling of StagingJaxprTrace.

The problem is with grad(cond(while)). By the time I get to partial eval of the nested while, the master that was created for linearize is not in the current trace.

The only case that does not work

@gnecula
Copy link
Collaborator Author

gnecula commented Apr 10, 2020

BTW, if I scan core.trace_stack I do see the LinearizeJaxprTrace. Can't a MasterTrace have parents?

@mattjj
Copy link
Collaborator

mattjj commented Apr 11, 2020

Hmm, what's a parent in this context?

@mattjj
Copy link
Collaborator

mattjj commented Apr 11, 2020

The comments LGTM!

@mattjj
Copy link
Collaborator

mattjj commented Apr 11, 2020

Tricky... I think the issue is that when we do partial evaluation of control flow primitives, with things like partial_eval_jaxpr, we need to propagate this tag information. That is, we're kicking off a new master when we call partial_eval_jaxpr inside the partial eval rule of cond; the fact that the cond partial eval is happening because of linearization is not being forwarded through partial_eval_jaxpr.

@mattjj
Copy link
Collaborator

mattjj commented Apr 11, 2020

Yes, that's the issue. The test passes after

  1. adding a trace_type argument to partial_eval_jaxpr (or we could do the for_linearize / stage_out signature, though maybe a single argument is easier)
  2. changing partial_eval_jaxpr by adding the line for_linearize = trace_type is LinearizeJaxprTrace at the top, then making the two calls to trace_to_jaxpr in its body pass that value

These initial style control flow primitives are the trickiest, at least the way we process them by round-tripping jaxprs back into traceables via jaxpr_as_fun/eval_jaxpr and tracing them. That's caused me some headaches earlier today with the PRNG reuse checker I'm prototyping. Maybe we should brainstorm how we might process jaxprs without involving eval_jaxpr. It could be as simple as writing a couple separate jaxpr interpreters. (In effect right now we are reusing the same eval_jaxpr interpreter, but transforming it with our usual Python transforms.) To be honest, though, I'm not sure if that would actually help much; it's just an alternative that would eliminate one level of tracing.

@gnecula
Copy link
Collaborator Author

gnecula commented Apr 11, 2020

Hmm, what's a parent in this context?

A better name is previous: the trace that was on top of the stack when you pushed this one. In essence, this would allow one to scan the trace stack from any particular tracer.

@gnecula gnecula force-pushed the while_partial_eval branch from de76e77 to a12c5ba Compare April 11, 2020 11:17
@gnecula gnecula changed the title Add a simple form of partial evaluation for while_loop. Add a better error message for grad of while not supported Apr 11, 2020
@gnecula
Copy link
Collaborator Author

gnecula commented Apr 11, 2020

I have implemented your suggestion. The one difference is that instead of passing for_linearize flag to trace_to_jaxpr I just pass trace_type. This now subsumes stage_out, but I left that alone.

BTW, the previous implementation had the minor advantage of actually implementing linearize for while_loop. Now we get the same error as for grad.

@gnecula
Copy link
Collaborator Author

gnecula commented Apr 11, 2020

Looks like random.gamma does not work, its VJP uses while (superficially, could be removed, see here) .

But the problem still stands, the current version give spurious errors if while_loop is in the VJP. The error is deserved if you want to do higher-order differentiation, but not for the first differentiation. Looks like I need to adjust this check further.

@mattjj
Copy link
Collaborator

mattjj commented Apr 11, 2020

A better name is previous: the trace that was on top of the stack when you pushed this one. In essence, this would allow one to scan the trace stack from any particular tracer.

Ah I see, thanks. Personally I try not to think in terms of the trace stack; every trace just has to worry about itself, and the core handles the fact that there are multiple traces going on.

BTW, the previous implementation had the minor advantage of actually implementing linearize for while_loop. Now we get the same error as for grad.

Oh, I didn't realize that. That would be a nice property to have. Should we revive the other solution then? Especially also given the gamma issue.

These refined "jaxpr trace type" tags are probably nice to keep around for other reasons, but making while_loop work with linearize is a big plus with your original approach.

@gnecula gnecula force-pushed the while_partial_eval branch from a12c5ba to 3ec15d1 Compare April 12, 2020 07:34
@gnecula
Copy link
Collaborator Author

gnecula commented Apr 12, 2020

I have revived the partial evaluation of while_loop. I kept the passing of trace_type along, but I dropped the LinearizeJaxprTrace type tag.

I test that linearize works but grad errors gracefully.

@gnecula gnecula force-pushed the while_partial_eval branch from 2ce39c4 to f153dd6 Compare April 12, 2020 12:25
Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

@gnecula gnecula force-pushed the while_partial_eval branch from f153dd6 to c2ef097 Compare April 15, 2020 15:04
@cghawthorne
Copy link

Quick suggestion for the error message: I think it would be helpful if it suggested the user try lax.scan instead. This wasn't obvious to me when I ran into this same problem and had to ask for help from @jekbradbury

The issue that I wanted to fix was that when running grad(while_loop),
the error was a cryptic assertion failure (that all primals are known
after linearization, in ad.py:linearize). I could not figure out
how to detect before that assertion that we are doing a reverse AD
for while_loop. So, I implemented a simple form of partial evaluation,
to allow the primals after linearization to be known, so that the
code proceeds and can then fail gracefully when trying to transpose the
while.

This is not a proper implementation of partial evaluation. The known
outputs are computed early, properly. But the unknown outputs
are computed by a *whole* computation of, including the known
parts.

Fixes issue: jax-ml#2129
@gnecula gnecula force-pushed the while_partial_eval branch from c2ef097 to 23e8247 Compare April 17, 2020 16:11
@gnecula gnecula merged commit 7d716b8 into jax-ml:master Apr 17, 2020
jacobjinkelly pushed a commit to jacobjinkelly/jax that referenced this pull request Apr 21, 2020
The issue that I wanted to fix was that when running grad(while_loop),
the error was a cryptic assertion failure (that all primals are known
after linearization, in ad.py:linearize). I could not figure out
how to detect before that assertion that we are doing a reverse AD
for while_loop. So, I implemented a simple form of partial evaluation,
to allow the primals after linearization to be known, so that the
code proceeds and can then fail gracefully when trying to transpose the
while.

This is not a proper implementation of partial evaluation. The known
outputs are computed early, properly. But the unknown outputs
are computed by a *whole* computation of, including the known
parts.

Fixes issue: jax-ml#2129
@gnecula gnecula deleted the while_partial_eval branch May 10, 2020 15:41
gnecula added a commit that referenced this pull request May 11, 2020
* Implement jax.ops.index_mul. (#2696)

* Implement jax.ops.index_mul.

* Add index_mul to documentation.

* Fix RHS JVP rule for scatter_mul, fix test bug that meant it was not tested.

* Fix typo in docstring.

* Add missing functions to autodoc

* Update XLA. (#2703)

* Fix packbits/unpackbits tests (#2702)

* Make type of value_and_grad slightly more precise. (#2704)

* Fix minor typo in cell (#2692)

* Fix minor typo in cell

One of the arguments to `hvp` wasn't being used, which made the example slightly confusing.

* Fix both definitions of hvp in the autodiff cookbook.

Co-authored-by: Peter Hawkins <phawkins@google.com>

* Update semantics of to_dlpack. (#2707)

to_dlpack now takes ownership of the original buffer, leaving it in an invalid state.

* Add type annotations to optix. (#2687)

* Add type annotations to optix.

* Fix function signature for chain() and remove unused collections import.

* Include Sequence[OptState] as possible output of Init.

* Update np.linalg docs with missing funcitons (#2710)

* Update np.linalg docs with missing funcitons

* Update np.linalg docs with missing funcitons

* Implement numpy fmin() & fmax() (#2711)

* Implement numpy fmin() & fmax()

* Use Tuple[int, ...] rather than Sequence[int] in jnp.ndarray shape annotation.

* Fix some test failures. (#2713)

* Instantiate RNG in testcase instead of test harness in a few more places. (#2706)

* Plumb precision argument into convolution in signal.py (#2715)

* Modify syntax to `x.at[idx].set(y)` and similar.

* Add support for `mul`

* Remove unused textwrap

* Fixup complex values and tol in tests for jax.scipy.linalg.sparse.cg (#2717)

* Fixup complex values and tol in tests for jax.scipy.linalg.sparse.cg

The tests for CG were failing on TPUs:

- `test_cg_pytree` is fixed by requiring slightly less precision than the
  unit-test default.
- `test_cg_against_scipy` is fixed for complex values in two independent ways:
  1. We don't set both `tol=0` and `atol=0`, which made the termination
     behavior of CG (convergence or NaN) dependent on exactly how XLA handles
     arithmetic with denormals.
  2. We make use of *real valued* inner products inside `cg`, even for complex
     values. It turns that all these inner products are mathematically
     guaranteed to yield a real number anyways, so we can save some flops and
     avoid ill-defined comparisons of complex-values (see
     numpy/numpy#15981) by ignoring the complex part
     of the result from `jnp.vdot`. (Real numbers also happen to have the
     desired rounding behavior for denormals on TPUs, so this on its own would
     also fix these failures.)

* comment fixup

* fix my comment

* Explicitly build specific CUDA capabilities. (#2722)

We choose the same set as TensorFlow (minus 3.7, which TF is apparently considering dropping anyway).

This avoids a slow PTX -> SASS compilation on first time startup.

* Add a dynamic type check that the value returned by an XLA translation rule is an XlaOp. (#2723)

Helps give a more understandable error on erroneous translation rules.

* Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142)

This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication.

This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks.

Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!).

**pmap_shard_args**
```
---------Benchmark summary for pmap_shard_args---------
  nargs    nshards       mean       %std    relative    mean/baseline
-------  ---------  ---------  ---------  ----------  ---------------
     10          8  0.041855    4.15223      1               1.01466
    100          8  0.129884    4.85321      3.1032          0.988543
    101          8  0.136347    6.20233      3.2576          0.967138
    500          8  0.533207    3.6815      12.7394          1.0294
   1000          8  1.10338     0.525193    26.362           0.960435
   5000          8  5.33911     0          127.562           0.963319
    100          2  0.0638619  10.7069       1.52579         1.0362
    100          4  0.0868253   6.76701      2.07443         0.967323
    100          8  0.128151    6.46004      3.06177         0.979742
    100        100  1.22631     1.94885     29.299           1.00371
    100        500  6.60746     0          157.865           0.956657
```
**pmap_shard_outputs**
```
  nouts    nshards        mean       %std    relative    mean/baseline
-------  ---------  ----------  ---------  ----------  ---------------
     10          8   0.0664526   9.49251      1               0.938466
    100          8   0.195711    2.19429      2.94512         1.04239
    500          8   0.82577     0.330864    12.4265          0.994669
   1000          8   1.68323     1.0516      25.3298          0.966915
   5000          8   8.89032     0          133.784           0.998038
    100          2   0.074806   10.1734       1.12571         0.980254
    100          4   0.121334    5.76774      1.82588         1.02033
    100          8   0.185253    5.45068      2.78775         1.01666
    100        100   2.37076     0           35.6759          1.08629
    100        500  17.0832      0          257.074           0.976879
```
**ShardedDeviceArray_indexing**
```
indices_fn                mean     %std    relative    mean/baseline
------------------  ----------  -------  ----------  ---------------
integer_indices      0.0603473  8.29159       1             0.359496
integer_2D_indices  18.0241     0           298.672         1.00583
```

This is how I ran the benchmark:
```
TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7>
```

* Add type hint to fix pytype error. (#2727)

Without this, pytype (correctly) points out that AbstractValues do not have shape/type information.

* Implement numpy.linalg.multi_dot (#2726)

* Implement numpy.linalg.multi_dot

* Thread precision through multi_dot

* Update XLA. (#2733)

* Temporarily make ShardedDeviceArray.__init__ optionally accept old si… (#2730)

This allows us to incrementally update ShardedDeviceArray creators to the new constructor introduced in 07571ae.

* Fix some bugs in _reshape_sharded_device_array (#2732)

* Fix copy-paste error

It looks as though `_device_put_scalar` should be used here. If not, `_device_put_scalar` should be removed, as it is otherwise unused.

* Release jaxlib 0.1.44. (#2740)

* Add numpy.rint to lax numpy (#2724)

* Add numpy.rint to lax numpy

* Use round_to_nearest_even for numpy.rint

* Add rint to jax.numpy docs

* Fix np.rint float promotion

* Some cleanup and reformatting in `xla.py`.

- Make creation of a few dictionaries more readable.
- Use f-strings where possible.
- Remove unused imports and function parameters.
- Don't format string before passing to `log` function.

* Add a regression test that runs the same computation on all devices that are present. (#2741)

* fix scipy_signal_test convolve failures

* Update names and documentation.

* Fix typo

* Thread precision through np.convolve & np.correlate

* set precision=HIGHEST only for TPU test

* Add FIXMEs for AD type errors

* Add a simple form of partial evaluation for while_loop. (#2497)

The issue that I wanted to fix was that when running grad(while_loop),
the error was a cryptic assertion failure (that all primals are known
after linearization, in ad.py:linearize). I could not figure out
how to detect before that assertion that we are doing a reverse AD
for while_loop. So, I implemented a simple form of partial evaluation,
to allow the primals after linearization to be known, so that the
code proceeds and can then fail gracefully when trying to transpose the
while.

This is not a proper implementation of partial evaluation. The known
outputs are computed early, properly. But the unknown outputs
are computed by a *whole* computation of, including the known
parts.

Fixes issue: #2129

* Assert that reduction computations don't have constants. (#2754)

This case wouldn't work anyway, because there's no good way to pass constants to an XLA reducer.

* Raise an error if stop_gradient is called on non-arrays (#2750)

* Raise an error if stop_gradient is called on non-arrays

* Fix incorrect usage of stop_gradient in solve()

* fix *other* misuse of stop_gradient

* skip two unreliable tests

* Update automatic jaxlib install command to use nvidia-smi instead of nvcc. (#2758)

This is just to get the CUDA version number, and nvidia-smi is more
commonly available.

* Fixes in the FAQ for RST (#2761)

* add adamax

* add adamax test

* Added FAQ entry about relationship between VJP and JVP (#2762)

* Delete cotangent references on their last use (#2719)

* Delete cotangent references on their last use

Current implementation of transposition may add a factor of 2x to
peak memory usage in real cases and _potentially an unbounded factor_
in pathological programs. The reason why this happens is because the
cotangents computed by the `backward_pass` are never evicted from the
environment until the whole transposition is complete. Other systems
(e.g. PyTorch) generally make use of refcounting or liveness analysis
to remove unnecessary references as soon as they are known to no
longer be needed.

A simple example that showcases this issue is this:
```python
def f(x):
  for i in range(1000):
    x = x * 4
  return x

x = np.ones(4)
vjp(f, x)[1](x)
```

Adding `print(len(ct_env))` at the end of `backward_pass` reveals that
the dictionary actually holds a thousand `DeviceArray`s, while both the
forward and backward can be computed in constant memory. Of course this
is the pathological example I mentioned above, but one can easily see
that keeping the cotangents alive for the whole duration of differentiation
causes the memory cost to be approximately `fwd_coefs + all_fwd_intermediates`
instead of `fwd_memory + program_pathwidth` where:
* `fwd_coefs` is the amount of memory necessary to store all constant
  coefficients of the linearized function
* `all_fwd_intermediates` is the amount of memory necessary to
  store _all intermediates appearing in the forward program_.
* `program_pathwidth` is the maximum over amounts of memory necessary
  to store the live values over all transposed program locations

Note that usually we have that
`all_fwd_intermediates > fwd_coefs >> program_pathwidth`
(`>>` meaning that the RHS is usually significantly smaller).

* Import Set

* Use a list instead of a dict

* Type annotation

* Import List

* Fix confusing documentation typo. (#2773)

* Implement np.unique (#2760)

* Implement np.unique

This is an implementation of np.unique
It follows the original numpy implementation of sorting.
While unique it self is intrinsically hard to make compatible
with jit, a helper function has been added which is compatible.
This function could for example be used for jit-compatible
computation of number of unique elements.

* Add test for np.unique

This test tests all possible combinations of inputs for np.unique
with the standard generated array inputs

* Fix return type of the inverse

* Remove complex arrays from np.unique

Since xla can not do size comparisons between complex numbers, and
np.unique depends on np.sort they are removed as possible input.

* Add jit wrap to _unique1d_sorted_mask

* Add pmean to lax documentation (#2778)

* Fix distribution name in docstring (#2764)

* Skip jnp.unique test on GPU (#2780)

Broken due to use of unstable sort (#2779).

* also skip jax.numpy.unique test on tpu

* update version and changelog for pypi

* Bump jaxlib version to 0.1.45 and update WORKSPACE and CHANGELOG. (#2785)

* autodiff cookbook: assume continuous second derivatives

fixes #2772

* attempt to fix changelog formatting bugs

* try optimize=True with einsum

closes #2583

can revert if this ends up problematic for some reason!

* Bump jaxlib version in README to 0.1.45

* apply is_stable=True to sort translation rules (#2789)

fixes #2779

* factor out process_map / post_process_map (#2788)

* factor out process_map / post_process_map

Also fix a bug from reusing post_process_call for pmap. Fixes #2787

* consolidate call_bind / map_bind code

* changelog fixes

* adjust test tolerance for tpu

* bump min jaxlib version (thanks @hawkinsp)

* update travis to match min jaxlib version

* Added clearer error message for tracers in numpy.split (#2508)

* Added clearer error message for tracers in numpy.split

Now we print:

ConcretizationTypeError: Abstract tracer value where concrete value is expected (in
jax.numpy.split argument 1).
Use transformation parameters such as `static_argnums` for `jit` to avoid
tracing input values.
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-where-concrete-value-is-expected-error`.
Encountered value: Traced<ShapedArray>

* Fixed tests, slight change to the error message

* Expanded the FAQ entry about abstract tracers for higher-order primitives

* Added clarification for tracers inside jit of grad

* Updated FAQ language in response to reviews

* attempt to fix failing travis (numerical issues)

* Switch jaxlib Python code to use the lower-level xla.ops API when building XLA ops. (#2798)

Change in preparation for deleting xla_client.ComputationBuilder.

* small typo fix (#2799)

* loosen scipy convolve test tolerance (GPU flaky)

* Add explicit derivative for jax.numpy.linalg.pinv. (#2794)

* Add explicit derivative for jax.numpy.linalg.pinv.

* Fix type confusion problems in the JVP rule for SVD that meant it produced 64-bit tangents for 32-bit primals.

* add tanh rule (#2653)

change expit taylor rule

add manual expit check, check stability of expit and tanh

* use _max instead of max, fix #2795 (#2803)

* use _max instead of max, fix #2795

* revert xla.py python scalar casting='safe' check

* skip pinv test on tpu because no svd

* in custom_jvp/vjp stop_gradient on nondiff_argnums (#2804)

fixes #2784

* disable mypy checks causing new errors

* stop_gradient_p -> ad_util.py, re-enable some mypy (#2806)

* rewrite axis_index implementation, use custom bind (#2807)

* rewrite axis_index implementation, use custom bind

fixes #2716

Co-authored-by: Trevor Cai <tycai@google.com>

* add test for #2716

Co-authored-by: Trevor Cai <tycai@google.com>

* Remove usage of xla_client.{Computation,ComputationBuilder}. (#2808)

* Remove usage of xla_client.{Computation,ComputationBuilder}.

ComputationBuilder is a fairly pointless wrapper class that mimics an outdated version of the the C++ XLA API. It dates back from when we used to have SWIG bindings and needed to write a non-trivial Python shim to keep the interface pleasant to use. Now that we have pybind11-based bindings that are reasonably ergonomic by themselves, we don't need the wrapper class. Instead, we can simply call the pybind11-wrapped C++ API directly, removing the impedance mismatch between the C++ and Python APIs and allowing us to delete the Python ComputationBuilder class.

Similarly we can delete xla_client.Computation for the same reasons; it doesn't do anything useful on top of the C++ API.

* Make ShardedDeviceArray._value threadsafe again. (#2810)

Fixes #2759

* Pin mypy version in .travis.yml. (#2811)

This is recommended in https://mypy.readthedocs.io/en/stable/existing_code.html#continuous-integration, to avoid unexpected upgrades introducing new type errors.

* use static_argnums in xla_computation (#2812)

* use static_argnums in xla_computation

fixes #1017

* add static_argnums to make_jaxpr

* fix type error: handle int case

* implement jet rules by lowering to other primitives (#2816)

merge jet_test

add jet rules

use lax.square

* Feature/permutation (#1568)

* added test for random.permutation

* added permutation that wraps shuffle with behaviour of np.random.permutation

* update docstring

* need to shuffle also the integer range input

* fixed test for permutation with integer

* tweak handling of random.permutation scalar case

* NotImplementedError for random.permutation on >1d

pending resolution to #2066

* address reviewer comments: improve tests

Co-authored-by: Matthew Johnson <mattjj@google.com>

* Fix time issues in odeint reverse mode (#2817)

* Fix time issues in odeint reverse mode

* Add regression test

* add ode test file (#2818)

* add ode test file

* control test tolerances based on precision

* add more ode tests (#2819)

* Remove platform canonicalization from xla_bridge.py (#2815)

* Fix lax.rng_uniform. (#2830)

* only maximally stage out for some call primitives (#2834)

fixes #2833

* handle mapped_invars correctly in more places (#2828)

fixes #2822

We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
  1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
  2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
  3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
  4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
  5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.

The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).

This commit fixes those issues by
  1. making `mapped_invars` non-optional,
  2. handling `mapped_invars` correctly in
    * JaxprTrace.process_map
    * JVPTrace.process_map
    * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
    * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
  3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.

This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.

* Update deprecated API usages in lapack.pyx. (#2838)

* Replace uses of xla_client.Buffer.from_pyval() with backend.buffer_from_pyval(). (#2839)

Change in preparation for deleting xla_client.Buffer.

* Remove some tests for Jaxlib versions older than the minimum. (#2840)

* Enable some tests that now pass. (#2841)

* Custom derivative for np.linalg.det (#2809)

* Add vjp and jvp rules for jnp.linalg.det

* Add tests for new determinant gradients

* Replace index_update with concatenate in cofactor_solve

This avoids issues with index_update not having a transpose rule, removing one bug in the way of automatically converting the JVP into a VJP (still need to deal with the np.where).

* Changes to cofactor_solve so it can be transposed

This allows a single JVP rule to give both forward and backward derivatives

* Update det grad tests

All tests pass now - however second derivatives still do not work for nonsingular matrices.

* Add explanation to docstring for _cofactor_solve

* Fixed comment

* Fix typo in docstring for _cofactor_solve (#2844)

Found a small typo in the description of _cofactor_solve

* split testDetGradOfSingularMatrix into corank=1,2 (#2845)

* Add precision only arguments (#2850)

* Make precision argument keyword only in jax.numpy

* Fix private functions

* Simplify _odeint_rev (#2832)

* Update jaxpr.rst (#2859)

* Update jaxpr doc

* Make jaxpr.rst doctestable

* check step size is greater than zero (#2857)

loosen tols for grad test

set tol only for float64

* Add pmap_shard_device_array_benchmark. (#2864)

Also renames pmap_shard_args_benchmark to pmap_shard_sharded_device_array_benchmark.

* Fix chi-squared tests in random_test.py (#2847)

As far as I can tell, the previous implementation of the chi-squared test
for samples from discrete probability distributions was broken. It should have
been asserting that the p-value was greater 0.01, e.g., as illustrated here:
http://hamelg.blogspot.com/2015/11/python-for-data-analysis-part-25-chi.html

This hid a few other bugs, such a miscalculation of expected frequencies.

Fortunately, the existing random tests for Bernoulli and Categorical *mostly*
still pass, which the exception of multi-dimensional logits for Categorical.
Those tests are disabled by this PR.

* refactor ode tests, add scipy benchmark (#2824)

* refactor ode tests, add scipy benchmark

remove double import

rename to scipy merge vmap test properly

* clean up more global trace state after errors

Co-authored-by: Matthew Johnson <mattjj@google.com>

* Remove unused `ispure` method (#2781)

* Add population_count primitive to lax (#2753)

* add population_count primitive (needs new jaxlib)

fixes #2263

* Add popcount docs

* Add population_count to lax_reference

* Use int prng (since we're only testing uints)

Co-authored-by: Matthew Johnson <mattjj@google.com>

* Correct the order of .format arguments in vjp wrapper (#2866)

* Change isinstance test in xla_bridge.py to not explicitly name xla_client.Backend. (#2868)

Change in preparation for removing xla_client.Backend in favor of the underlying C++ classes.

* Add top_k jvp and batching rules

* Clarify that `grad` requires arguments to be differentiated to be of inexact type. (#2712)

* Fix definition of qr primitive to return only the upper triangular part of r. (#2870)

Issue #2863.

* Qr complex jvp fix (#2872)

* Fix qr jvp for complex input

* Fix qr jvp for complex64 inputs when jax_enable_x64=True

* Reenable complex jvp test for qr

* Updated README wrt. new features for Stax. (#2862)

* Updated README wrt. new features for Stax.

* fix sort_key_val return type annotation, docstring

* Document how jax.hessian and pytrees interact. (#2705)

* Document how jax.hessian and pytrees interact.

* add spacing to numpy.gradient (#2545)

* Implement nanargmin-max and add tests (#2398)

Co-authored-by: vlad <veryfakemail@ya.ru>

* Make dlpack code robust against upcoming XLA Python binding change. (#2876)

* Fix bug in ShardedDeviceArrayTest.testThreadsafeIndexing (#2875)

* Add nanargmin and nanargmax to documentation. (#2877)

* iterate on jax.hessian docs (#2873)

* iterate on jax.hessian docs

* tweaks

* add back note about block structure

* Add ReLU6, Hard sigmoid, swish (#2709)

* Fix slices in Gated Linear Unit activation (#2341)

* Check for unsupported dtypes and issue a helpful error. (#2885)

* Reset parameter replication default (#2880)

* Reset parameter replication default

* add tests

* Add relu6, hard_swish, and hard_sigmoid to docs. (#2886)

* Fix lax_reference implementation of round() to match lax. (#2894)

lax.round() is documented to round half away from zero, but np.round() rounds to nearest even.

* Make sure gather/scatter indices in lax gradient tests aren't out of bounds. (#2895)

Out-of-bounds gathers are clamped to be in bounds, but out-of-bounds scatters are dropped entirely. This can cause gradient tests to fail because the two operations aren't duals of one another, as the gradient rules expect.

* add jets for sines fns (#2892)

refactor

remove duplicate

* Fix jit with device placement (#2883)

In setups with multiple backends, a jit happens on the default
backend, unless we give a `backend` parameter. This is true
even if the inputs are committed to a device on the non-default
backend, or if we pass a `device` parameter to jit.

* Fix typo in tests; caught on GPU and TPU (#2902)

* err on empty operand in numpy argmin and argmax

fixes #2899

* Remove assert from ShardedDeviceArray staging. (#2908)

This would erroneously fail on Cloud TPU because the TPU client has its own buffer type.

* Update jax version to 0.1.65 (#2909)

* update changelog

* err on empty operand dimension in numpy argmin and argmax

see #2899

* revise xla.device_put device logic (#2907)

* revise xla.device_put device logic, fixes #2905

* remove test of behavior we don't want

Previously, we were testing that for a DeviceArray x, writing
jax.device_put(x) would evaluate to a DeviceArray *on the default
device*. Instead, we should be happy with just returning the same
DeviceArray without any movement.

* Add flag to enable checking, and turn on checking in tests. (#2900)

Fix an error in check_jaxpr.

* Fixed a few places where device sticky-ness was lost. Added FAQ  (#2882)

* Fixed a few places where device sitckyness was lost. Added FAQ for device
placement.

I have also added a new test (multi_device_test.test_computation_follows_data),
written more as part of the documentation. It is shorted than the
old test_computation_follows_data (which is still there, renamed
as test_computation_follows_data_old). I believe there is no
extra coverage in test_computation_follows_data_old w.r.t. all the
other tests we have.

* Fix mypy annotations and updates based on comments

* Undid some changes, will make another PR

* Relax some test tolerances. (#2917)

* Avoid tuple allreduce lowering of psum on TPUs (#2914)

Tuple-shaped allreduces aren't supported in an XLA:TPU optimization pass (see internal bug), but since our use of them on GPU is due to compiler nondeterminism that isn't present on TPU, it should be fine to avoid this bug by disabling tuple psum on TPU.

* Explicitly broadcast values in nn.one_hot and nn.initializers.orthogonal. (#2901)

At head the following fails:

```python
>>> import jax
>>> import jax.numpy as jnp
>>> jax.config.update('jax_numpy_rank_promotion', 'raise')
>>> jax.nn.one_hot(jnp.ones([8]), 512)
...
ValueError: Operands could not be broadcast together for equal on shapes (8, 1) (512,) and with the config option jax_numpy_rank_promotion='raise'. For more information, see https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.
```

* Fix test flakiness in autodiff tests for min/max type functions (#2918)

* Fix test flakiness in autodiff tests for clamp, reduce, and reduce-window.

We change the tests to avoid computing numerical gradients in the neighborhood of nondifferentiable points where, for example, the maximum element in a reduce-max changes. The autodiff approximation is only valid within an epsilon ball around a point, and close to an inflection point the approximation may not be valid.

* Only test reduce-grad-mul for float types.

* Reapply #2017 (Allow shapecheck of PixelCNN++), fixing #2245 (#2800)

* Unrevert "Allow shapecheck of PixelCNN++ (google#2017)"

This reverts commit ceab1e3.

* Fix out-of-bound slices (#2245)

* Minor

* Add type annotations

* Fix Poly.__rsub__

* any -> _any

* tweaks, mostly comments/whitespace

* separate polymorphic code path, patch _slice_sizes

* put back some logic for handling Poly sizes

* improve test_slice_indices

* Remove to_index, replace with canonicalize_shape

* Fix slicing with polymorphic start/stop

* Test negative step for polymorphic slicing

* Refactor polymorphic slicing

* Simplify diff

* Fix shapecheck(iota)

Co-authored-by: Matthew Johnson <mattjj@google.com>

* skip failing shapecheck tests

cc @juliuskunze

* Add support for in_axes=None (but not out_axes, or in_axes>0) to pmap (#2896)

* allow in_axes=None for pmap in api.py

* wire in_axes=None through parallel_callable

* add test

* fix error string

* fixes

* fixes

* add test for nested pmap with in_axes

* test pmap still defaults to (implicit) out_axes=0

* replace accidental use of jax.numpy.min w/ builtin

* revert previous change

* Deprecate random.shuffle() and implement random.permutation() for multi-dimensional matrices.

* instantiate zeros (#2924)

fix dtype

remove TODO

* Update XLA. (#2927)

* Update XLA. (#2929)

Includes a fix that may help with issue #2906.

* jax.random.poisson (#2805)

* jax.random.poisson

The implementation for lam < 10 was directly copied from TensorFlow probability:
https://github.com/tensorflow/probability/blob/v0.10.0-rc0/tensorflow_probability/python/internal/backend/numpy/random_generators.py#L155

I adapted the implementation for lam > 10 from TensorFlow:
https://github.com/tensorflow/tensorflow/blob/v2.2.0-rc3/tensorflow/core/kernels/random_poisson_op.cc

The methods themselves match both TensorFlow and NumPy:
https://github.com/numpy/numpy/blob/v1.18.3/numpy/random/src/distributions/distributions.c#L574

* add a check for even larger lambda

* increment iter count

* remove comment that makes no sense

* Fix chi-squared tests in random_test.py

As far as I can tell, the previous implementation of the chi-squared test
for samples from discrete probability distributions was broken. It should have
been asserting that the p-value was greater 0.01, e.g., as illustrated here:
http://hamelg.blogspot.com/2015/11/python-for-data-analysis-part-25-chi.html

This hid a few other bugs, such a miscalculation of expected frequencies.

Fortunately, the existing random tests for Bernoulli and Categorical *mostly*
still pass, which the exception of multi-dimensional logits for Categorical.
Those tests are disabled by this PR.

* Fix accept condition (based on correct chi-squared test)

* Add moment checks for Poisson

* Add batching test, more Poisson rates

* Update XLA. (#2932)

Mention illegal instruction fix in changelog.

* improve docs and error message for odeint *args (#2931)

cf. #2920

* reduce use of lax on static data (e.g. shapes) (#2933)

* reduce use of lax on static data (e.g. shapes)

* use f-string for error message

* support axis argument in nn.glu (#2879)

* support axis argument in nn.glu

* also add basic correctness test

* Update nn_test.py

* Fixed a few more places where device commitment was lost. (#2913)

* trivial jit computations were forcing commitment to the default device
* a device_put with a device specification would not set the commitment
  if the data was already (uncommitted) on the specified device.
* added tests for the above
* once the above were fixed the LaztTest.test_zeros_ones_compilation
  stated to fail because the `sticky` parameter to lazy_force_computation
  was changing. Fixed this by removing stickyness from the compilation key.
* Expanded docstring for jax.device_put; expanded the
  device placement FAQ entry.

* Fix a codeblock in the "understanding jaxpr" doc. (#2942)

This fixes an issue where the codeblock didn't render properly on the website.

* Update XLA to fix build failures. (#2950)

* Allow ConvDimensionNumbers to be passed into conv_transpose (#2915)

* Fix a number of flaky tests. (#2953)

* relax some test tolerances.
* disable 'random' preconditioner in CG test (#2951).
* ensure that scatter and top-k tests don't create ties.

* Fix spurious rank promotion warning. (#2954)

* DOC: add a table of contents for top level API docs (#2946)

This makes them easier to scan.

* DOC: write a new dosctring for jax.numpy.vectorize (#2944)

* DOC: write a new dosctring for jax.numpy.vectorize

This version is customized entirely for JAX.

* review and typo fixes

* Implementation numpy.ediff1d (#2729)

* Implementation of numpy.ediff1d

* Added testing for numpy.ediff1d implementation

* Made ediff1d jit-compatible

* Implemented corrections: style and more testing

* Adapted tests

* changed tests

* modified tests

* Incorporated changes

* Style changes

* Added line between tests

* Changed op_record test

* Add a note about jax.pmap when leading dim is smaller than num devices. (#2949)

* Cache test_utils.format_shape_and_dtype_string. (#2959)

A significant fraction of time when collecting test cases is spent building shape and dtype strings (which are usually similar and usually thrown away.)

* Raise an error in np.var when array is complex and dtype is not (#2288)

Co-authored-by: vlad <veryfakemail@ya.ru>

* add optional 'forward' argument to lax.scan (#2921)

* add optional 'forward' argument to lax.scan

* switch to reverse; revise disable-jit case

* fix jaxpr.rst

* fix loops.py

Co-authored-by: James Bradbury <jekbradbury@gmail.com>

* Improve JAX test PRNG APIs to fix correlations between test cases. (#2957)

* Improve JAX test PRNG APIs to fix correlations between test cases.

In #2863, we observed that we were missing gradient problems because the random test cases being generated were too similar because they were formed with identically seeded PRNGs. This change updates the test_util.rand_...() functions to take an explicit numpy.random.RandomState, and adds a rng() method to JaxTestCase to form a RandomState seeded on the test case name.

This gives the following properties:
* different test cases receive different seeds
* PRNG seeding is deterministic and independent of execution order and sharding.
* PRNG seeding is deterministic across runs.

* Fix some failing tests.

* Fix more test failures.

Simplify ediff1d implementation and make it more permissive when casting.

* Relax test tolerance of laplace CDF test.

* Fix tests for random.categorical with multi-dimensional logits (#2955)

* Expose functools.reduce initializer argument to tree_util.tree_reduce (#2935)

* Expose `functools.reduce` initializer argument to `tree_util.tree_reduce`.

`functools.reduce` takes an optional `initializer` argument (default=None) which is currently not exposed by `tree_reduce'. This can be useful e.g. for computing an L2 penalty, where you would initialize with 0., and then sum the L2 for each parameter.

Example:
```
def l2_sum(total, param):
  return total + jnp.sum(param**2)

tree_reduce(l2_sum, params, 0.)
```

* Only call functools.reduce with initializer when it is not None.

* Change logic to check for number of args to allow None value as initializer

* Rename seq to tree, and add tree_leaves

* Change reduce to functools.reduce.

* Make tree_reduce self-documenting

* Replace jax.tree_leaves with tree_leaves

* Update to use custom sentinel instead of optional position argument

* jax.tree_leaves -> tree_leaves

* Update README for jaxlib 0.1.46 release. (#2968)

* New and improved _shard_device_array function. (#2958)

This gets the performance of sharding DeviceArray arguments to pmap roughly back to what it was prior to 07571ae. It does so by re-introducing a _shard_device_array function that can handle arbitrary array slices.

Benchmark results compared to 87d9590 (i.e. just prior to the regression):
```
---------Benchmark summary for pmap_shard_device_array---------
  nargs    nshards       mean      %std    relative    mean/baseline
-------  ---------  ---------  --------  ----------  ---------------
     10          8  0.0479975  12.0865      1                1.09631
    100          8  0.32916     5.7446      6.85786          1.10263
    500          8  1.5563      2.68041    32.4246           1.10066
    100          2  0.136431    8.33826     2.84245          1.15886
    100          4  0.198815    5.91716     4.1422           1.11409
    100          8  0.31788     4.80559     6.62285          1.06637
```

This still seems a bit slower than it was before, but gets most of the performance back. We can further optimize in future changes if needed.

Fixes #2958 (hopefully)

* Relax test tolerances, suppress warning messages. (#2967)

* Update jax version to 0.1.66 (#2970)

* Replace np -> jnp, onp -> np in tests. (#2969)

* Replace np -> jnp, onp -> np in examples/ (#2971)

For context, see #2370

* Replace np -> jnp, onp -> np in more places. (#2973)

* Replace np -> jnp, onp -> np in more places.

Context: #2370

* Fix typo in random_test.py

* Suppress pytype error (#2974)

pytype gets confused otherwise:
```
File ".../pxla.py", line 244, in _as_slice_indices: bad option in return type [bad-return-type]
           Expected: Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]
  Actually returned: Tuple[Tuple[Union[Tuple[Union[int, slice], ...], slice], ...], tuple, Tuple[int, ...]]
```

* Add IgammaGradA (#2504)

* Update internal aliases to lax_numpy to jnp instead of np. (#2975)

* Update XLA. (#2977)

* Update grad of while_loop message. (#2976)

The previous error message was misleading as of
ed8dbd2
(see #2414 (comment)
for context).

* Add decorator for performing broadcasting inside translation rules (#2468)

* Add decorator for broadcasting at the translation rule layer.

* Fix broadcasting in igamma gradients.

Co-authored-by: Peter Hawkins <phawkins@google.com>

* Fix some bugs in _shards_device_array path. (#2983)

Also adds more comprehensive unit tests.

* Adjust test tolerances for TPU. (#2984)

Ideally this is temporary, as the tolerances are getting high.

* Adjust test tolerances take 2 (#2985)

* Add jnp.unravel_index (#2966)

* Cleanup: move _wraps into jax.numpy._utils. (#2987)

Why? This prevents circular imports within the numpy submodule.

* DOC: add unravel_index to docs/jax.numpy.rst (forgotten in #2966) (#2989)

* Add implementation of np.searchsorted (#2938)

* Add copyright to new file (#2992)

* Fix IntEnum test when checking is enabled. (#2981)

* Added argument check to all primitives. (#2948)

* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previosuly if core.skip_checks == False
because then `bind` checks its arguments. I have essentially
added an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and
`numpy` would report the error somehow, perhaps.

* Merged find_top_trace with check_args

* Fix pytype for copybara import (#2995)

* Undo strict checking of LAX primitives (#2996)

This undoes d08dec5d20

* An experiment for id_print implemented with outfeed

* Added print descriptors, support multiple types
* Added a state-passing mechanism to XLA interpreter

* Fixed scan, and grad. Added multiplexing protocol.

* Ensure that we carry state only for control-flow conditionals that use print

* Added masking transformation, added batch_dims to vmap

* Added support for multiple backends to outfeed receiver

Changed the encoding of the header to be uin32

* Added support for sending all arrays in a single message

* Added error handling for tap function errors

* Implemented pytree support for arg and result.

Enabled outfeed for all arrays as a tuple

* Added error checking when starting compiled computations without starting
the outfeed receiver.

* Improved documentation

* Reimplemented the passing of tokens with a Jaxpr transform

* Unified the eager and jit paths

Added error checking for outfeed_receiver not started to primitive computations

* Use a whitelist to limit visibility of exported names in jax.numpy. (#2978)

* Use a whitelist to limit visibility of exported names in jax.numpy.

Prevents unintentional exports of non-public names in the API.

* Undo the id_print/id_tap feature (PR #2791)

Crashes on Travis with the latest 0.1.46. Need to figure out what is going on

* Implement np.bincount (#2986)

* Use a whitelist to restrict visibility in top-level jax namespace. (#2982)

* Use a whitelist to restrict visibility in top-level jax namespace.

The goal of this change is to capture the way the world is (i.e., not break users), and separately we will work on fixing users to avoid accidentally-exported APIs.

* Prepare version 0.1.47 for jaxlib (#3008)

* Remove `jax.np` from the jax namespace (use `jax.numpy` instead). (#3010)

* Implementation of id_tap/id_print using outfeed. (#3006)

This was already merged as #2791 but reverted due to XLA crashes.

This reverts commit 769d703.

* Use a whitelist to clean up exported symbols in the jax.lax namespace. (#3012)

* Add Colab test notebooks for CPU, GPU, and TPU (#3000)

* Support axis_index_groups in allreduce collectives (#2382)

* support replica groups in allreduce collectives

* add test and fix jaxpr in docs

* switch from XLA replica IDs to JAX axis indices

* fix psum transpose rule

* test other nesting order + imperfect nesting

* update jaxpr.rst

* handle None case

* add note+check that groups  cover the index space

* switch split_axis assert to NotImplementedError

* update CHANGELOG

* improve pmap static broadcasted kwarg error msg (#3018)

fixes #3007

* sort supported dtypes in host_callback_test.py (#3020)

* sort supported dtypes in host_callback_test.py

This fixes issues I ran into with running `pytest -n auto
tests/host_callback_test.py` or similar.

* remove unused import

* Fix links in our developer docs (#3019)

The previous versions weren't valid RST.

Ironically, this was in the section with instructions on how to preview
changes to our documentation!

* Modify linspace so that endpoints equal the inputs. (#3016)

* Implement np.digitize (#3003)

* Implement np.nanvar and np.nanstd (#2310)

* Implement nanvar & nanstd

Add tests for nanvar & nanstd

* Clean up bfloat16 tests for np.nanvar and np.nanstd

* add nanvar & nanstd to the whitelist

ignore numpy ddof warnings

* Use a new variable for static_broadcasted_argnums as a tuple. (#3027)

* Use a new variable for static_broadcasted_argnums as a tuple.

This works around a bug in pytype (b/156151503).

* Fix uses of deprecated onp. in pmap_test (#3028)

* Update numpy references to use np. Added to Changelog (#3029)

Co-authored-by: Peter Hawkins <phawkins@google.com>
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
Co-authored-by: Matthew Johnson <mattjj@csail.mit.edu>
Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
Co-authored-by: Lauro Langosco di Langosco <langosco.lauro@gmail.com>
Co-authored-by: John Aslanides <aslanides@users.noreply.github.com>
Co-authored-by: John Aslanides <jaslanides@google.com>
Co-authored-by: Stephan Hoyer <shoyer@google.com>
Co-authored-by: Daniel Johnson <ddjohnson@google.com>
Co-authored-by: Chris Jones <cjfj@google.com>
Co-authored-by: Jamie Townsend <jamiehntownsend@gmail.com>
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Adam Paszke <apaszke@google.com>
Co-authored-by: Jacob Kelly <jacob.jin.kelly@gmail.com>
Co-authored-by: Adam Paszke <adam.paszke@gmail.com>
Co-authored-by: Lucas Beyer <lucasb.eyer.be@gmail.com>
Co-authored-by: Oliver Åstrand <oliver.astrand@gmail.com>
Co-authored-by: James Bradbury <jekbradbury@google.com>
Co-authored-by: William C Grisaitis <wgrisaitis@gmail.com>
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Yufeng <yufengg@users.noreply.github.com>
Co-authored-by: Trevor Cai <tycai@google.com>
Co-authored-by: MichaelMarien <marien.mich@gmail.com>
Co-authored-by: samuela <skainsworth@gmail.com>
Co-authored-by: Jon Malmaud <malmaud@google.com>
Co-authored-by: David Pfau <pfau@google.com>
Co-authored-by: Abhishek Sharma <abhishekshrm53@gmail.com>
Co-authored-by: Jamie Townsend <jamestownsend@google.com>
Co-authored-by: Anselm Levskaya <levskaya@google.com>
Co-authored-by: Anselm Levskaya <levskaya@gmail.com>
Co-authored-by: Paige Bailey <webpaige@google.com>
Co-authored-by: Eduardo Pignatelli <eduardo.pignatelli@burohappold.com>
Co-authored-by: yurodiviy <44850998+yurodiviy@users.noreply.github.com>
Co-authored-by: vlad <veryfakemail@ya.ru>
Co-authored-by: Vaibhav Balloli <balloli.vb@gmail.com>
Co-authored-by: Martin Sotir <martinsotir@gmail.com>
Co-authored-by: Tom Hennigan <tomhennigan@google.com>
Co-authored-by: Julius Kunze <juliuskunze@gmail.com>
Co-authored-by: Roman Ring <inoryy@gmail.com>
Co-authored-by: tamaranorman <tamaranorman@google.com>
Co-authored-by: joschkabraun <47435119+joschkabraun@users.noreply.github.com>
Co-authored-by: James Bradbury <jekbradbury@gmail.com>
Co-authored-by: Joost Bastings <bastings@users.noreply.github.com>
Co-authored-by: Srinivas Vasudevan <srvasude@google.com>
Co-authored-by: notEvil <a_rappold@gmx.at>
Co-authored-by: Matt Wescott <mattwescott@protonmail.com>
@gnecula gnecula mentioned this pull request May 26, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants