diff --git a/.flake8 b/.flake8 deleted file mode 100644 index ce6fdf7..0000000 --- a/.flake8 +++ /dev/null @@ -1,4 +0,0 @@ -[flake8] -max-line-length = 120 -ignore = W291,W293,W503,W504,E123,E126,E203,E402,E701,E731 -per-file-ignores = __init__.py: F401 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c36bf6a..a4e560c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -10,9 +10,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Release - uses: patrick-kidger/action_update_python_project@v1 + uses: patrick-kidger/action_update_python_project@v2 with: - python-version: "3.8" + python-version: "3.11" test-script: | python -m pip install pytest jax jaxlib sympy equinox cp -r ${{ github.workspace }}/tests ./tests diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 656f657..59d9ebd 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -7,7 +7,7 @@ jobs: run-tests: strategy: matrix: - python-version: [ 3.7, 3.8, 3.9 ] + python-version: [ 3.9 ] os: [ ubuntu-latest ] fail-fast: false runs-on: ${{ matrix.os }} diff --git a/.isort.cfg b/.isort.cfg deleted file mode 100644 index 64737cc..0000000 --- a/.isort.cfg +++ /dev/null @@ -1,5 +0,0 @@ -[settings] -force_alphabetical_sort_within_sections=true -lines_after_imports=2 -profile=black -treat_comments_as_code=true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3af767a..b6033ae 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,33 +1,23 @@ -# Copyright 2022 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - repos: - repo: https://github.com/ambv/black rev: 22.3.0 hooks: - id: black - - repo: https://github.com/nbQA-dev/nbQA - rev: 1.2.3 + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: 'v0.0.255' hooks: - - id: nbqa-black - - id: nbqa-isort - - id: nbqa-flake8 - - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + - id: ruff + args: ["--fix"] + - repo: https://github.com/RobertCraigie/pyright-python + rev: v1.1.315 hooks: - - id: isort - - repo: https://github.com/pycqa/flake8 - rev: 4.0.1 + - id: pyright + additional_dependencies: ["equinox", "jax", "sympy"] + - repo: https://github.com/nbQA-dev/nbQA + rev: 1.6.3 hooks: - - id: flake8 + - id: nbqa-black + additional_dependencies: [ipython==8.12, black] + - id: nbqa-ruff + args: ["--ignore=I001"] + additional_dependencies: [ipython==8.12, ruff] diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index dd2678e..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1,2 +0,0 @@ -include LICENSE -prune tests diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..076a7a1 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,48 @@ +[project] +name = "sympy2jax" +version = "0.0.4" +description = "Turn SymPy expressions into trainable JAX expressions." +readme = "README.md" +requires-python ="~=3.9" +license = {file = "LICENSE"} +authors = [ + {name = "Patrick Kidger", email = "contact@kidger.site"}, +] +keywords = ["jax", "sympy", "equinox"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Natural Language :: English", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Mathematics", +] +urls = {repository = "https://github.com/google/sympy2jax" } +dependencies = ["equinox>=0.5.3", "jax>=0.3.4", "sympy>=1.7.1"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["sympy2jax/*"] + +[tool.pytest.ini_options] +addopts = "--jaxtyping-packages=symyp2jax,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))" + +[tool.ruff] +select = ["E", "F", "I001"] +ignore = ["E402", "E721", "E731", "E741", "F722"] +ignore-init-module-imports = true +fixable = ["I001", "F401"] + +[tool.ruff.isort] +combine-as-imports = true +lines-after-imports = 2 +extra-standard-library = ["typing_extensions"] +order-by-type = false + +[tool.pyright] +reportIncompatibleMethodOverride = true +include = ["sympy2jax", "tests"] diff --git a/setup.py b/setup.py deleted file mode 100644 index 28136b0..0000000 --- a/setup.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright 2022 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pathlib -import re - -import setuptools - - -_here = pathlib.Path(__file__).resolve().parent - -name = "sympy2jax" - -# for simplicity we actually store the version in the __version__ attribute in the source -with open(_here / name / "__init__.py") as f: - meta_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M) - if meta_match: - version = meta_match.group(1) - else: - raise RuntimeError("Unable to find __version__ string.") - -author = "Patrick Kidger" - -author_email = "contact@kidger.site" - -description = "Turn SymPy expressions into trainable JAX expressions." - -with open(_here / "README.md", "r") as f: - readme = f.read() - -url = "https://github.com/google/sympy2jax" - -license = "Apache-2.0" - -classifiers = [ - "Development Status :: 3 - Alpha", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: Apache Software License", - "Natural Language :: English", - "Programming Language :: Python :: 3", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Scientific/Engineering :: Mathematics", -] - -python_requires = "~=3.7" - -install_requires = ["equinox>=0.5.3", "jax>=0.3.4", "sympy>=1.7.1"] - -setuptools.setup( - name=name, - version=version, - author=author, - author_email=author_email, - maintainer=author, - maintainer_email=author_email, - description=description, - long_description=readme, - long_description_content_type="text/markdown", - url=url, - license=license, - classifiers=classifiers, - zip_safe=False, - python_requires=python_requires, - install_requires=install_requires, - packages=[name], -) diff --git a/sympy2jax/__init__.py b/sympy2jax/__init__.py index f271964..75b499a 100644 --- a/sympy2jax/__init__.py +++ b/sympy2jax/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .sympy_module import concatenate, stack, SymbolicModule +from .sympy_module import ( + concatenate as concatenate, + stack as stack, + SymbolicModule as SymbolicModule, +) __version__ = "0.0.4" diff --git a/sympy2jax/sympy_module.py b/sympy2jax/sympy_module.py index 6e4dcd5..0c6232e 100644 --- a/sympy2jax/sympy_module.py +++ b/sympy2jax/sympy_module.py @@ -15,7 +15,8 @@ import abc import collections as co import functools as ft -from typing import Any, Callable, Optional +from collections.abc import Callable, Mapping +from typing import Any, cast, Optional import equinox as eqx import jax @@ -26,8 +27,8 @@ PyTree = Any -concatenate = sympy.Function("concatenate") -stack = sympy.Function("stack") +concatenate: Callable = sympy.Function("concatenate") # pyright: ignore +stack: Callable = sympy.Function("stack") # pyright: ignore def _reduce(fn): @@ -96,9 +97,16 @@ def fn_(*args): assert len(_reverse_lookup) == len(_lookup) +def _item(x): + if eqx.is_array(x): + return x.item() + else: + return x + + class _AbstractNode(eqx.Module): @abc.abstractmethod - def __call__(self, memodict: dict): + def __call__(self, memodict: dict) -> jax.typing.ArrayLike: ... @abc.abstractmethod @@ -107,14 +115,14 @@ def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr: # Comparisons based on identity __hash__ = object.__hash__ - __eq__ = object.__eq__ + __eq__ = object.__eq__ # pyright: ignore class _Symbol(_AbstractNode): _name: str def __init__(self, expr: sympy.Expr): - self._name = expr.name + self._name = expr.name # pyright: ignore def __call__(self, memodict: dict): try: @@ -135,7 +143,7 @@ def _maybe_array(val, make_array): class _Integer(_AbstractNode): - _value: jnp.ndarray + _value: jax.typing.ArrayLike def __init__(self, expr: sympy.Expr, make_array: bool): assert isinstance(expr, sympy.Integer) @@ -146,11 +154,11 @@ def __call__(self, memodict: dict): def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr: # memodict not needed as sympy deduplicates internally - return sympy.Integer(self._value.item()) + return sympy.Integer(_item(self._value)) class _Float(_AbstractNode): - _value: jnp.ndarray + _value: jax.typing.ArrayLike def __init__(self, expr: sympy.Expr, make_array: bool): assert isinstance(expr, sympy.Float) @@ -161,12 +169,12 @@ def __call__(self, memodict: dict): def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr: # memodict not needed as sympy deduplicates internally - return sympy.Float(self._value.item()) + return sympy.Float(_item(self._value)) class _Rational(_AbstractNode): - _numerator: jnp.ndarray - _denominator: jnp.ndarray + _numerator: jax.typing.ArrayLike + _denominator: jax.typing.ArrayLike def __init__(self, expr: sympy.Expr, make_array: bool): assert isinstance(expr, sympy.Rational) @@ -185,7 +193,9 @@ def __call__(self, memodict: dict): def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr: # memodict not needed as sympy deduplicates internally - return sympy.Integer(self._numerator) / sympy.Integer(self._denominator) + return sympy.Integer(_item(self._numerator)) / sympy.Integer( + _item(self._denominator) + ) class _Func(_AbstractNode): @@ -193,14 +203,15 @@ class _Func(_AbstractNode): _args: list def __init__( - self, expr: sympy.Expr, memodict: dict, func_lookup: dict, make_array: bool + self, expr: sympy.Expr, memodict: dict, func_lookup: Mapping, make_array: bool ): try: self._func = func_lookup[expr.func] except KeyError as e: raise KeyError(f"Unsupported Sympy type {type(expr)}") from e self._args = [ - _sympy_to_node(arg, memodict, func_lookup, make_array) for arg in expr.args + _sympy_to_node(cast(sympy.Expr, arg), memodict, func_lookup, make_array) + for arg in expr.args ] def __call__(self, memodict: dict): @@ -226,7 +237,7 @@ def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr: def _sympy_to_node( - expr: sympy.Expr, memodict: dict, func_lookup: dict, make_array: bool + expr: sympy.Expr, memodict: dict, func_lookup: Mapping, make_array: bool ) -> _AbstractNode: try: return memodict[expr] @@ -258,9 +269,7 @@ def __init__( expressions: PyTree, extra_funcs: Optional[dict] = None, make_array: bool = True, - **kwargs, ): - super().__init__(**kwargs) if extra_funcs is None: lookup = _lookup self.has_extra_funcs = False @@ -278,7 +287,8 @@ def __init__( def sympy(self) -> sympy.Expr: if self.has_extra_funcs: raise NotImplementedError( - "SymbolicModule cannot be converted back to SymPy if `extra_funcs` is passed" + "SymbolicModule cannot be converted back to SymPy if `extra_funcs` " + "is passed." ) memodict = dict() return jax.tree_map( diff --git a/tests/test_symbolic_module.py b/tests/test_symbolic_module.py index 70b3608..e49bbc0 100644 --- a/tests/test_symbolic_module.py +++ b/tests/test_symbolic_module.py @@ -16,14 +16,15 @@ import jax import jax.numpy as jnp import jax.random as jr +import jax.tree_util as jtu import sympy import sympy2jax def assert_equal(x, y): - x_leaves, x_tree = jax.tree_flatten(x) - y_leaves, y_tree = jax.tree_flatten(y) + x_leaves, x_tree = jtu.tree_flatten(x) + y_leaves, y_tree = jtu.tree_flatten(y) assert x_tree == y_tree for xi, yi in zip(x_leaves, y_leaves): assert type(xi) is type(yi) @@ -44,10 +45,10 @@ def assert_sympy_allclose(x, y): elif isinstance(x, sympy.Integer): assert x == y elif isinstance(x, sympy.Rational): - assert x.numerator == y.numerator - assert x.denominator == y.denominator + assert x.numerator == y.numerator # pyright: ignore + assert x.denominator == y.denominator # pyright: ignore elif isinstance(x, sympy.Symbol): - assert x.name == y.name + assert x.name == y.name # pyright: ignore else: assert len(x.args) == len(y.args) for xarg, yarg in zip(x.args, y.args): @@ -62,7 +63,7 @@ def test_example(): x = jax.numpy.zeros(3) out = mod(x_sym=x) - params = jax.tree_leaves(mod) + params = jtu.tree_leaves(mod) assert_equal(out, [jnp.cos(x), 2 * jnp.sin(x)]) assert_equal( @@ -140,7 +141,7 @@ def __call__(self, x): mod(x=1.0, y=2.0) def _get_params(module): - return {id(x) for x in jax.tree_leaves(module) if eqx.is_array(x)} + return {id(x) for x in jtu.tree_leaves(module) if eqx.is_array(x)} assert _get_params(mod).issuperset(_get_params(mlp)) @@ -163,3 +164,8 @@ def test_stack(): mod(x=jnp.array(0.4), y=jnp.array(0.5), z=jnp.array(0.6)), jnp.array([0.4, 0.5, 0.6]), ) + + +def test_non_array_to_sympy(): + mod = sympy2jax.SymbolicModule(expressions=[sympy.Integer(1)], make_array=False) + assert mod.sympy() == [sympy.Integer(1)]