Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Several refactoring and bugfix. #145

Merged
merged 10 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 32 additions & 26 deletions src/integration_tests/function_expansion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,118 +3,124 @@
from integration_tests import utils


def test_expand_atan2_function() -> None:
def test_atan2() -> None:
def solve(x, y):
return math.atan2(y, x)

latex = r"\mathrm{solve}(x, y) = \arctan{\left({\frac{y}{x}}\right)}"
latex = (
r"\mathrm{solve}(x, y) ="
r" \arctan \mathopen{}\left( \frac{y}{x} \mathclose{}\right)"
)
utils.check_function(solve, latex, expand_functions={"atan2"})


def test_expand_atan2_nested_function() -> None:
def test_atan2_nested() -> None:
def solve(x, y):
return math.atan2(math.exp(y), math.exp(x))

latex = r"\mathrm{solve}(x, y) = \arctan{\left({\frac{e^{y}}{e^{x}}}\right)}"
latex = (
r"\mathrm{solve}(x, y) ="
r" \arctan \mathopen{}\left( \frac{e^{y}}{e^{x}} \mathclose{}\right)"
)
utils.check_function(solve, latex, expand_functions={"atan2", "exp"})


def test_expand_exp_function() -> None:
def test_exp() -> None:
def solve(x):
return math.exp(x)

latex = r"\mathrm{solve}(x) = e^{x}"
utils.check_function(solve, latex, expand_functions={"exp"})


def test_expand_exp_nested_function() -> None:
def test_exp_nested() -> None:
def solve(x):
return math.exp(math.exp(x))

latex = r"\mathrm{solve}(x) = e^{e^{x}}"
utils.check_function(solve, latex, expand_functions={"exp"})


def test_expand_exp2_function() -> None:
def test_exp2() -> None:
def solve(x):
return math.exp2(x)

latex = r"\mathrm{solve}(x) = {2}^{x}"
utils.check_function(solve, latex, expand_functions={"exp2"})


def test_expand_exp2_nested_function() -> None:
def test_exp2_nested() -> None:
def solve(x):
return math.exp2(math.exp2(x))

latex = r"\mathrm{solve}(x) = {2}^{{2}^{x}}"
utils.check_function(solve, latex, expand_functions={"exp2"})


def test_expand_expm1_function() -> None:
def test_expm1() -> None:
def solve(x):
return math.expm1(x)

latex = r"\mathrm{solve}(x) = \exp{\left({x}\right)} - {1}"
latex = r"\mathrm{solve}(x) = \exp x - {1}"
utils.check_function(solve, latex, expand_functions={"expm1"})


def test_expand_expm1_nested_function() -> None:
def test_expm1_nested() -> None:
def solve(x, y, z):
return math.expm1(math.pow(y, z))

latex = r"\mathrm{solve}(x, y, z) = e^{y^{z}} - {1}"
utils.check_function(solve, latex, expand_functions={"expm1", "exp", "pow"})


def test_expand_hypot_function_without_attribute_access() -> None:
def test_hypot_without_attribute() -> None:
from math import hypot

def solve(x, y, z):
return hypot(x, y, z)

latex = r"\mathrm{solve}(x, y, z) = \sqrt{x^{{2}} + y^{{2}} + z^{{2}}}"
latex = r"\mathrm{solve}(x, y, z) = \sqrt{ x^{{2}} + y^{{2}} + z^{{2}} }"
utils.check_function(solve, latex, expand_functions={"hypot"})


def test_expand_hypot_function() -> None:
def test_hypot() -> None:
def solve(x, y, z):
return math.hypot(x, y, z)

latex = r"\mathrm{solve}(x, y, z) = \sqrt{x^{{2}} + y^{{2}} + z^{{2}}}"
latex = r"\mathrm{solve}(x, y, z) = \sqrt{ x^{{2}} + y^{{2}} + z^{{2}} }"
utils.check_function(solve, latex, expand_functions={"hypot"})


def test_expand_hypot_nested_function() -> None:
def test_hypot_nested() -> None:
def solve(a, b, x, y):
return math.hypot(math.hypot(a, b), x, y)

latex = (
r"\mathrm{solve}(a, b, x, y) = "
r"\sqrt{"
r"\sqrt{a^{{2}} + b^{{2}}}^{{2}} + "
r"x^{{2}} + y^{{2}}}"
r"\mathrm{solve}(a, b, x, y) ="
r" \sqrt{ \sqrt{ a^{{2}} + b^{{2}} }^{{2}} + x^{{2}} + y^{{2}} }"
)
utils.check_function(solve, latex, expand_functions={"hypot"})


def test_expand_log1p_function() -> None:
def test_log1p() -> None:
def solve(x):
return math.log1p(x)

latex = r"\mathrm{solve}(x) = \log{\left({{1} + x}\right)}"
latex = r"\mathrm{solve}(x) = \log \mathopen{}\left( {1} + x \mathclose{}\right)"
utils.check_function(solve, latex, expand_functions={"log1p"})


def test_expand_log1p_nested_function() -> None:
def test_log1p_nested() -> None:
def solve(x):
return math.log1p(math.exp(x))

