From 8591b280a4384f19ea3787cc6c52d1de5f321224 Mon Sep 17 00:00:00 2001 From: "Nathaniel D. Hoffman" Date: Sat, 21 Oct 2023 14:24:36 -0400 Subject: [PATCH 1/3] add numerical constants pi, e, Euler's gamma, and imaginary I --- sympy2jax/sympy_module.py | 24 ++++++++++++++++++++++++ tests/test_symbolic_module.py | 8 ++++++++ 2 files changed, 32 insertions(+) diff --git a/sympy2jax/sympy_module.py b/sympy2jax/sympy_module.py index 6e4dcd5..7a446ed 100644 --- a/sympy2jax/sympy_module.py +++ b/sympy2jax/sympy_module.py @@ -187,6 +187,28 @@ def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr: # memodict not needed as sympy deduplicates internally return sympy.Integer(self._numerator) / sympy.Integer(self._denominator) +class _Constant(_AbstractNode): + _value: jnp.ndarray + _expr: sympy.Expr + + _constant_lookup = { + sympy.E: jnp.e, + sympy.pi: jnp.pi, + sympy.EulerGamma: jnp.euler_gamma, + sympy.I: 1j, + } + def __init__(self, expr: sympy.Expr, make_array: bool): + assert expr in (sympy.E, sympy.pi, sympy.EulerGamma, sympy.I) + self._value = _maybe_array(self._constant_lookup[expr], make_array) + self._expr = expr + + def __call__(self, memodict: dict): + return self._value + + def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr: + # memodict not needed as sympy deduplicates internally + return self._expr + class _Func(_AbstractNode): _func: Callable @@ -239,6 +261,8 @@ def _sympy_to_node( out = _Float(expr, make_array) elif isinstance(expr, sympy.Rational): out = _Rational(expr, make_array) + elif expr in (sympy.E, sympy.pi, sympy.EulerGamma, sympy.I): + out = _Constant(expr, make_array) else: out = _Func(expr, memodict, func_lookup, make_array) memodict[expr] = out diff --git a/tests/test_symbolic_module.py b/tests/test_symbolic_module.py index 70b3608..a4f6b58 100644 --- a/tests/test_symbolic_module.py +++ b/tests/test_symbolic_module.py @@ -123,6 +123,14 @@ def test_rational(): assert mod.sympy() == y +def test_constants(): + x = sympy.symbols("x") + y = x + sympy.pi + sympy.E + sympy.EulerGamma + sympy.I + mod = sympy2jax.SymbolicModule(y) + assert jnp.isclose(mod(x=1.0), 1 + jnp.pi + jnp.e + jnp.euler_gamma + 1j) + assert mod.sympy() == y + + def test_extra_funcs(): class _MLP(eqx.Module): mlp: eqx.nn.MLP From 8e05d49a6f26423ca04eb847600ca054dd901da1 Mon Sep 17 00:00:00 2001 From: "Nathaniel D. Hoffman" Date: Sat, 21 Oct 2023 16:51:38 -0400 Subject: [PATCH 2/3] move _constant_lookup to global scope, update assertion --- sympy2jax/sympy_module.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sympy2jax/sympy_module.py b/sympy2jax/sympy_module.py index 2131c2d..50bdc05 100644 --- a/sympy2jax/sympy_module.py +++ b/sympy2jax/sympy_module.py @@ -93,6 +93,13 @@ def fn_(*args): sympy.Determinant: jnp.linalg.det, } +_constant_lookup = { + sympy.E: jnp.e, + sympy.pi: jnp.pi, + sympy.EulerGamma: jnp.euler_gamma, + sympy.I: 1j, +} + _reverse_lookup = {v: k for k, v in _lookup.items()} assert len(_reverse_lookup) == len(_lookup) @@ -201,22 +208,15 @@ class _Constant(_AbstractNode): _value: jnp.ndarray _expr: sympy.Expr - _constant_lookup = { - sympy.E: jnp.e, - sympy.pi: jnp.pi, - sympy.EulerGamma: jnp.euler_gamma, - sympy.I: 1j, - } def __init__(self, expr: sympy.Expr, make_array: bool): - assert expr in (sympy.E, sympy.pi, sympy.EulerGamma, sympy.I) - self._value = _maybe_array(self._constant_lookup[expr], make_array) + assert expr in _constant_lookup + self._value = _maybe_array(_constant_lookup[expr], make_array) self._expr = expr def __call__(self, memodict: dict): return self._value def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr: - # memodict not needed as sympy deduplicates internally return self._expr From 53528bb15fb9024989c2a31f737987381d9fe0cf Mon Sep 17 00:00:00 2001 From: "Nathaniel D. Hoffman" Date: Sat, 21 Oct 2023 17:08:08 -0400 Subject: [PATCH 3/3] formatter complained, added a line --- sympy2jax/sympy_module.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sympy2jax/sympy_module.py b/sympy2jax/sympy_module.py index 50bdc05..7d5d34e 100644 --- a/sympy2jax/sympy_module.py +++ b/sympy2jax/sympy_module.py @@ -204,6 +204,7 @@ def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr: _item(self._denominator) ) + class _Constant(_AbstractNode): _value: jnp.ndarray _expr: sympy.Expr