Skip to content

Latest commit

 

History

History
180 lines (134 loc) · 4.74 KB

index.rst

File metadata and controls

180 lines (134 loc) · 4.74 KB

JAX: High performance array computing

JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.

.. grid:: 3
   :margin: 0
   :padding: 0
   :gutter: 0

   .. grid-item-card:: Familiar API
      :columns: 12 6 6 4
      :class-card: sd-border-0
      :shadow: None

      JAX provides a familiar NumPy-style API for ease of adoption by researchers and engineers.

   .. grid-item-card:: Transformations
      :columns: 12 6 6 4
      :class-card: sd-border-0
      :shadow: None

      JAX includes composable function transformations for compilation, batching, automatic differentiation, and parallelization.

   .. grid-item-card:: Run anywhere
      :columns: 12 6 6 4
      :class-card: sd-border-0
      :shadow: None

      The same code executes on multiple backends, including CPU, GPU, & TPU

.. grid:: 3

    .. grid-item-card:: :material-regular:`rocket_launch;2em` Getting started
      :columns: 12 6 6 4
      :link: beginner-guide
      :link-type: ref
      :class-card: getting-started

    .. grid-item-card:: :material-regular:`library_books;2em` User guides
      :columns: 12 6 6 4
      :link: user-guides
      :link-type: ref
      :class-card: user-guides

    .. grid-item-card:: :material-regular:`laptop_chromebook;2em` Developer notes
      :columns: 12 6 6 4
      :link: contributor-guide
      :link-type: ref
      :class-card: developer-docs

If you're looking to train neural networks, use Flax and start with its tutorials. For an end-to-end transformer library built on JAX, see MaxText.

Ecosystem

JAX itself is narrowly-scoped and focuses on efficient array operations & program transformations. Built around JAX is an evolving ecosystem of machine learning and numerical computing tools; the following is just a small sample of what is out there:

.. grid:: 4
    :class-container: ecosystem-grid

    .. grid-item:: :material-outlined:`hub;2em` **Neural networks**

       - Flax_
       - NNX_
       - Equinox_
       - Keras_

    .. grid-item:: :material-regular:`show_chart;2em` **Optimizers & solvers**

       - Optax_
       - Optimistix_
       - Lineax_
       - Diffrax_

    .. grid-item:: :material-outlined:`storage;2em` **Data loading**

       - Grain_
       - `Tensorflow datasets`_
       - `Hugging Face datasets`_

    .. grid-item:: :material-regular:`construction;2em` **Miscellaneous tools**

       - Orbax_
       - Chex_

    .. grid-item:: :material-regular:`lan;2em` **Probabilistic programming**

       - Blackjax_
       - Numpyro_
       - PyMC_

    .. grid-item:: :material-regular:`bar_chart;2em` **Probabilistic modeling**

       - `Tensorflow probabilty`_
       - Distrax_

    .. grid-item:: :material-outlined:`animation;2em` **Physics & simulation**

       - `JAX MD`_
       - Brax_

    .. grid-item:: :material-regular:`language;2em` **LLMs**

       - MaxText_
       - AXLearn_
       - Levanter_
       - EasyLM_


Many more JAX-based libraries have been developed; the community-run Awesome JAX page maintains an up-to-date list.

.. toctree::
   :hidden:
   :maxdepth: 1
   :caption: Getting started

   installation
   quickstart

.. toctree::
   :hidden:
   :maxdepth: 1

   tutorials

   notebooks/Common_Gotchas_in_JAX

   faq

.. toctree::
   :hidden:
   :maxdepth: 2
   :caption: More guides/resources

   user_guides
   advanced_guide
   contributor_guide
   extensions
   notes
   jax


.. toctree::
   :hidden:
   :maxdepth: 1

   changelog
   glossary