Skip to content

Latest commit



211 lines (154 loc) · 7.98 KB

File metadata and controls

211 lines (154 loc) · 7.98 KB




  • Add support for with mlx (mlx now supports einsum).
  • Add support for einx.vmap, einx.{index|get_at|set_at|...} with mlx. Both use mx.vmap internally which is in development and does not fully support all operations that can be expressed in einx. If this causes an error during an einx call, the error can be reproduced without einx by copying and calling the Python function generated by einx directly (see Just-in-time compilation).


  • When initializing a backend, delay raising an exception until the backend is used in an operation.


  • Fix exception that is raised when einx is run with torch<2.
  • Add workaround for torch._dynamo.exc.InternalTorchDynamoError that occurs with torch.compile on torch==2.4.



  • Add partial support for tinygrad.
    • Supported:
      • einx.rearrange
      • einx.{elementwise|add|multiply|where|...}
      • einx.{reduce|sum|mean|...}
      • einx.{vmap_with_axis|flip|softmax|...}
    • Not supported:
      • einx.vmap (no vmap in tinygrad)
      • einx.{index|get_at|set_at|...} (due to relying on einx.vmap)


  • Use tf.gather_nd instead of x[y] to implement einx.get_at for Tensorflow.


  • Allow empty tuples and lists as constraints for ellipsis parameters.
  • Fix shorthand notation in




  • Fix bug when calling einx from multiple threads. (Run unit tests also in multi-threaded context.)



  • Remove einx dependency in compiled code: The code for a traced function now directly imports and uses the namespace of the backend (e.g. import torch). For example:
    >>> print("b q (h c), b k (h c) -> b q k h", x, y, h=16, graph=True))
    import torch
    def op0(i0, i1):
        x0 = torch.reshape(i0, (16, 768, 16, 64))
        x1 = torch.reshape(i1, (16, 768, 16, 64))
        x2 = torch.einsum("abcd,aecd->abec", x0, x1)
        return x2
    In most cases, compiled functions now contain no reference to other einx code.
  • Improve handling of Python scalars: (see #7) einx now only converts int, float and bool to tensor objects (e.g. via torch.asarray) if the backend function that is called does not support Python scalars (previously all inputs were converted to tensor objects). When using PyTorch, the device argument will be used to place the constructed tensor on the correct device.
    For example, torch.add supports Python scalars
    >>> print(einx.add("a,", x, 1, graph=True))
    import torch
    def op0(i0, i1):
        x0 = torch.add(i0, i1)
        return x0
    while torch.maximum does not:
    >>> print(einx.maximum("a,", x, 1, graph=True))
    import torch
    def op0(i0, i1):
        x0 = torch.asarray(i1, device=i0.device)
        x1 = torch.maximum(i0, x0)
        return x1
  • Run unit tests for PyTorch and Jax also on the GPU (if it is available).
  • Run unit tests also with jax.jit and torch.compile.




  • Add partial support for Apple's mlx.
    • Supported:
      • einx.rearrange
      • einx.{elementwise|add|multiply|where|...}
      • einx.{reduce|sum|mean|...}
      • einx.{vmap_with_axis|flip|softmax|...}
    • Not supported yet:
      • (mx.einsum is not implemented yet)
      • einx.vmap (mx.vmap does not fully support all primitives yet)
      • einx.{index|get_at|set_at|...} (due to relying on einx.vmap)
  • Add partial support for dask.array.
    • Supported:
      • einx.rearrange
      • einx.{elementwise|add|multiply|where|...}
      • einx.{reduce|sum|mean|...}
      • einx.{vmap_with_axis|flip|softmax|...}
    • Not supported:
      • einx.vmap (vmap not implemented in dask)
      • einx.{index|get_at|set_at|...} (due to relying on einx.vmap)
  • Add environment variable EINX_WARN_ON_RETRACE to warn when excessive retracing is detected.


  • Allow -> and , to be composed with other operators. (This deprecates the existing [|] notation which should instead be implemented with composable ->. The feature is still maintained for backwards compatibility). For example:
    •"b [c1->c2]", ...) expands to"b [c1] -> b [c2]", ...)
    • einx.get_at("b p [i,->]", ...) expands to einx.get_at("b p [i], b p -> b p", ...)
  • Allow einx.{set_at|add_at|...} to be called with zero-sized updates or coordinates (in which case the input tensor is returned as-is).
  • Remove which was not used anywhere but in the unit tests.
  • Improve error reporting:
    • Drop internal stack frames when raising exceptions.
    • Better error when passing invalid shape constraints to einx functions.
  • Reduce overhead of einx when using the PyTorch backend.


  • Fix compatibility of einx.nn.torch.Norm with PyTorch 2.2.
  • Fix parameters in einn.param being ignored.
  • Fix bug when using concatenations in einx.rearrange. See: #6
  • Fix broadcasting new axes in einx.vmap_with_axis.
  • Disable torch.compile during graph construction using torch.compiler.disable.



  • Add option to install einx via pip install einx[torch] or pip install einx[keras] to enforce version requirements on PyTorch or Keras.


  • Fail gracefully and report error when run with incompatible version of PyTorch and Keras.


  • Fix compatibility with 2.0 <= PyTorch < 2.1.



  • Add type annotations to public API.
  • Allow passing multiple coordinate tensors in einx.{get_at|set_at|...}.
  • Allow implicit output shape in einx.{set_at|add_at|...}.
  • Allow passing backend with string argument to einx.nn.norm.
  • Make backends accessible as einx.backend.{NAME} once they are loaded.


  • Refactor tracing:

    • Trace vmapped functions (previously kept a pointer to an untraced function).
    • Add shape assertion when calling unsafe functions.
    • Add comments for better inspection.
    • Remove pass_backend argument from einx.vmap.
    • Cache different functions for different backends.
    • Don't call backend.to_tensor if input already has correct type.

    For example, tracing einx.get_at now gives the following jit-compiled code:

    >>> print(einx.get_at("b [h w] c, b p [2] -> b p c", x, y, graph=True))
    # backend: einx.backend.numpy
    def op1(i0, i1):
        x1 = i1[:, 0]
        x2 = i1[:, 1]
        x0 = backend.get_at(i0, (x1, x2))
        return (x0,)
    def op0(i0, i1, op1=op1):
        op2 = backend.vmap(op1, in_axes=(0, 0), out_axes=(0,))
        op3 = backend.vmap(op2, in_axes=(3, None), out_axes=(2,))
        x0 = op3(i0, i1)
        return x0[0]


  • Fix bug when using "1" as coordinate axis in einx.index.
  • Add workaround for scalar indexing operations with torch.vmap (see
  • Fix support for list/ tuple arguments as tensors with non-trivial shape.
  • Change einx.reduce to accept only single tensors as arguments (API allowed multiple arguments, but was not implemented).
  • Don't trace and jit functions if EINX_CACHE_SIZE=0.
  • Fix bug where some static code analysis tools fail to recognize function specializations.