Skip to content

Commit

Permalink
Merge pull request #670 from MilesCranmer/issue666
Browse files Browse the repository at this point in the history
fix: `from pysr import *`
  • Loading branch information
MilesCranmer authored Jul 15, 2024
2 parents db44938 + e84bed4 commit b658d24
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 13 deletions.
1 change: 0 additions & 1 deletion pysr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"sklearn_monkeypatch",
"sympy2jax",
"sympy2torch",
"Problem",
"install",
"PySRRegressor",
"best",
Expand Down
2 changes: 1 addition & 1 deletion pysr/export_jax.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np # noqa: F401
import sympy
import sympy # type: ignore

# Special since need to reduce arguments.
MUL = 0
Expand Down
4 changes: 2 additions & 2 deletions pysr/export_latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import List, Optional, Tuple

import pandas as pd
import sympy
from sympy.printing.latex import LatexPrinter
import sympy # type: ignore
from sympy.printing.latex import LatexPrinter # type: ignore


class PreciseLatexPrinter(LatexPrinter):
Expand Down
2 changes: 1 addition & 1 deletion pysr/export_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import pandas as pd
from numpy.typing import NDArray
from sympy import Expr, Symbol, lambdify
from sympy import Expr, Symbol, lambdify # type: ignore


def sympy2numpy(eqn, sympy_symbols, *, selection=None):
Expand Down
2 changes: 1 addition & 1 deletion pysr/export_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Callable, Dict, List, Optional

import sympy
import sympy # type: ignore
from sympy import sympify

from .utils import ArrayLike
Expand Down
2 changes: 1 addition & 1 deletion pysr/export_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools as ft

import numpy as np # noqa: F401
import sympy
import sympy # type: ignore


def _reduce(fn):
Expand Down
15 changes: 12 additions & 3 deletions pysr/test/test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import os
import pickle as pkl
import tempfile
Expand All @@ -8,7 +9,7 @@

import numpy as np
import pandas as pd
import sympy
import sympy # type: ignore
from sklearn.utils.estimator_checks import check_estimator

from pysr import PySRRegressor, install, jl
Expand Down Expand Up @@ -892,7 +893,7 @@ def test_suggest_keywords(self):

# More complex, and with error
with self.assertRaises(TypeError) as cm:
model = PySRRegressor(ncyclesperiterationn=5)
PySRRegressor(ncyclesperiterationn=5)

self.assertIn(
"`ncyclesperiterationn` is not a valid keyword", str(cm.exception)
Expand All @@ -903,10 +904,18 @@ def test_suggest_keywords(self):

# Farther matches (this might need to be changed)
with self.assertRaises(TypeError) as cm:
model = PySRRegressor(operators=["+", "-"])
PySRRegressor(operators=["+", "-"])

self.assertIn("`unary_operators`, `binary_operators`", str(cm.exception))

def test_issue_666(self):
# Try the equivalent of `from pysr import *`
pysr_module = importlib.import_module("pysr")
names_to_import = pysr_module.__all__

for name in names_to_import:
getattr(pysr_module, name)


TRUE_PREAMBLE = "\n".join(
[
Expand Down
4 changes: 2 additions & 2 deletions pysr/test/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
import pandas as pd
import sympy
import sympy # type: ignore

import pysr
from pysr import PySRRegressor, sympy2jax
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_avoid_simplification(self):
)

def test_issue_656(self):
import sympy
import sympy # type: ignore

E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
f, params = pysr.export_jax.sympy2jax(E_plus_x1, [sympy.symbols("x1")])
Expand Down
2 changes: 1 addition & 1 deletion pysr/test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
import pandas as pd
import sympy
import sympy # type: ignore

import pysr
from pysr import PySRRegressor, sympy2torch
Expand Down

0 comments on commit b658d24

Please sign in to comment.