Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Print doc #3032

Merged
merged 394 commits into from
May 11, 2020
Merged

Print doc #3032

merged 394 commits into from
May 11, 2020

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented May 11, 2020

Testing readthedocs

hawkinsp and others added 30 commits April 13, 2020 16:16
* Implement jax.ops.index_mul.

* Add index_mul to documentation.

* Fix RHS JVP rule for scatter_mul, fix test bug that meant it was not tested.

* Fix typo in docstring.
add custom_jvp for logaddexp / logaddexp2
Add missing functions to autodoc lists
* Fix minor typo in cell

One of the arguments to `hvp` wasn't being used, which made the example slightly confusing.

* Fix both definitions of hvp in the autodiff cookbook.

Co-authored-by: Peter Hawkins <phawkins@google.com>
to_dlpack now takes ownership of the original buffer, leaving it in an invalid state.
* Add type annotations to optix.

* Fix function signature for chain() and remove unused collections import.

* Include Sequence[OptState] as possible output of Init.
* Update np.linalg docs with missing funcitons

* Update np.linalg docs with missing funcitons
* Implement numpy fmin() & fmax()
Add some type annotations to jax.random and jnp.ndarray.
jet of pow using comp with exp, mul, log
…ax-ml#2717)

* Fixup complex values and tol in tests for jax.scipy.linalg.sparse.cg

The tests for CG were failing on TPUs:

- `test_cg_pytree` is fixed by requiring slightly less precision than the
  unit-test default.
- `test_cg_against_scipy` is fixed for complex values in two independent ways:
  1. We don't set both `tol=0` and `atol=0`, which made the termination
     behavior of CG (convergence or NaN) dependent on exactly how XLA handles
     arithmetic with denormals.
  2. We make use of *real valued* inner products inside `cg`, even for complex
     values. It turns that all these inner products are mathematically
     guaranteed to yield a real number anyways, so we can save some flops and
     avoid ill-defined comparisons of complex-values (see
     numpy/numpy#15981) by ignoring the complex part
     of the result from `jnp.vdot`. (Real numbers also happen to have the
     desired rounding behavior for denormals on TPUs, so this on its own would
     also fix these failures.)

* comment fixup

* fix my comment
We choose the same set as TensorFlow (minus 3.7, which TF is apparently considering dropping anyway).

This avoids a slow PTX -> SASS compilation on first time startup.
…n rule is an XlaOp. (jax-ml#2723)

Helps give a more understandable error on erroneous translation rules.
…-ml#2142)

This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication.

This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks.

Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!).

**pmap_shard_args**
```
---------Benchmark summary for pmap_shard_args---------
  nargs    nshards       mean       %std    relative    mean/baseline
-------  ---------  ---------  ---------  ----------  ---------------
     10          8  0.041855    4.15223      1               1.01466
    100          8  0.129884    4.85321      3.1032          0.988543
    101          8  0.136347    6.20233      3.2576          0.967138
    500          8  0.533207    3.6815      12.7394          1.0294
   1000          8  1.10338     0.525193    26.362           0.960435
   5000          8  5.33911     0          127.562           0.963319
    100          2  0.0638619  10.7069       1.52579         1.0362
    100          4  0.0868253   6.76701      2.07443         0.967323
    100          8  0.128151    6.46004      3.06177         0.979742
    100        100  1.22631     1.94885     29.299           1.00371
    100        500  6.60746     0          157.865           0.956657
```
**pmap_shard_outputs**
```
  nouts    nshards        mean       %std    relative    mean/baseline
-------  ---------  ----------  ---------  ----------  ---------------
     10          8   0.0664526   9.49251      1               0.938466
    100          8   0.195711    2.19429      2.94512         1.04239
    500          8   0.82577     0.330864    12.4265          0.994669
   1000          8   1.68323     1.0516      25.3298          0.966915
   5000          8   8.89032     0          133.784           0.998038
    100          2   0.074806   10.1734       1.12571         0.980254
    100          4   0.121334    5.76774      1.82588         1.02033
    100          8   0.185253    5.45068      2.78775         1.01666
    100        100   2.37076     0           35.6759          1.08629
    100        500  17.0832      0          257.074           0.976879
```
**ShardedDeviceArray_indexing**
```
indices_fn                mean     %std    relative    mean/baseline
------------------  ----------  -------  ----------  ---------------
integer_indices      0.0603473  8.29159       1             0.359496
integer_2D_indices  18.0241     0           298.672         1.00583
```

This is how I ran the benchmark:
```
TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7>
```
Without this, pytype (correctly) points out that AbstractValues do not have shape/type information.
* Implement numpy.linalg.multi_dot

* Thread precision through multi_dot
gnecula and others added 26 commits May 7, 2020 16:24
Enabled outfeed for all arrays as a tuple
Added error checking for outfeed_receiver not started to primitive computations
An implementation of id_tap and id_print using outfeed
…ax-ml#2978)

* Use a whitelist to limit visibility of exported names in jax.numpy.

Prevents unintentional exports of non-public names in the API.
Crashes on Travis with the latest 0.1.46. Need to figure out what is going on
…ax-ml#2982)

* Use a whitelist to restrict visibility in top-level jax namespace.

The goal of this change is to capture the way the world is (i.e., not break users), and separately we will work on fixing users to avoid accidentally-exported APIs.
This was already merged as jax-ml#2791 but reverted due to XLA crashes.

This reverts commit 769d703.
* support replica groups in allreduce collectives

* add test and fix jaxpr in docs

* switch from XLA replica IDs to JAX axis indices

* fix psum transpose rule

* test other nesting order + imperfect nesting

* update jaxpr.rst

* handle None case

* add note+check that groups  cover the index space

* switch split_axis assert to NotImplementedError

* update CHANGELOG
* sort supported dtypes in host_callback_test.py

This fixes issues I ran into with running `pytest -n auto
tests/host_callback_test.py` or similar.

* remove unused import
The previous versions weren't valid RST.

Ironically, this was in the section with instructions on how to preview
changes to our documentation!
* Implement nanvar & nanstd

Add tests for nanvar & nanstd

* Clean up bfloat16 tests for np.nanvar and np.nanstd

* add nanvar & nanstd to the whitelist

ignore numpy ddof warnings
…#3027)

* Use a new variable for static_broadcasted_argnums as a tuple.

This works around a bug in pytype (b/156151503).
@gnecula gnecula merged commit 574db2a into jax-ml:test-docs May 11, 2020
@gnecula gnecula deleted the print_doc branch May 11, 2020 11:00
@googlebot
Copy link
Collaborator

All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter.

We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only @googlebot I consent. in this pull request.

Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the cla label to yes (if enabled on your project).

ℹ️ Googlers: Go here for more info.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.