latex = r"\mathrm{solve}(x) = \log{\left({{1} + e^{x}}\right)}"
latex = (
r"\mathrm{solve}(x) = \log \mathopen{}\left( {1} + e^{x} \mathclose{}\right)"
)
utils.check_function(solve, latex, expand_functions={"log1p", "exp"})


def test_expand_pow_nested_function() -> None:
def test_pow_nested() -> None:
def solve(w, x, y, z):
return math.pow(math.pow(w, x), math.pow(y, z))

Expand All @@ -125,7 +131,7 @@ def solve(w, x, y, z):
utils.check_function(solve, latex, expand_functions={"pow"})


def test_expand_pow_function() -> None:
def test_pow() -> None:
def solve(x, y):
return math.pow(x, y)

Expand Down
8 changes: 4 additions & 4 deletions src/integration_tests/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_quadratic_solution() -> None:
def solve(a, b, c):
return (-b + math.sqrt(b**2 - 4 * a * c)) / (2 * a)

latex = r"\mathrm{solve}(a, b, c) = \frac{-b + \sqrt{b^{{2}} - {4} a c}}{{2} a}"
latex = r"\mathrm{solve}(a, b, c) = \frac{-b + \sqrt{ b^{{2}} - {4} a c }}{{2} a}"
utils.check_function(solve, latex)


Expand All @@ -26,7 +26,7 @@ def sinc(x):
r"\mathrm{sinc}(x) = "
r"\left\{ \begin{array}{ll} "
r"{1}, & \mathrm{if} \ "
r"{x = {0}} \\ \frac{\sin{\left({x}\right)}}{x}, & \mathrm{otherwise} "
r"{x = {0}} \\ \frac{\sin x}{x}, & \mathrm{otherwise} "
r"\end{array} \right."
)
utils.check_function(sinc, latex)
Expand Down Expand Up @@ -201,9 +201,9 @@ def sigmoid(x):
sigmoid,
(
r"\mathrm{sigmoid}(x) = \left\{ \begin{array}{ll} "
r"\frac{{1}}{{1} + \exp{\left({-x}\right)}}, & "
r"\frac{{1}}{{1} + \exp \mathopen{}\left( -x \mathclose{}\right)}, & "
r"\mathrm{if} \ {x > {0}} \\ "
r"\frac{\exp{\left({x}\right)}}{\exp{\left({x}\right)} + {1}}, & "
r"\frac{\exp x}{\exp x + {1}}, & "
r"\mathrm{otherwise} "
r"\end{array} \right."
),
Expand Down
2 changes: 1 addition & 1 deletion src/latexify/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def make_name(id: str) -> ast.Name:
return ast.Name(id=id, ctx=ast.Load())


