-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Print doc #3032
Print doc #3032
Conversation
* Implement jax.ops.index_mul. * Add index_mul to documentation. * Fix RHS JVP rule for scatter_mul, fix test bug that meant it was not tested. * Fix typo in docstring.
add custom_jvp for logaddexp / logaddexp2
Add missing functions to autodoc lists
* Fix minor typo in cell One of the arguments to `hvp` wasn't being used, which made the example slightly confusing. * Fix both definitions of hvp in the autodiff cookbook. Co-authored-by: Peter Hawkins <phawkins@google.com>
to_dlpack now takes ownership of the original buffer, leaving it in an invalid state.
* Add type annotations to optix. * Fix function signature for chain() and remove unused collections import. * Include Sequence[OptState] as possible output of Init.
* Update np.linalg docs with missing funcitons * Update np.linalg docs with missing funcitons
* Implement numpy fmin() & fmax()
Add some type annotations to jax.random and jnp.ndarray.
jet of pow using comp with exp, mul, log
…ax-ml#2717) * Fixup complex values and tol in tests for jax.scipy.linalg.sparse.cg The tests for CG were failing on TPUs: - `test_cg_pytree` is fixed by requiring slightly less precision than the unit-test default. - `test_cg_against_scipy` is fixed for complex values in two independent ways: 1. We don't set both `tol=0` and `atol=0`, which made the termination behavior of CG (convergence or NaN) dependent on exactly how XLA handles arithmetic with denormals. 2. We make use of *real valued* inner products inside `cg`, even for complex values. It turns that all these inner products are mathematically guaranteed to yield a real number anyways, so we can save some flops and avoid ill-defined comparisons of complex-values (see numpy/numpy#15981) by ignoring the complex part of the result from `jnp.vdot`. (Real numbers also happen to have the desired rounding behavior for denormals on TPUs, so this on its own would also fix these failures.) * comment fixup * fix my comment
We choose the same set as TensorFlow (minus 3.7, which TF is apparently considering dropping anyway). This avoids a slow PTX -> SASS compilation on first time startup.
…n rule is an XlaOp. (jax-ml#2723) Helps give a more understandable error on erroneous translation rules.
…-ml#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
Without this, pytype (correctly) points out that AbstractValues do not have shape/type information.
* Implement numpy.linalg.multi_dot * Thread precision through multi_dot
Enabled outfeed for all arrays as a tuple
…ting the outfeed receiver.
Added error checking for outfeed_receiver not started to primitive computations
An implementation of id_tap and id_print using outfeed
…ax-ml#2978) * Use a whitelist to limit visibility of exported names in jax.numpy. Prevents unintentional exports of non-public names in the API.
Crashes on Travis with the latest 0.1.46. Need to figure out what is going on
Undo the id_print/id_tap feature (PR jax-ml#2791)
…ax-ml#2982) * Use a whitelist to restrict visibility in top-level jax namespace. The goal of this change is to capture the way the world is (i.e., not break users), and separately we will work on fixing users to avoid accidentally-exported APIs.
This was already merged as jax-ml#2791 but reverted due to XLA crashes. This reverts commit 769d703.
* support replica groups in allreduce collectives * add test and fix jaxpr in docs * switch from XLA replica IDs to JAX axis indices * fix psum transpose rule * test other nesting order + imperfect nesting * update jaxpr.rst * handle None case * add note+check that groups cover the index space * switch split_axis assert to NotImplementedError * update CHANGELOG
* sort supported dtypes in host_callback_test.py This fixes issues I ran into with running `pytest -n auto tests/host_callback_test.py` or similar. * remove unused import
The previous versions weren't valid RST. Ironically, this was in the section with instructions on how to preview changes to our documentation!
* Implement nanvar & nanstd Add tests for nanvar & nanstd * Clean up bfloat16 tests for np.nanvar and np.nanstd * add nanvar & nanstd to the whitelist ignore numpy ddof warnings
…#3027) * Use a new variable for static_broadcasted_argnums as a tuple. This works around a bug in pytype (b/156151503).
All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the ℹ️ Googlers: Go here for more info. |
Testing readthedocs