Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add assign feature #51

Merged
merged 13 commits into from
Oct 20, 2022
36 changes: 36 additions & 0 deletions src/integration_tests/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,39 @@ def inner(y):
return inner

assert get_latex(nested(3)) == r"\mathrm{inner}(y) \triangleq xy"


def test_assign_feature():
@with_latex
def f(x):
return abs(x) * math.exp(math.sqrt(x))

@with_latex
def g(x):
a = abs(x)
b = math.exp(math.sqrt(x))
return a * b

@with_latex(reduce_assignments=False)
def h(x):
a = abs(x)
b = math.exp(math.sqrt(x))
return a * b

assert str(f) == r"\mathrm{f}(x) \triangleq \left|{x}\right|\exp{\left({\sqrt{x}}\right)}"
assert str(g) == r"\mathrm{g}(x) \triangleq \left( \left|{x}\right| \right)\left( \exp{\left({\sqrt{x}}\right)} \right)"
assert str(h) == r"a \triangleq \left|{x}\right| \\ b \triangleq \exp{\left({\sqrt{x}}\right)} \\ \mathrm{h}(x) \triangleq ab"

@with_latex(reduce_assignments=True)
def f(x):
a = math.sqrt(math.exp(x))
return abs(x) * math.log10(a)

assert str(f) == r"\mathrm{f}(x) \triangleq \left|{x}\right|\log_{10}{\left({\left( \sqrt{\exp{\left({x}\right)}} \right)}\right)}"

@with_latex(reduce_assignments=False)
def f(x):
a = math.sqrt(math.exp(x))
return abs(x) * math.log10(a)

assert str(f) == r"a \triangleq \sqrt{\exp{\left({x}\right)}} \\ \mathrm{f}(x) \triangleq \left|{x}\right|\log_{10}{\left({a}\right)}"
82 changes: 65 additions & 17 deletions src/latexify/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@
class LatexifyVisitor(node_visitor_base.NodeVisitorBase):
"""Latexify AST visitor."""

def __init__(self, math_symbol=False, raw_func_name=False):
def __init__(self, math_symbol=False, raw_func_name=False,
reduce_assignments=True):
self.math_symbol = math_symbol
self.raw_func_name = (
raw_func_name # True:do not treat underline as label of subscript(#31)
)
self.reduce_assignments = reduce_assignments
self.assign_var = {}
super().__init__()

def _parse_math_symbols(self, val: str) -> str:
Expand All @@ -50,7 +53,7 @@ def generic_visit(self, node, action):
def visit_Module(self, node, action): # pylint: disable=invalid-name
del action

return self.visit(node.body[0])
return self.visit(node.body[0], 'multi_lines')

def visit_FunctionDef(self, node, action): # pylint: disable=invalid-name
del action
Expand All @@ -59,8 +62,46 @@ def visit_FunctionDef(self, node, action): # pylint: disable=invalid-name
if self.raw_func_name:
name_str = name_str.replace(r"_", r"\_") # fix #31
arg_strs = [self._parse_math_symbols(str(arg.arg)) for arg in node.args.args]
body_str = self.visit(node.body[0])
return name_str + "(" + ", ".join(arg_strs) + r") \triangleq " + body_str

body_str = ''
assign_vars = []
for el in node.body:
if isinstance(el, ast.FunctionDef):
if self.reduce_assignments:
body_str = self.visit(el, 'in_line')
self.assign_var[el.name] = rf'\left( {body_str} \right)'
else:
body_str = self.visit(el, 'multi_lines')
assign_vars.append(body_str + r' \\ ')
else:
body_str = self.visit(el)
if not self.reduce_assignments and isinstance(el, ast.Assign):
assign_vars.append(body_str)
elif isinstance(el, ast.Return):
break
if body_str == '':
raise ValueError('`return` missing')

return name_str, arg_strs, assign_vars, body_str

def visit_FunctionDef_multi_lines(self, node):
name_str, arg_strs, assign_vars, body_str = self.visit_FunctionDef(node, None)
print(name_str, arg_strs, assign_vars, body_str)
return "".join(assign_vars) + name_str + "(" + ", ".join(arg_strs) + r") \triangleq " + body_str

def visit_FunctionDef_in_line(self, node):
name_str, arg_strs, assign_vars, body_str = self.visit_FunctionDef(node, None)
return "".join(assign_vars) + body_str

def visit_Assign(self, node, action):
del action

var = self.visit(node.value)
if self.reduce_assignments:
self.assign_var[node.targets[0].id] = rf'\left( {var} \right)'
return None
else:
return rf"{node.targets[0].id} \triangleq {var} \\ "

def visit_Return(self, node, action): # pylint: disable=invalid-name
del action
Expand Down Expand Up @@ -116,19 +157,23 @@ def _decorated_lstr_and_arg(node, callee_str, lstr):
del action

callee_str = self.visit(node.func)

for prefix in constants.PREFIXES:
if callee_str.startswith(f"{prefix}."):
callee_str = callee_str[len(prefix) + 1:]
break

lstr, rstr = constants.BUILTIN_CALLEES.get(callee_str, (None, None))
if lstr is None:
lstr = r"\mathrm{" + callee_str + r"}\left("
rstr = r"\right)"

lstr, arg_str = _decorated_lstr_and_arg(node, callee_str, lstr)
return lstr + arg_str + rstr
if self.reduce_assignments \
and (getattr(node.func, 'id', None) in self.assign_var.keys()
or getattr(node.func, 'attr', None) in self.assign_var.keys()):
return callee_str
else:
for prefix in constants.PREFIXES:
if callee_str.startswith(f"{prefix}."):
callee_str = callee_str[len(prefix) + 1:]
break

lstr, rstr = constants.BUILTIN_CALLEES.get(callee_str, (None, None))
if lstr is None:
lstr = r"\mathrm{" + callee_str + r"}\left("
rstr = r"\right)"

lstr, arg_str = _decorated_lstr_and_arg(node, callee_str, lstr)
return lstr + arg_str + rstr

def visit_Attribute(self, node, action): # pylint: disable=invalid-name
del action
Expand All @@ -140,6 +185,9 @@ def visit_Attribute(self, node, action): # pylint: disable=invalid-name
def visit_Name(self, node, action): # pylint: disable=invalid-name
del action

if self.reduce_assignments and node.id in self.assign_var.keys():
return self.assign_var[node.id]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some concerns around this:

  • It may cause exponential expansion. An extreme case:

    def fn(x):
        a = x + x
        b = a + a
        c = b + b
        return c

    This may return $x+x+x+x+x+x+x+x$.

  • If the assigned string has some structure, we need to add appropriate parentheses around the string.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid exponential expansion, I suggest to add an option to activate or not the assignment.
For example:

@latexify.with_latex(assign_mode=False)
def h(x):
    a = abs(x)
    b = math.exp(math.sqrt(x))
    return a * b

will return a = \left|{x}\right| \\ b = \exp{\left({\sqrt{x}}\right)} \\ \mathrm{h}(x) \triangleq ab

image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also a problem when assignment refers to several numbers with a product like:

@with_latex
def f(x):
    a = 1
    b = 2
    return a * b

return \mathrm{f}(x) \triangleq 12

image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. I think reduce_assignments is good for the option name (assign_mode boolean option is somewhat confusing to users)


return self._parse_math_symbols(str(node.id))

def visit_Constant(self, node, action): # pylint: disable=invalid-name
Expand Down