Skip to content

Commit

Permalink
Merge pull request #5 from google/remove-iddict
Browse files Browse the repository at this point in the history
Removed _IdDict
  • Loading branch information
patrick-kidger authored Aug 3, 2022
2 parents 5c6e80e + 3015c18 commit 929ff75
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 29 deletions.
2 changes: 1 addition & 1 deletion sympy2jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
from .sympy_module import SymbolicModule


__version__ = "0.0.3"
__version__ = "0.0.4"
49 changes: 21 additions & 28 deletions sympy2jax/sympy_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,40 +84,33 @@ def fn_(*args):
assert len(_reverse_lookup) == len(_lookup)


class _IdDict:
def __init__(self, **values):
self._dict = {id(k): v for k, v in values.items()}

def __getitem__(self, item):
return self._dict[id(item)]

def __setitem__(self, item, value):
self._dict[id(item)] = value


class _AbstractNode(eqx.Module):
@abc.abstractmethod
def __call__(self, memodict: _IdDict):
def __call__(self, memodict: dict):
...

@abc.abstractmethod
def sympy(self, memodict: _IdDict, func_lookup: dict) -> sympy.Expr:
def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
...

# Comparisons based on identity
__hash__ = object.__hash__
__eq__ = object.__eq__


class _Symbol(_AbstractNode):
_name: str

def __init__(self, expr: sympy.Expr):
self._name = expr.name

def __call__(self, memodict: _IdDict):
def __call__(self, memodict: dict):
try:
return memodict[self._name]
except KeyError as e:
raise KeyError(f"Missing input for symbol {self._name}") from e

def sympy(self, memodict: _IdDict, func_lookup: dict) -> sympy.Expr:
def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
# memodict not needed as sympy deduplicates internally
return sympy.Symbol(self._name)

Expand All @@ -136,10 +129,10 @@ def __init__(self, expr: sympy.Expr, make_array: bool):
assert isinstance(expr, sympy.Integer)
self._value = _maybe_array(int(expr), make_array)

def __call__(self, memodict: _IdDict):
def __call__(self, memodict: dict):
return self._value

def sympy(self, memodict: _IdDict, func_lookup: dict) -> sympy.Expr:
def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
# memodict not needed as sympy deduplicates internally
return sympy.Integer(self._value.item())

Expand All @@ -151,10 +144,10 @@ def __init__(self, expr: sympy.Expr, make_array: bool):
assert isinstance(expr, sympy.Float)
self._value = _maybe_array(float(expr), make_array)

def __call__(self, memodict: _IdDict):
def __call__(self, memodict: dict):
return self._value

def sympy(self, memodict: _IdDict, func_lookup: dict) -> sympy.Expr:
def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
# memodict not needed as sympy deduplicates internally
return sympy.Float(self._value.item())

Expand All @@ -168,10 +161,10 @@ def __init__(self, expr: sympy.Expr, make_array: bool):
self._numerator = _maybe_array(int(expr.numerator), make_array)
self._denominator = _maybe_array(int(expr.denominator), make_array)

def __call__(self, memodict: _IdDict):
def __call__(self, memodict: dict):
return self._numerator / self._denominator

def sympy(self, memodict: _IdDict, func_lookup: dict) -> sympy.Expr:
def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
# memodict not needed as sympy deduplicates internally
return sympy.Integer(self._numerator) / sympy.Integer(self._denominator)

Expand All @@ -181,7 +174,7 @@ class _Func(_AbstractNode):
_args: list

def __init__(
self, expr: sympy.Expr, memodict: _IdDict, func_lookup: dict, make_array: bool
self, expr: sympy.Expr, memodict: dict, func_lookup: dict, make_array: bool
):
try:
self._func = func_lookup[expr.func]
Expand All @@ -191,7 +184,7 @@ def __init__(
_sympy_to_node(arg, memodict, func_lookup, make_array) for arg in expr.args
]

def __call__(self, memodict: _IdDict):
def __call__(self, memodict: dict):
args = []
for arg in self._args:
try:
Expand All @@ -202,7 +195,7 @@ def __call__(self, memodict: _IdDict):
args.append(arg_call)
return self._func(*args)

def sympy(self, memodict: _IdDict, func_lookup: dict) -> sympy.Expr:
def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
try:
return memodict[self]
except KeyError:
Expand All @@ -214,7 +207,7 @@ def sympy(self, memodict: _IdDict, func_lookup: dict) -> sympy.Expr:


def _sympy_to_node(
expr: sympy.Expr, memodict: _IdDict, func_lookup: dict, make_array: bool
expr: sympy.Expr, memodict: dict, func_lookup: dict, make_array: bool
) -> _AbstractNode:
try:
return memodict[expr]
Expand Down Expand Up @@ -257,7 +250,7 @@ def __init__(
self.has_extra_funcs = True
_convert = ft.partial(
_sympy_to_node,
memodict=_IdDict(),
memodict=dict(),
func_lookup=lookup,
make_array=make_array,
)
Expand All @@ -268,11 +261,11 @@ def sympy(self) -> sympy.Expr:
raise NotImplementedError(
"SymbolicModule cannot be converted back to SymPy if `extra_funcs` is passed"
)
memodict = _IdDict()
memodict = dict()
return jax.tree_map(
lambda n: n.sympy(memodict, _reverse_lookup), self.nodes, is_leaf=_is_node
)

def __call__(self, **symbols):
memodict = _IdDict(**symbols)
memodict = symbols
return jax.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node)

0 comments on commit 929ff75

Please sign in to comment.