Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More refactoring #71

Merged
merged 8 commits into from
Oct 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions src/integration_tests/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def solve(a, b, c):


solve_latex = (
r"\mathrm{solve}(a, b, c) \triangleq " r"\frac{-b + \sqrt{b^{2} - 4ac}}{2a}"
r"\mathrm{solve}(a, b, c) \triangleq " r"\frac{-b + \sqrt{b^{{2}} - {4}ac}}{{2}a}"
)


Expand All @@ -37,8 +37,8 @@ def sinc(x):


sinc_latex = (
r"\mathrm{sinc}(x) \triangleq \left\{ \begin{array}{ll} 1, & \mathrm{if} \ "
r"{x = 0} \\ \frac{\sin{\left({x}\right)}}{x}, & \mathrm{otherwise} \end{array}"
r"\mathrm{sinc}(x) \triangleq \left\{ \begin{array}{ll} {1}, & \mathrm{if} \ "
r"{x = {0}} \\ \frac{\sin{\left({x}\right)}}{x}, & \mathrm{otherwise} \end{array}"
r" \right."
)

Expand All @@ -56,7 +56,7 @@ def sum_with_limit(n):


sum_with_limit_latex = (
r"\mathrm{sum_with_limit}(n) \triangleq \sum_{i=0}^{n-1} \left({i^{2}}\right)"
r"\mathrm{sum_with_limit}(n) \triangleq \sum_{i=0}^{n-1} \left({i^{{2}}}\right)"
)


Expand All @@ -66,7 +66,7 @@ def sum_with_limit_two_args(a, n):

sum_with_limit_two_args_latex = (
r"\mathrm{sum_with_limit_two_args}(a, n) "
r"\triangleq \sum_{i=a}^{n-1} \left({i^{2}}\right)"
r"\triangleq \sum_{i=a}^{n-1} \left({i^{{2}}}\right)"
)


Expand Down Expand Up @@ -96,7 +96,7 @@ def test_nested_function():
def nested(x):
return 3 * x

assert get_latex(nested) == r"\mathrm{nested}(x) \triangleq 3x"
assert get_latex(nested) == r"\mathrm{nested}(x) \triangleq {3}x"


def test_double_nested_function():
Expand All @@ -113,14 +113,14 @@ def test_use_raw_function_name():
def foo_bar():
return 42

assert str(with_latex(foo_bar)) == r"\mathrm{foo_bar}() \triangleq 42"
assert str(with_latex(foo_bar)) == r"\mathrm{foo_bar}() \triangleq {42}"
assert (
str(with_latex(foo_bar, use_raw_function_name=True))
== r"\mathrm{foo\_bar}() \triangleq 42"
== r"\mathrm{foo\_bar}() \triangleq {42}"
)
assert (
str(with_latex(use_raw_function_name=True)(foo_bar))
== r"\mathrm{foo\_bar}() \triangleq 42"
== r"\mathrm{foo\_bar}() \triangleq {42}"
)


Expand All @@ -129,9 +129,9 @@ def f(x):
a = x + x
return 3 * a

assert str(with_latex(f)) == r"a \triangleq x + x \\ \mathrm{f}(x) \triangleq 3a"
assert str(with_latex(f)) == r"a \triangleq x + x \\ \mathrm{f}(x) \triangleq {3}a"

latex_with_option = r"\mathrm{f}(x) \triangleq 3\left( x + x \right)"
latex_with_option = r"\mathrm{f}(x) \triangleq {3}\left( x + x \right)"
assert str(with_latex(f, reduce_assignments=True)) == latex_with_option
assert str(with_latex(reduce_assignments=True)(f)) == latex_with_option

Expand All @@ -143,12 +143,12 @@ def f(x):
return 3 * b

assert str(with_latex(f)) == (
r"a \triangleq x^{2} \\ b \triangleq a + a \\ \mathrm{f}(x) \triangleq 3b"
r"a \triangleq x^{{2}} \\ b \triangleq a + a \\ \mathrm{f}(x) \triangleq {3}b"
)

