Skip to content

Commit

Permalink
Merge pull request #780 from MilesCranmer/backend-1.4.0
Browse files Browse the repository at this point in the history
feat: add differential operators and input stream specification
  • Loading branch information
MilesCranmer authored Dec 14, 2024
2 parents 8c8695b + 13f1cb0 commit efc034f
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 6 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/CI_apptainer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ jobs:
with:
apptainer-version: 1.3.0
- name: Build apptainer
run: apptainer build --notest pysr.sif Apptainer.def
run: sudo apptainer build --notest pysr.sif Apptainer.def
- name: Test apptainer
run: |
TMPDIR=$(mktemp -d)
cp pysr.sif $TMPDIR
cd $TMPDIR
apptainer test ./pysr.sif
sudo apptainer test ./pysr.sif
3 changes: 3 additions & 0 deletions pysr/julia_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,8 @@ def _import_juliacall():
jl.seval("using SymbolicRegression")
SymbolicRegression = jl.SymbolicRegression

# Expose `D` operator:
jl.seval("using SymbolicRegression: D")

jl.seval("using Pkg: Pkg")
Pkg = jl.Pkg
2 changes: 1 addition & 1 deletion pysr/juliapkg.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"packages": {
"SymbolicRegression": {
"uuid": "8254be44-1295-4e6a-a16d-46603ac705cb",
"version": "=1.2.0"
"version": "=1.4.0"
},
"Serialization": {
"uuid": "9e88b42a-f829-5b0c-bbe9-9e923198166b",
Expand Down
1 change: 1 addition & 0 deletions pysr/param_groupings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
- print_precision
- progress
- logger_spec
- input_stream
- Environment:
- temp_equation_file
- tempdir
Expand Down
22 changes: 19 additions & 3 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def _maybe_create_inline_operators(
binary_operators: list[str],
unary_operators: list[str],
extra_sympy_mappings: dict[str, Callable] | None,
expression_spec: AbstractExpressionSpec,
) -> tuple[list[str], list[str]]:
binary_operators = binary_operators.copy()
unary_operators = unary_operators.copy()
Expand All @@ -132,9 +133,11 @@ def _maybe_create_inline_operators(
"Only alphanumeric characters, numbers, "
"and underscores are allowed."
)
if (extra_sympy_mappings is None) or (
function_name not in extra_sympy_mappings
):
missing_sympy_mapping = (
extra_sympy_mappings is None
or function_name not in extra_sympy_mappings
)
if missing_sympy_mapping and expression_spec.supports_sympy:
raise ValueError(
f"Custom function {function_name} is not defined in `extra_sympy_mappings`. "
"You can define it with, "
Expand Down Expand Up @@ -618,6 +621,12 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
Logger specification for the Julia backend. See, for example,
`TensorBoardLoggerSpec`.
Default is `None`.
input_stream : str
The stream to read user input from. By default, this is `"stdin"`.
If you encounter issues with reading from `stdin`, like a hang,
you can simply pass `"devnull"` to this argument. You can also
reference an arbitrary Julia object in the `Main` namespace.
Default is `"stdin"`.
run_id : str
A unique identifier for the run. Will be generated using the
current date and time if not provided.
Expand Down Expand Up @@ -863,6 +872,7 @@ def __init__(
print_precision: int = 5,
progress: bool = True,
logger_spec: AbstractLoggerSpec | None = None,
input_stream: str = "stdin",
run_id: str | None = None,
output_directory: str | None = None,
temp_equation_file: bool = False,
Expand Down Expand Up @@ -969,6 +979,7 @@ def __init__(
self.print_precision = print_precision
self.progress = progress
self.logger_spec = logger_spec
self.input_stream = input_stream
# - Project management
self.run_id = run_id
self.output_directory = output_directory
Expand Down Expand Up @@ -1217,6 +1228,7 @@ def __getstate__(self) -> dict[str, Any]:
f"`{state_key}` at runtime."
)
state_keys_to_clear = state_keys_containing_lambdas
state_keys_to_clear.append("logger_")
pickled_state = {
key: (None if key in state_keys_to_clear else value)
for key, value in state.items()
Expand Down Expand Up @@ -1837,6 +1849,7 @@ def _run(
binary_operators=binary_operators,
unary_operators=unary_operators,
extra_sympy_mappings=self.extra_sympy_mappings,
expression_spec=self.expression_spec_,
)
if constraints is not None:
_constraints = _process_constraints(
Expand Down Expand Up @@ -1888,6 +1901,8 @@ def _run(
else "nothing"
)

input_stream = jl.seval(self.input_stream)

load_required_packages(
turbo=self.turbo,
bumper=self.bumper,
Expand Down Expand Up @@ -2002,6 +2017,7 @@ def _run(
crossover_probability=self.crossover_probability,
skip_mutation_failures=self.skip_mutation_failures,
max_evals=self.max_evals,
input_stream=input_stream,
early_stop_condition=early_stop_condition,
seed=seed,
deterministic=self.deterministic,
Expand Down

0 comments on commit efc034f

Please sign in to comment.