Skip to content

Commit

Permalink
[math] Update einsum (#20)
Browse files Browse the repository at this point in the history
* Update _compat_numpy_misc.py

* Update _compat_numpy_misc.py
  • Loading branch information
Routhleck authored Jun 20, 2024
1 parent 1942e0f commit df0cc35
Showing 1 changed file with 25 additions and 8 deletions.
33 changes: 25 additions & 8 deletions brainunit/math/_compat_numpy_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from __future__ import annotations

import collections
from collections.abc import Sequence
from typing import (Callable, Union, Tuple, Any, Optional)

Expand Down Expand Up @@ -121,6 +122,16 @@ def broadcast_shapes(*shapes):
return jnp.broadcast_shapes(*shapes)


def _default_poly_einsum_handler(*operands, **kwargs):
dummy = collections.namedtuple('dummy', ['shape', 'dtype'])
dummies = [dummy(tuple(d if type(d) is int else 8 for d in x.shape), x.dtype)
if hasattr(x, 'dtype') else x for x in operands]
mapping = {id(d): i for i, d in enumerate(dummies)}
out_dummies, contractions = opt_einsum.contract_path(*dummies, **kwargs)
contract_operands = [operands[mapping[id(d)]] for d in out_dummies]
return contract_operands, contractions


def einsum(
subscripts: str,
/,
Expand Down Expand Up @@ -176,7 +187,6 @@ def einsum(
if not non_constant_dim_types:
contract_path = opt_einsum.contract_path
else:
from jax._src.numpy.lax_numpy import _default_poly_einsum_handler
contract_path = _default_poly_einsum_handler

operands, contractions = contract_path(*operands, einsum_call=True, use_blas=True, optimize=optimize)
Expand All @@ -201,15 +211,22 @@ def einsum(
else:
if isinstance(operands[i + 1], Quantity):
unit = unit * operands[i + 1].dim
operands = [op.value if isinstance(op, Quantity) else op for op in operands]

contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
r = jnp.einsum(subscripts,
*operands,
precision=precision,
preferred_element_type=preferred_element_type,
_dot_general=_dot_general)

einsum = jax.jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True)
if spec is not None:
einsum = jax.named_call(einsum, name=spec)
operands = [op.value if isinstance(op, Quantity) else op for op in operands]
r = einsum(operands, contractions, precision, # type: ignore[operator]
preferred_element_type, _dot_general)
# contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
#
# einsum = jax.jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True)
# if spec is not None:
# einsum = jax.named_call(einsum, name=spec)

# r = einsum(operands, contractions, precision, # type: ignore[operator]
# preferred_element_type, _dot_general)
if unit is not None:
return Quantity(r, dim=unit)
else:
Expand Down

0 comments on commit df0cc35

Please sign in to comment.