latex_with_option = (
r"\mathrm{f}(x) \triangleq "
r"3\left( \left( x^{2} \right) + \left( x^{2} \right) \right)"
r"{3}\left( \left( x^{{2}} \right) + \left( x^{{2}} \right) \right)"
)
assert str(with_latex(f, reduce_assignments=True)) == latex_with_option
assert str(with_latex(reduce_assignments=True)(f)) == latex_with_option
10 changes: 5 additions & 5 deletions src/latexify/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
class LatexifyError(Exception):
"""Base class of all Latexify exceptions.

Subclasses of this exception does not mean incorrect use of the library by the user,
but informs users that Latexify went into something wrong during compiling the given
functions.
These functions are usually captured by the frontend functions (e.g., `with_latex`)
Subclasses of this exception does not mean incorrect use of the library by the user
at the interface level. These exceptions inform users that Latexify went into
something wrong during processing the given functions.
These exceptions are usually captured by the frontend functions (e.g., `with_latex`)
to prevent destroying the entire program.
Errors caused by the wrong inputs should raise built-in exceptions.
Errors caused by wrong inputs should raise built-in exceptions.
"""

...
Expand Down
24 changes: 6 additions & 18 deletions src/latexify/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@

from __future__ import annotations

import ast
from collections.abc import Callable
import inspect
import textwrap
from typing import Any

import dill

from latexify import exceptions, latexify_visitor
from latexify.transformers.identifier_replacer import IdentifierReplacer
from latexify import exceptions
from latexify import latexify_visitor
from latexify import parser
from latexify import transformers


