Skip to content

Commit

Permalink
Merge pull request #4 from google/make-array
Browse files Browse the repository at this point in the history
Made array conversion optional
  • Loading branch information
patrick-kidger authored Jul 25, 2022
2 parents 10d7751 + 6e8cbfd commit 5c6e80e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 18 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ params = jax.tree_leaves(mod) # 1.0 and 2.0 are parameters.
## Documentation

```python
sympytorch.SymbolicModule(expressions, extra_funcs=None)
sympytorch.SymbolicModule(expressions, extra_funcs=None, make_array=True)
```

Where:
- `expressions` is a PyTree of SymPy expressions.
- `extra_funcs` is an optional dictionary from SymPy functions to JAX operations, to extend the built-in translation rules.
- `make_array` is whether integers/floats/rationals should be stored as Python integers/etc., or as JAX arrays.

Instances can be called with key-value pairs of symbol-value, as in the above example.

Expand Down
2 changes: 1 addition & 1 deletion sympy2jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
from .sympy_module import SymbolicModule


__version__ = "0.0.2"
__version__ = "0.0.3"
52 changes: 36 additions & 16 deletions sympy2jax/sympy_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,19 @@ def sympy(self, memodict: _IdDict, func_lookup: dict) -> sympy.Expr:
return sympy.Symbol(self._name)


def _maybe_array(val, make_array):
if make_array:
return jnp.asarray(val)
else:
return val


class _Integer(_AbstractNode):
_value: jnp.ndarray

def __init__(self, expr: sympy.Expr):
def __init__(self, expr: sympy.Expr, make_array: bool):
assert isinstance(expr, sympy.Integer)
self._value = jnp.asarray(int(expr))
self._value = _maybe_array(int(expr), make_array)

def __call__(self, memodict: _IdDict):
return self._value
Expand All @@ -140,9 +147,9 @@ def sympy(self, memodict: _IdDict, func_lookup: dict) -> sympy.Expr:
class _Float(_AbstractNode):
_value: jnp.ndarray

def __init__(self, expr: sympy.Expr):
def __init__(self, expr: sympy.Expr, make_array: bool):
assert isinstance(expr, sympy.Float)
self._value = jnp.asarray(float(expr))
self._value = _maybe_array(float(expr), make_array)

def __call__(self, memodict: _IdDict):
return self._value
Expand All @@ -156,10 +163,10 @@ class _Rational(_AbstractNode):
_numerator: jnp.ndarray
_denominator: jnp.ndarray

def __init__(self, expr: sympy.Expr):
def __init__(self, expr: sympy.Expr, make_array: bool):
assert isinstance(expr, sympy.Rational)
self._numerator = jnp.asarray(int(expr.numerator))
self._denominator = jnp.asarray(int(expr.denominator))
self._numerator = _maybe_array(int(expr.numerator), make_array)
self._denominator = _maybe_array(int(expr.denominator), make_array)

def __call__(self, memodict: _IdDict):
return self._numerator / self._denominator
Expand All @@ -173,12 +180,16 @@ class _Func(_AbstractNode):
_func: Callable
_args: list

def __init__(self, expr: sympy.Expr, memodict: _IdDict, func_lookup: dict):
def __init__(
self, expr: sympy.Expr, memodict: _IdDict, func_lookup: dict, make_array: bool
):
try:
self._func = func_lookup[expr.func]
except KeyError as e:
raise KeyError(f"Unsupported Sympy type {type(expr)}") from e
self._args = [_sympy_to_node(arg, memodict, func_lookup) for arg in expr.args]
self._args = [
_sympy_to_node(arg, memodict, func_lookup, make_array) for arg in expr.args
]

def __call__(self, memodict: _IdDict):
args = []
Expand All @@ -203,21 +214,21 @@ def sympy(self, memodict: _IdDict, func_lookup: dict) -> sympy.Expr:


def _sympy_to_node(
expr: sympy.Expr, memodict: _IdDict, func_lookup: dict
expr: sympy.Expr, memodict: _IdDict, func_lookup: dict, make_array: bool
) -> _AbstractNode:
try:
return memodict[expr]
except KeyError:
if isinstance(expr, sympy.Symbol):
out = _Symbol(expr)
elif isinstance(expr, sympy.Integer):
out = _Integer(expr)
out = _Integer(expr, make_array)
elif isinstance(expr, sympy.Float):
out = _Float(expr)
out = _Float(expr, make_array)
elif isinstance(expr, sympy.Rational):
out = _Rational(expr)
out = _Rational(expr, make_array)
else:
out = _Func(expr, memodict, func_lookup)
out = _Func(expr, memodict, func_lookup, make_array)
memodict[expr] = out
return out

Expand All @@ -231,7 +242,11 @@ class SymbolicModule(eqx.Module):
has_extra_funcs: bool = eqx.static_field()

def __init__(
self, expressions: PyTree, extra_funcs: Optional[dict] = None, **kwargs
self,
expressions: PyTree,
extra_funcs: Optional[dict] = None,
make_array: bool = True,
**kwargs,
):
super().__init__(**kwargs)
if extra_funcs is None:
Expand All @@ -240,7 +255,12 @@ def __init__(
else:
lookup = co.ChainMap(extra_funcs, _lookup)
self.has_extra_funcs = True
_convert = ft.partial(_sympy_to_node, memodict=_IdDict(), func_lookup=lookup)
_convert = ft.partial(
_sympy_to_node,
memodict=_IdDict(),
func_lookup=lookup,
make_array=make_array,
)
self.nodes = jax.tree_map(_convert, expressions)

def sympy(self) -> sympy.Expr:
Expand Down

0 comments on commit 5c6e80e

Please sign in to comment.