diff --git a/sympy2jax/__init__.py b/sympy2jax/__init__.py index 2422095..f271964 100644 --- a/sympy2jax/__init__.py +++ b/sympy2jax/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .sympy_module import SymbolicModule +from .sympy_module import concatenate, stack, SymbolicModule __version__ = "0.0.4" diff --git a/sympy2jax/sympy_module.py b/sympy2jax/sympy_module.py index 7ab66fd..6e4dcd5 100644 --- a/sympy2jax/sympy_module.py +++ b/sympy2jax/sympy_module.py @@ -26,6 +26,9 @@ PyTree = Any +concatenate = sympy.Function("concatenate") +stack = sympy.Function("stack") + def _reduce(fn): def fn_(*args): @@ -34,7 +37,16 @@ def fn_(*args): return fn_ +def _single_args(fn): + def fn_(*args): + return fn(args) + + return fn_ + + _lookup = { + concatenate: _single_args(jnp.concatenate), + stack: _single_args(jnp.stack), sympy.Mul: _reduce(jnp.multiply), sympy.Add: _reduce(jnp.add), sympy.div: jnp.divide, diff --git a/tests/test_symbolic_module.py b/tests/test_symbolic_module.py index 7c823ef..70b3608 100644 --- a/tests/test_symbolic_module.py +++ b/tests/test_symbolic_module.py @@ -143,3 +143,23 @@ def _get_params(module): return {id(x) for x in jax.tree_leaves(module) if eqx.is_array(x)} assert _get_params(mod).issuperset(_get_params(mlp)) + + +def test_concatenate(): + x, y, z = sympy.symbols("x y z") + cat = sympy2jax.concatenate(x, y, z) + mod = sympy2jax.SymbolicModule(expressions=cat) + assert_equal( + mod(x=jnp.array([0.4, 0.5]), y=jnp.array([0.6, 0.7]), z=jnp.array([0.8, 0.9])), + jnp.array([0.4, 0.5, 0.6, 0.7, 0.8, 0.9]), + ) + + +def test_stack(): + x, y, z = sympy.symbols("x y z") + stack = sympy2jax.stack(x, y, z) + mod = sympy2jax.SymbolicModule(expressions=stack) + assert_equal( + mod(x=jnp.array(0.4), y=jnp.array(0.5), z=jnp.array(0.6)), + jnp.array([0.4, 0.5, 0.6]), + )