def get_latex(
Expand Down Expand Up @@ -44,19 +41,10 @@ def get_latex(
Raises:
latexify.exceptions.LatexifyError: Something went wrong during conversion.
"""
try:
source = inspect.getsource(fn)
except Exception:
# Maybe running on console.
source = dill.source.getsource(fn)

# Remove extra indentation so that ast.parse runs correctly.
source = textwrap.dedent(source)

tree = ast.parse(source)
tree = parser.parse_function(fn)

if identifiers is not None:
tree = IdentifierReplacer(identifiers).visit(tree)
tree = transformers.IdentifierReplacer(identifiers).visit(tree)

visitor = latexify_visitor.LatexifyVisitor(
use_math_symbols=use_math_symbols,
Expand Down
2 changes: 1 addition & 1 deletion src/latexify/frontend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ def very_long_name_function(very_long_name_variable):
}

assert frontend.get_latex(very_long_name_function, identifiers=identifiers) == (
r"\mathrm{f}(x) \triangleq 3x"
r"\mathrm{f}(x) \triangleq {3}x"
)
85 changes: 62 additions & 23 deletions src/latexify/latexify_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import ast
from typing import ClassVar
from typing import Any, ClassVar

from latexify import constants
from latexify import math_symbols
Expand Down Expand Up @@ -51,10 +51,10 @@ def generic_visit(self, node, action) -> str:
f"Unsupported AST: {type(node).__name__}"
)

def visit_Module(self, node, action): # pylint: disable=invalid-name
def visit_Module(self, node, action):
return self.visit(node.body[0], "multi_lines")

def visit_FunctionDef(self, node, action): # pylint: disable=invalid-name
def visit_FunctionDef(self, node, action):
name_str = str(node.name)
if self._use_raw_function_name:
name_str = name_str.replace(r"_", r"\_")
Expand Down Expand Up @@ -108,22 +108,22 @@ def visit_Assign(self, node, action):
else:
return rf"{node.targets[0].id} \triangleq {var} \\ "

def visit_Return(self, node, action): # pylint: disable=invalid-name
def visit_Return(self, node, action):
return self.visit(node.value)

def visit_Tuple(self, node, action): # pylint: disable=invalid-name
def visit_Tuple(self, node, action):
elts = [self.visit(i) for i in node.elts]
return r"\left( " + r"\space,\space ".join(elts) + r"\right) "

def visit_List(self, node, action): # pylint: disable=invalid-name
def visit_List(self, node, action):
elts = [self.visit(i) for i in node.elts]
return r"\left[ " + r"\space,\space ".join(elts) + r"\right] "

def visit_Set(self, node, action): # pylint: disable=invalid-name
def visit_Set(self, node, action):
elts = [self.visit(i) for i in node.elts]
return r"\left\{ " + r"\space,\space ".join(elts) + r"\right\} "

def visit_Call(self, node, action): # pylint: disable=invalid-name
def visit_Call(self, node, action):
"""Visit a call node."""

def _decorated_lstr_and_arg(node, callee_str, lstr):
Expand Down Expand Up @@ -171,26 +171,65 @@ def _decorated_lstr_and_arg(node, callee_str, lstr):
lstr, arg_str = _decorated_lstr_and_arg(node, callee_str, lstr)
return lstr + arg_str + rstr

def visit_Attribute(self, node, action): # pylint: disable=invalid-name
def visit_Attribute(self, node, action):
vstr = self.visit(node.value)
astr = str(node.attr)
return vstr + "." + astr

def visit_Name(self, node, action): # pylint: disable=invalid-name
def visit_Name(self, node, action):
if self._reduce_assignments and node.id in self.assign_var.keys():
return self.assign_var[node.id]

return self._math_symbol_converter.convert(str(node.id))

def visit_Constant(self, node, action): # pylint: disable=invalid-name
# for python >= 3.8
return str(node.n)
def convert_constant(self, value: Any) -> str:
"""Helper to convert constant values to LaTeX.

def visit_Num(self, node, action): # pylint: disable=invalid-name
# for python < 3.8
return str(node.n)
Args:
value: A constant value.

Returns:
The LaTeX representation of `value`.
"""
if value is None or isinstance(value, bool):
return r"\mathrm{" + str(value) + "}"
if isinstance(value, (int, float, complex)):
return "{" + str(value) + "}"
if isinstance(value, str):
return r'\textrm{"' + value + '"}'
if isinstance(value, bytes):
return r"\textrm{" + str(value) + "}"
if value is ...:
return r"{\cdots}"
raise exceptions.LatexifyNotSupportedError(
f"Unrecognized constant: {type(value).__name__}"
)

# From Python 3.8
def visit_Constant(self, node: ast.Constant, action) -> str:
return self.convert_constant(node.value)

# Until Python 3.7
def visit_Num(self, node: ast.Num, action) -> str:
return self.convert_constant(node.n)

# Until Python 3.7
def visit_Str(self, node: ast.Str, action) -> str:
return self.convert_constant(node.s)

# Until Python 3.7
def visit_Bytes(self, node: ast.Bytes, action) -> str:
return self.convert_constant(node.s)

# Until Python 3.7
def visit_NameConstant(self, node: ast.NameConstant, action) -> str:
return self.convert_constant(node.value)

# Until Python 3.7
def visit_Ellipsis(self, node: ast.Ellipsis, action) -> str:
return self.convert_constant(...)

def visit_UnaryOp(self, node, action): # pylint: disable=invalid-name
def visit_UnaryOp(self, node, action):
"""Visit a unary op node."""

def _wrap(child):
Expand All @@ -211,7 +250,7 @@ def _wrap(child):
return reprs[type(node.op)]()
return r"\mathrm{unknown\_uniop}(" + self.visit(node.operand) + ")"

def visit_BinOp(self, node, action): # pylint: disable=invalid-name
def visit_BinOp(self, node, action):
"""Visit a binary op node."""
priority = constants.BIN_OP_PRIORITY

Expand Down Expand Up @@ -263,7 +302,7 @@ def _wrap(child):
ast.NotIn: r"\notin",
}

def visit_Compare(self, node: ast.Compare, action): # pylint: disable=invalid-name
def visit_Compare(self, node: ast.Compare, action):
"""Visit a compare node."""
lhs = self.visit(node.left)
ops = [self._compare_ops[type(x)] for x in node.ops]
Expand All @@ -276,13 +315,13 @@ def visit_Compare(self, node: ast.Compare, action): # pylint: disable=invalid-n
ast.Or: r"\lor",
}

def visit_BoolOp(self, node: ast.BoolOp, action): # pylint: disable=invalid-name
def visit_BoolOp(self, node: ast.BoolOp, action):
"""Visit a BoolOp node."""
values = [rf"\left( {self.visit(x)} \right)" for x in node.values]
op = f" {self._bool_ops[type(node.op)]} "
return "{" + op.join(values) + "}"

def visit_If(self, node, action): # pylint: disable=invalid-name
def visit_If(self, node, action):
"""Visit an if node."""
latex = r"\left\{ \begin{array}{ll} "

Expand All @@ -295,7 +334,7 @@ def visit_If(self, node, action): # pylint: disable=invalid-name
latex += self.visit(node)
return latex + r", & \mathrm{otherwise} \end{array} \right."

def visit_GeneratorExp_set_bounds(self, node): # pylint: disable=invalid-name
def visit_GeneratorExp_set_bounds(self, node):
output = self.visit(node.elt)
comprehensions = [
self.visit(generator, "set_bounds") for generator in node.generators
Expand Down Expand Up @@ -344,7 +383,7 @@ def visit_Subscript(self, node: ast.Subscript, action) -> str:

return f"{{{value}_{indices_str}}}"

def visit_comprehension_set_bounds(self, node): # pylint: disable=invalid-name
def visit_comprehension_set_bounds(self, node):
"""Visit a comprehension node, which represents a for clause"""
var = self.visit(node.target)
if isinstance(node.iter, ast.Call):
Expand Down
Loading