Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing PyTorch/JAX export for logical_or, logical_and, and relu #433

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
98cca40
Refactor utility functions
MilesCranmer Sep 17, 2023
4713607
Move denoising functionality to separate file
MilesCranmer Sep 17, 2023
3ae241a
Move feature selection functionality to separate file
MilesCranmer Sep 17, 2023
ff2ef42
Mypy compatibility
MilesCranmer Sep 17, 2023
135a464
Move all deprecated functions to deprecated.py
MilesCranmer Sep 17, 2023
6c92e1c
Store `sr_options_` and rename state to `sr_state_`
MilesCranmer Sep 17, 2023
ff2f93a
Add missing sympy operators for boolean logic
MilesCranmer Sep 19, 2023
d5787b2
Add missing sympy operators for relu
MilesCranmer Sep 19, 2023
47823ba
Add functionality for piecewise export to torch
MilesCranmer Sep 22, 2023
73d0f8a
Clean up error message in exports
MilesCranmer Sep 22, 2023
f92a935
Implement relu, logical_or, logical_and
MilesCranmer Sep 22, 2023
2a20447
Remove unnecessary as_bool
MilesCranmer Sep 22, 2023
22b047a
Merge tag 'v0.16.4' into sympy-or
MilesCranmer Dec 14, 2023
11dea32
Replace Heaviside with piecewise
MilesCranmer Dec 14, 2023
208307d
Merge tag 'v0.16.4' into store-options
MilesCranmer Dec 14, 2023
f21e3d6
Merge branch 'master' into store-options
MilesCranmer Dec 14, 2023
50c1407
Merge branch 'store-options' into sympy-or
MilesCranmer Dec 14, 2023
cff611a
Apply suggestions from code review
MilesCranmer Jun 3, 2024
3f1524b
Update pysr/export_torch.py
MilesCranmer Jun 3, 2024
01e1a15
Update pysr/export_torch.py
MilesCranmer Jun 3, 2024
5c0a49a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2024
0f47a59
Merge tag 'v0.18.4' into sympy-or
MilesCranmer Jun 3, 2024
c008678
Merge branch 'master' into sympy-or
MilesCranmer Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pysr/export_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None):
_func = {**_jnp_func_lookup, **extra_jax_mappings}[expr.func]
except KeyError:
raise KeyError(
f"Function {expr.func} was not found in JAX function mappings."
f"Function {expr.func} was not found in JAX function mappings. "
"Please add it to extra_jax_mappings in the format, e.g., "
"{sympy.sqrt: 'jnp.sqrt'}."
)
Expand Down
80 changes: 78 additions & 2 deletions pysr/export_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,70 @@ def _initialize_torch():

torch = _torch

# Allows PyTorch to map Piecewise functions:
def expr_cond_pair(expr, cond):
if isinstance(cond, torch.Tensor) and not isinstance(expr, torch.Tensor):
expr = torch.tensor(expr, dtype=cond.dtype, device=cond.device)
elif isinstance(expr, torch.Tensor) and not isinstance(cond, torch.Tensor):
cond = torch.tensor(cond, dtype=expr.dtype, device=expr.device)
else:
return expr, cond

# First, make sure expr and cond are same size:
if expr.shape != cond.shape:
if len(expr.shape) == 0:
expr = expr.expand(cond.shape)
elif len(cond.shape) == 0:
cond = cond.expand(expr.shape)
else:
raise ValueError(
"expr and cond must have same shape, or one must be a scalar."
)
return expr, cond

MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
def if_then_else(*conds):
a, b, c = conds
return torch.where(
a, torch.where(b, True, False), torch.where(c, True, False)
)

def piecewise(*expr_conds):
output = None
already_used = None
for expr, cond in expr_conds:
if not isinstance(cond, torch.Tensor) and not isinstance(
expr, torch.Tensor
):
# When we just have scalars, have to do this a bit more complicated
# due to the fact that we need to evaluate on the correct device.
if output is None:
already_used = cond
output = expr if cond else 0.0
else:
if not isinstance(output, torch.Tensor):
output += expr if cond and not already_used else 0.0
already_used = already_used or cond
else:
expr = torch.tensor(
expr, dtype=output.dtype, device=output.device
).expand(output.shape)
output += torch.where(
cond & ~already_used, expr, torch.zeros_like(expr)
)
already_used = already_used | cond
else:
if output is None:
already_used = cond
output = torch.where(cond, expr, torch.zeros_like(expr))
else:
output += torch.where(
cond.bool() & ~already_used, expr, torch.zeros_like(expr)
)
already_used = already_used | cond.bool()
return output

# TODO: Add test that makes sure tensors are on the same device

_global_func_lookup = {
sympy.Mul: _reduce(torch.mul),
sympy.Add: _reduce(torch.add),
Expand Down Expand Up @@ -81,6 +145,12 @@ def _initialize_torch():
sympy.Heaviside: torch.heaviside,
sympy.core.numbers.Half: (lambda: 0.5),
sympy.core.numbers.One: (lambda: 1.0),
sympy.logic.boolalg.Boolean: lambda x: x,
sympy.logic.boolalg.BooleanTrue: (lambda: True),
sympy.logic.boolalg.BooleanFalse: (lambda: False),
sympy.functions.elementary.piecewise.ExprCondPair: expr_cond_pair,
sympy.Piecewise: piecewise,
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
sympy.logic.boolalg.ITE: if_then_else,
}

class _Node(torch.nn.Module):
Expand Down Expand Up @@ -125,7 +195,7 @@ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
self._torch_func = _func_lookup[expr.func]
except KeyError:
raise KeyError(
f"Function {expr.func} was not found in Torch function mappings."
f"Function {expr.func} was not found in Torch function mappings. "
"Please add it to extra_torch_mappings in the format, e.g., "
"{sympy.sqrt: torch.sqrt}."
)
Expand Down Expand Up @@ -153,7 +223,13 @@ def forward(self, memodict):
arg_ = arg(memodict)
memodict[arg] = arg_
args.append(arg_)
return self._torch_func(*args)
try:
return self._torch_func(*args)
except Exception as err:
# Add information about the current node to the error:
raise type(err)(
f"Error occurred in node {self._sympy_func} with args {args}"
)

class _SingleSymPyModule(torch.nn.Module):
"""SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
Expand Down
Loading