diff --git a/sympy2jax/sympy_module.py b/sympy2jax/sympy_module.py index 0c6232e..7d5d34e 100644 --- a/sympy2jax/sympy_module.py +++ b/sympy2jax/sympy_module.py @@ -93,6 +93,13 @@ def fn_(*args): sympy.Determinant: jnp.linalg.det, } +_constant_lookup = { + sympy.E: jnp.e, + sympy.pi: jnp.pi, + sympy.EulerGamma: jnp.euler_gamma, + sympy.I: 1j, +} + _reverse_lookup = {v: k for k, v in _lookup.items()} assert len(_reverse_lookup) == len(_lookup) @@ -198,6 +205,22 @@ def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr: ) +class _Constant(_AbstractNode): + _value: jnp.ndarray + _expr: sympy.Expr + + def __init__(self, expr: sympy.Expr, make_array: bool): + assert expr in _constant_lookup + self._value = _maybe_array(_constant_lookup[expr], make_array) + self._expr = expr + + def __call__(self, memodict: dict): + return self._value + + def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr: + return self._expr + + class _Func(_AbstractNode): _func: Callable _args: list @@ -250,6 +273,8 @@ def _sympy_to_node( out = _Float(expr, make_array) elif isinstance(expr, sympy.Rational): out = _Rational(expr, make_array) + elif expr in (sympy.E, sympy.pi, sympy.EulerGamma, sympy.I): + out = _Constant(expr, make_array) else: out = _Func(expr, memodict, func_lookup, make_array) memodict[expr] = out diff --git a/tests/test_symbolic_module.py b/tests/test_symbolic_module.py index e49bbc0..6aff574 100644 --- a/tests/test_symbolic_module.py +++ b/tests/test_symbolic_module.py @@ -124,6 +124,14 @@ def test_rational(): assert mod.sympy() == y +def test_constants(): + x = sympy.symbols("x") + y = x + sympy.pi + sympy.E + sympy.EulerGamma + sympy.I + mod = sympy2jax.SymbolicModule(y) + assert jnp.isclose(mod(x=1.0), 1 + jnp.pi + jnp.e + jnp.euler_gamma + 1j) + assert mod.sympy() == y + + def test_extra_funcs(): class _MLP(eqx.Module): mlp: eqx.nn.MLP