def make_attribute(value: ast.Expr, attr: str):
def make_attribute(value: ast.expr, attr: str):
"""Generates a new Attribute node.

Args:
Expand Down
103 changes: 75 additions & 28 deletions src/latexify/codegen/function_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@
ast.Or: 10,
}

# NOTE(odashi):
# Function invocation is treated as a unary operator with a higher precedence.
# This ensures that the argument with a unary operator is wrapped:
# exp(x) --> \exp x
# exp(-x) --> \exp (-x)
# -exp(x) --> - \exp x
_CALL_PRECEDENCE = _PRECEDENCES[ast.UAdd] + 1


def _get_precedence(node: ast.AST) -> int:
"""Obtains the precedence of the subtree.
Expand All @@ -63,6 +71,9 @@ def _get_precedence(node: ast.AST) -> int:
If `node` is a subtree with some operator, returns the precedence of the
operator. Otherwise, returns a number larger enough from other precedences.
"""
if isinstance(node, ast.Call):
return _CALL_PRECEDENCE

if isinstance(node, (ast.BoolOp, ast.BinOp, ast.UnaryOp)):
return _PRECEDENCES[type(node.op)]

Expand Down Expand Up @@ -289,38 +300,34 @@ def visit_Return(self, node: ast.Return) -> str:

def visit_Tuple(self, node: ast.Tuple) -> str:
elts = [self.visit(i) for i in node.elts]
return (
r"\mathopen{}\left( "
+ r"\space,\space ".join(elts)
+ r"\mathclose{}\right) "
)
return r"\mathopen{}\left( " + r", ".join(elts) + r" \mathclose{}\right)"

def visit_List(self, node: ast.List) -> str:
elts = [self.visit(i) for i in node.elts]
return r"\left[ " + r"\space,\space ".join(elts) + r"\right] "
return r"\mathopen{}\left[ " + r", ".join(elts) + r" \mathclose{}\right]"

def visit_Set(self, node: ast.Set) -> str:
elts = [self.visit(i) for i in node.elts]
return r"\left\{ " + r"\space,\space ".join(elts) + r"\right\} "
return r"\mathopen{}\left\{ " + r", ".join(elts) + r" \mathclose{}\right\}"

def visit_ListComp(self, node: ast.ListComp) -> str:
generators = [self.visit(comp) for comp in node.generators]
return (
r"\left[ "
r"\mathopen{}\left[ "
+ self.visit(node.elt)
+ r" \mid "
+ ", ".join(generators)
+ r" \right]"
+ r" \mathclose{}\right]"
)

def visit_SetComp(self, node: ast.SetComp) -> str:
generators = [self.visit(comp) for comp in node.generators]
return (
r"\left\{ "
r"\mathopen{}\left\{ "
+ self.visit(node.elt)
+ r" \mid "
+ ", ".join(generators)
+ r" \right\}"
+ r" \mathclose{}\right\}"
)

def visit_comprehension(self, node: ast.comprehension) -> str:
Expand All @@ -347,10 +354,16 @@ def _generate_sum_prod(self, node: ast.Call) -> str | None:
return None

name = ast_utils.extract_function_name_or_none(node)
assert name is not None
assert name in ("fsum", "sum", "prod")

command = {
"fsum": r"\sum",
"sum": r"\sum",
"prod": r"\prod",
}[name]

elt, scripts = self._get_sum_prod_info(node.args[0])
scripts_str = [rf"\{name}_{{{lo}}}^{{{up}}}" for lo, up in scripts]
scripts_str = [rf"{command}_{{{lo}}}^{{{up}}}" for lo, up in scripts]
return (
" ".join(scripts_str)
+ rf" \mathopen{{}}\left({{{elt}}}\mathclose{{}}\right)"
Expand Down Expand Up @@ -403,7 +416,7 @@ def visit_Call(self, node: ast.Call) -> str:
func_name = ast_utils.extract_function_name_or_none(node)

# Special treatments for some functions.
if func_name in ("sum", "prod"):
if func_name in ("fsum", "sum", "prod"):
special_latex = self._generate_sum_prod(node)
elif func_name in ("array", "ndarray"):
special_latex = self._generate_matrix(node)
Expand All @@ -413,17 +426,38 @@ def visit_Call(self, node: ast.Call) -> str:
if special_latex is not None:
return special_latex

# Function signature (possibly an expression).
default_func_str = self.visit(node.func)

# Obtains wrapper syntax: sqrt -> "\sqrt{" and "}"
lstr, rstr = constants.BUILTIN_FUNCS.get(
func_name,
(default_func_str + r"\mathopen{}\left(", r"\mathclose{}\right)"),
)
# Obtains the codegen rule.
rule = constants.BUILTIN_FUNCS.get(func_name)
if rule is None:
rule = constants.FunctionRule(self.visit(node.func))

if rule.is_unary and len(node.args) == 1:
# Unary function. Applies the same wrapping policy with the unary operators.
# NOTE(odashi):
# Factorial "x!" is treated as a special case: it requires both inner/outer
# parentheses for correct interpretation.
precedence = _get_precedence(node)
arg = node.args[0]
force_wrap = isinstance(arg, ast.Call) and (
func_name == "factorial"
or ast_utils.extract_function_name_or_none(arg) == "factorial"
)
arg_latex = self._wrap_operand(arg, precedence, force_wrap)
elements = [rule.left, arg_latex, rule.right]
else:
arg_latex = ", ".join(self.visit(arg) for arg in node.args)
if rule.is_wrapped:
elements = [rule.left, arg_latex, rule.right]
else:
elements = [
rule.left,
r"\mathopen{}\left(",
arg_latex,
r"\mathclose{}\right)",
rule.right,
]

arg_strs = [self.visit(arg) for arg in node.args]
return lstr + ", ".join(arg_strs) + rstr
return " ".join(x for x in elements if x)

def visit_Attribute(self, node: ast.Attribute) -> str:
vstr = self.visit(node.value)
Expand Down Expand Up @@ -481,20 +515,26 @@ def visit_NameConstant(self, node: ast.NameConstant) -> str:
def visit_Ellipsis(self, node: ast.Ellipsis) -> str:
return self._convert_constant(...)

def _wrap_operand(self, child: ast.expr, parent_prec: int) -> str:
def _wrap_operand(
self, child: ast.expr, parent_prec: int, force_wrap: bool = False
) -> str:
"""Wraps the operand subtree with parentheses.

Args:
child: Operand subtree.
parent_prec: Precedence of the parent operator.
force_wrap: Whether to wrap the operand or not when the precedence is equal.

Returns:
LaTeX form of `child`, with or without surrounding parentheses.
"""
latex = self.visit(child)
if _get_precedence(child) >= parent_prec:
return latex
return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)"
child_prec = _get_precedence(child)

if child_prec < parent_prec or force_wrap and child_prec == parent_prec:
return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)"

return latex

def _wrap_binop_operand(
self,
Expand All @@ -515,6 +555,13 @@ def _wrap_binop_operand(
if not operand_rule.wrap:
return self.visit(child)

if isinstance(child, ast.Call):
rule = constants.BUILTIN_FUNCS.get(
ast_utils.extract_function_name_or_none(child)
)
if rule is not None and rule.is_wrapped:
return self.visit(child)

if not isinstance(child, ast.BinOp):
return self._wrap_operand(child, parent_prec)

Expand Down
Loading