diff --git a/src/integration_tests/regression_test.py b/src/integration_tests/regression_test.py index 6ca41a7..981d8f9 100644 --- a/src/integration_tests/regression_test.py +++ b/src/integration_tests/regression_test.py @@ -70,24 +70,23 @@ def sum_with_limit_two_args(a, n): ) -func_and_latex_str_list = [ - (solve, solve_latex, None), - (sinc, sinc_latex, None), - (xtimesbeta, xtimesbeta_latex, True), - (xtimesbeta, xtimesbeta_latex_no_symbols, False), - (sum_with_limit, sum_with_limit_latex, None), - (sum_with_limit_two_args, sum_with_limit_two_args_latex, None), -] - - -@pytest.mark.parametrize("func, expected_latex, math_symbol", func_and_latex_str_list) -def test_with_latex_to_str(func, expected_latex, math_symbol): +@pytest.mark.parametrize( + "func, expected_latex, use_math_symbols", + [ + (solve, solve_latex, None), + (sinc, sinc_latex, None), + (xtimesbeta, xtimesbeta_latex, True), + (xtimesbeta, xtimesbeta_latex_no_symbols, False), + (sum_with_limit, sum_with_limit_latex, None), + (sum_with_limit_two_args, sum_with_limit_two_args_latex, None), + ], +) +def test_with_latex_to_str(func, expected_latex, use_math_symbols): """Test with_latex to str.""" - # pylint: disable=protected-access - if math_symbol is None: + if use_math_symbols is None: latexified_function = with_latex(func) else: - latexified_function = with_latex(math_symbol=math_symbol)(func) + latexified_function = with_latex(use_math_symbols=use_math_symbols)(func) assert str(latexified_function) == expected_latex expected_repr = r"$$ \displaystyle %s $$" % expected_latex assert latexified_function._repr_latex_() == expected_repr @@ -110,57 +109,46 @@ def inner(y): assert get_latex(nested(3)) == r"\mathrm{inner}(y) \triangleq xy" -def test_assign_feature(): - @with_latex - def f(x): - return abs(x) * math.exp(math.sqrt(x)) - - @with_latex - def g(x): - a = abs(x) - b = math.exp(math.sqrt(x)) - return a * b - - @with_latex(reduce_assignments=False) - def h(x): - a = abs(x) - b = math.exp(math.sqrt(x)) - return a * b - - assert str(f) == ( - r"\mathrm{f}(x) \triangleq \left|{x}\right|\exp{\left({\sqrt{x}}\right)}" - ) - assert str(g) == ( - r"\mathrm{g}(x) \triangleq " - r"\left( " - r"\left|{x}\right| \right)\left( \exp{\left({\sqrt{x}}\right)} " - r"\right)" +def test_use_raw_function_name(): + def foo_bar(): + return 42 + + assert str(with_latex(foo_bar)) == r"\mathrm{foo_bar}() \triangleq 42" + assert ( + str(with_latex(foo_bar, use_raw_function_name=True)) + == r"\mathrm{foo\_bar}() \triangleq 42" ) - assert str(h) == ( - r"a \triangleq " - r"\left|{x}\right| \\ " - r"b \triangleq \exp{\left({\sqrt{x}}\right)} \\ " - r"\mathrm{h}(x) \triangleq ab" + assert ( + str(with_latex(use_raw_function_name=True)(foo_bar)) + == r"\mathrm{foo\_bar}() \triangleq 42" ) - @with_latex(reduce_assignments=True) + +def test_reduce_assignments(): def f(x): - a = math.sqrt(math.exp(x)) - return abs(x) * math.log10(a) + a = x + x + return 3 * a + + assert str(with_latex(f)) == r"a \triangleq x + x \\ \mathrm{f}(x) \triangleq 3a" + + latex_with_option = r"\mathrm{f}(x) \triangleq 3\left( x + x \right)" + assert str(with_latex(f, reduce_assignments=True)) == latex_with_option + assert str(with_latex(reduce_assignments=True)(f)) == latex_with_option - assert str(f) == ( - r"\mathrm{f}(x) \triangleq " - r"\left|{x}\right|" - r"\log_{10}{\left({\left( \sqrt{\exp{\left({x}\right)}} \right)}\right)}" - ) - @with_latex(reduce_assignments=False) +def test_reduce_assignments_double(): def f(x): - a = math.sqrt(math.exp(x)) - return abs(x) * math.log10(a) + a = x**2 + b = a + a + return 3 * b - assert str(f) == ( - r"a \triangleq " - r"\sqrt{\exp{\left({x}\right)}} \\ " - r"\mathrm{f}(x) \triangleq \left|{x}\right|\log_{10}{\left({a}\right)}" + assert str(with_latex(f)) == ( + r"a \triangleq x^{2} \\ b \triangleq a + a \\ \mathrm{f}(x) \triangleq 3b" + ) + + latex_with_option = ( + r"\mathrm{f}(x) \triangleq " + r"3\left( \left( x^{2} \right) + \left( x^{2} \right) \right)" ) + assert str(with_latex(f, reduce_assignments=True)) == latex_with_option + assert str(with_latex(reduce_assignments=True)(f)) == latex_with_option diff --git a/src/latexify/__init__.py b/src/latexify/__init__.py index a5cbea1..282c9b4 100644 --- a/src/latexify/__init__.py +++ b/src/latexify/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. """Latexify toplevel module.""" -from latexify import core +from latexify import frontend -get_latex = core.get_latex -with_latex = core.with_latex +get_latex = frontend.get_latex +with_latex = frontend.with_latex diff --git a/src/latexify/constants.py b/src/latexify/constants.py index 6013532..b918a09 100644 --- a/src/latexify/constants.py +++ b/src/latexify/constants.py @@ -25,54 +25,6 @@ class Actions(NamedTuple): actions = Actions() -MATH_SYMBOLS = { - "aleph", - "alpha", - "beta", - "beth", - "chi", - "daleth", - "delta", - "digamma", - "epsilon", - "eta", - "gamma", - "gimel", - "iota", - "kappa", - "lambda", - "mu", - "nu", - "omega", - "phi", - "pi", - "psi", - "rho", - "sigma", - "tau", - "theta", - "upsilon", - "varepsilon", - "varkappa", - "varphi", - "varpi", - "varrho", - "varsigma", - "vartheta", - "xi", - "zeta", - "Delta", - "Gamma", - "Lambda", - "Omega", - "Phi", - "Pi", - "Sigma", - "Theta", - "Upsilon", - "Xi", -} - PREFIXES = ["math", "numpy", "np"] BUILTIN_CALLEES = { diff --git a/src/latexify/core_test.py b/src/latexify/core_test.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/latexify/frontend.py b/src/latexify/frontend.py new file mode 100644 index 0000000..380a200 --- /dev/null +++ b/src/latexify/frontend.py @@ -0,0 +1,113 @@ +"""Frontend interfaces of latexify.""" + +from __future__ import annotations + +import ast +from collections.abc import Callable +import inspect +import textwrap +from typing import Any + +import dill + +from latexify import latexify_visitor + + +def get_latex( + fn: Callable[..., Any], + *, + use_math_symbols: bool = False, + use_raw_function_name: bool = False, + reduce_assignments: bool = False, +) -> str: + """Obtains LaTeX description from the function's source. + + Args: + fn: Reference to a function to analyze. + use_math_symbols: Whether to convert identifiers with a math symbol surface + (e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha"). + use_raw_function_name: Whether to keep underscores "_" in the function name, + or convert it to subscript. + reduce_assignments: If True, assignment statements are used to synthesize + the final expression. + + Returns: + Generatee LaTeX description. + """ + try: + source = inspect.getsource(fn) + except Exception: + # Maybe running on console. + source = dill.source.getsource(fn) + + # Remove extra indentation so that ast.parse runs correctly. + source = textwrap.dedent(source) + + tree = ast.parse(source) + + visitor = latexify_visitor.LatexifyVisitor( + use_math_symbols=use_math_symbols, + use_raw_function_name=use_raw_function_name, + reduce_assignments=reduce_assignments, + ) + + return visitor.visit(tree) + + +class LatexifiedFunction: + """Function with latex representation.""" + + def __init__(self, fn, **kwargs): + self._fn = fn + self._str = get_latex(fn, **kwargs) + + @property + def __doc__(self): + return self._fn.__doc__ + + @__doc__.setter + def __doc__(self, val): + self._fn.__doc__ = val + + @property + def __name__(self): + return self._fn.__name__ + + @__name__.setter + def __name__(self, val): + self._fn.__name__ = val + + def __call__(self, *args): + return self._fn(*args) + + def __str__(self): + return self._str + + def _repr_latex_(self): + """IPython hook to display LaTeX visualization.""" + return r"$$ \displaystyle " + self._str + " $$" + + +def with_latex(*args, **kwargs) -> Callable[[Callable[..., Any]], LatexifiedFunction]: + """Translate a function with latex representation. + + This function works with or without specifying the target function as the positional + argument. The following two syntaxes works similarly. + - with_latex(fn, **kwargs) + - with_latex(**kwargs)(fn) + + Args: + *args: No argument, or a callable. + **kwargs: Arguments to control behavior. See also get_latex(). + + Returns: + - If the target function is passed directly, returns the wrapped function. + - Otherwise, returns the wrapper function with given settings. + """ + if len(args) == 1 and isinstance(args[0], Callable): + return LatexifiedFunction(args[0], **kwargs) + + def wrapper(fn): + return LatexifiedFunction(fn, **kwargs) + + return wrapper diff --git a/src/latexify/core.py b/src/latexify/latexify_visitor.py similarity index 74% rename from src/latexify/core.py rename to src/latexify/latexify_visitor.py index b60a117..3094a3d 100644 --- a/src/latexify/core.py +++ b/src/latexify/latexify_visitor.py @@ -1,72 +1,70 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This is very scratchy and supports only limited portion of Python functions. -"""Latexify core module.""" +"""Latexify core AST visitor.""" -import ast -import inspect -import textwrap +from __future__ import annotations -import dill +import ast from latexify import constants +from latexify import math_symbols from latexify import node_visitor_base class LatexifyVisitor(node_visitor_base.NodeVisitorBase): """Latexify AST visitor.""" - def __init__(self, math_symbol=False, raw_func_name=False, reduce_assignments=True): - self.math_symbol = math_symbol - self.raw_func_name = ( - raw_func_name # True:do not treat underline as label of subscript(#31) + _math_symbol_converter: math_symbols.MathSymbolConverter + _use_raw_function_name: bool + _reduce_assignments: bool + + # TODO(odashi): This variable can be function-level. Remove it from the object. + _assign_var: dict[str, str] + + def __init__( + self, + *, + use_math_symbols: bool = False, + use_raw_function_name: bool = False, + reduce_assignments: bool = True, + ): + """Initializer. + + Args: + use_math_symbols: Whether to convert identifiers with a math symbol surface + (e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha"). + use_raw_function_name: Whether to keep underscores "_" in the function name, + or convert it to subscript. + reduce_assignments: If True, assignment statements are used to synthesize + the final expression. + """ + self._math_symbol_converter = math_symbols.MathSymbolConverter( + enabled=use_math_symbols ) - self.reduce_assignments = reduce_assignments - self.assign_var = {} - super().__init__() + self._use_raw_function_name = use_raw_function_name + self._reduce_assignments = reduce_assignments - def _parse_math_symbols(self, val: str) -> str: - if not self.math_symbol: - return val - if val in constants.MATH_SYMBOLS: - return "{\\" + val + "}" - return val + self.assign_var = {} def generic_visit(self, node, action): - del action - return str(node) def visit_Module(self, node, action): # pylint: disable=invalid-name - del action - return self.visit(node.body[0], "multi_lines") def visit_FunctionDef(self, node, action): # pylint: disable=invalid-name - del action + name_str = str(node.name) + if self._use_raw_function_name: + name_str = name_str.replace(r"_", r"\_") + name_str = r"\mathrm{" + name_str + "}" - name_str = r"\mathrm{" + str(node.name) + "}" - if self.raw_func_name: - name_str = name_str.replace(r"_", r"\_") # fix #31 - arg_strs = [self._parse_math_symbols(str(arg.arg)) for arg in node.args.args] + arg_strs = [ + self._math_symbol_converter.convert(str(arg.arg)) for arg in node.args.args + ] body_str = "" assign_vars = [] for el in node.body: if isinstance(el, ast.FunctionDef): - if self.reduce_assignments: + if self._reduce_assignments: body_str = self.visit(el, "in_line") self.assign_var[el.name] = rf"\left( {body_str} \right)" else: @@ -74,7 +72,7 @@ def visit_FunctionDef(self, node, action): # pylint: disable=invalid-name assign_vars.append(body_str + r" \\ ") else: body_str = self.visit(el) - if not self.reduce_assignments and isinstance(el, ast.Assign): + if not self._reduce_assignments and isinstance(el, ast.Assign): assign_vars.append(body_str) elif isinstance(el, ast.Return): break @@ -85,7 +83,6 @@ def visit_FunctionDef(self, node, action): # pylint: disable=invalid-name def visit_FunctionDef_multi_lines(self, node): name_str, arg_strs, assign_vars, body_str = self.visit_FunctionDef(node, None) - print(name_str, arg_strs, assign_vars, body_str) return ( "".join(assign_vars) + name_str @@ -100,35 +97,25 @@ def visit_FunctionDef_in_line(self, node): return "".join(assign_vars) + body_str def visit_Assign(self, node, action): - del action - var = self.visit(node.value) - if self.reduce_assignments: + if self._reduce_assignments: self.assign_var[node.targets[0].id] = rf"\left( {var} \right)" return None else: return rf"{node.targets[0].id} \triangleq {var} \\ " def visit_Return(self, node, action): # pylint: disable=invalid-name - del action - return self.visit(node.value) def visit_Tuple(self, node, action): # pylint: disable=invalid-name - del action - elts = [self.visit(i) for i in node.elts] return r"\left( " + r"\space,\space ".join(elts) + r"\right) " def visit_List(self, node, action): # pylint: disable=invalid-name - del action - elts = [self.visit(i) for i in node.elts] return r"\left[ " + r"\space,\space ".join(elts) + r"\right] " def visit_Set(self, node, action): # pylint: disable=invalid-name - del action - elts = [self.visit(i) for i in node.elts] return r"\left\{ " + r"\space,\space ".join(elts) + r"\right\} " @@ -160,10 +147,8 @@ def _decorated_lstr_and_arg(node, callee_str, lstr): return lstr, arg_str - del action - callee_str = self.visit(node.func) - if self.reduce_assignments and ( + if self._reduce_assignments and ( getattr(node.func, "id", None) in self.assign_var.keys() or getattr(node.func, "attr", None) in self.assign_var.keys() ): @@ -183,35 +168,26 @@ def _decorated_lstr_and_arg(node, callee_str, lstr): return lstr + arg_str + rstr def visit_Attribute(self, node, action): # pylint: disable=invalid-name - del action - vstr = self.visit(node.value) astr = str(node.attr) return vstr + "." + astr def visit_Name(self, node, action): # pylint: disable=invalid-name - del action - - if self.reduce_assignments and node.id in self.assign_var.keys(): + if self._reduce_assignments and node.id in self.assign_var.keys(): return self.assign_var[node.id] - return self._parse_math_symbols(str(node.id)) + return self._math_symbol_converter.convert(str(node.id)) def visit_Constant(self, node, action): # pylint: disable=invalid-name - del action - # for python >= 3.8 return str(node.n) def visit_Num(self, node, action): # pylint: disable=invalid-name - del action - # for python < 3.8 return str(node.n) def visit_UnaryOp(self, node, action): # pylint: disable=invalid-name """Visit a unary op node.""" - del action def _wrap(child): latex = self.visit(child) @@ -233,8 +209,6 @@ def _wrap(child): def visit_BinOp(self, node, action): # pylint: disable=invalid-name """Visit a binary op node.""" - del action - priority = constants.BIN_OP_PRIORITY def _unwrap(child): @@ -274,8 +248,6 @@ def _wrap(child): def visit_Compare(self, node, action): # pylint: disable=invalid-name """Visit a compare node.""" - del action - lstr = self.visit(node.left) rstr = self.visit(node.comparators[0]) @@ -297,8 +269,6 @@ def visit_Compare(self, node, action): # pylint: disable=invalid-name return r"\mathrm{unknown\_comparator}(" + lstr + ", " + rstr + ")" def visit_BoolOp(self, node, action): # pylint: disable=invalid-name - del action - logic_operator = ( r"\lor " if isinstance(node.op, ast.Or) @@ -319,8 +289,6 @@ def visit_BoolOp(self, node, action): # pylint: disable=invalid-name def visit_If(self, node, action): # pylint: disable=invalid-name """Visit an if node.""" - del action - latex = r"\left\{ \begin{array}{ll} " while isinstance(node, ast.If): @@ -357,63 +325,3 @@ def visit_comprehension_set_bounds(self, node): # pylint: disable=invalid-name raise TypeError( "Comprehension for sum only supports range func " "with 1 or 2 args" ) - - -def get_latex(fn, *args, **kwargs): - try: - source = inspect.getsource(fn) - # pylint: disable=broad-except - except Exception: - # Maybe running on console. - source = dill.source.getsource(fn) - - source = textwrap.dedent(source) - - return LatexifyVisitor(*args, **kwargs).visit(ast.parse(source)) - - -def with_latex(*args, **kwargs): - """Translate a function with latex representation.""" - - class _LatexifiedFunction: - """Function with latex representation.""" - - def __init__(self, fn): - self._fn = fn - self._str = get_latex(fn, *args, **kwargs) - - @property - def __doc__(self): - return self._fn.__doc__ - - @__doc__.setter - def __doc__(self, val): - self._fn.__doc__ = val - - @property - def __name__(self): - return self._fn.__name__ - - @__name__.setter - def __name__(self, val): - self._fn.__name__ = val - - def __call__(self, *args): - return self._fn(*args) - - def __str__(self): - return self._str - - def _repr_latex_(self): - """ - Hooks into Jupyter notebook's display system. - """ - return r"$$ \displaystyle " + self._str + " $$" - - if len(args) == 1 and callable(args[0]): - return _LatexifiedFunction(args[0]) - - def ret(fn): - return _LatexifiedFunction(fn) - - return ret diff --git a/src/latexify/math_symbols.py b/src/latexify/math_symbols.py new file mode 100644 index 0000000..9b70a5c --- /dev/null +++ b/src/latexify/math_symbols.py @@ -0,0 +1,84 @@ +"""Utilities to manipulate math symbols.""" + +from __future__ import annotations + +_MATH_SYMBOLS = { + "aleph", + "alpha", + "beta", + "beth", + "chi", + "daleth", + "delta", + "digamma", + "epsilon", + "eta", + "gamma", + "gimel", + "iota", + "kappa", + "lambda", + "mu", + "nu", + "omega", + "phi", + "pi", + "psi", + "rho", + "sigma", + "tau", + "theta", + "upsilon", + "varepsilon", + "varkappa", + "varphi", + "varpi", + "varrho", + "varsigma", + "vartheta", + "xi", + "zeta", + "Delta", + "Gamma", + "Lambda", + "Omega", + "Phi", + "Pi", + "Psi", + "Sigma", + "Theta", + "Upsilon", + "Xi", +} + + +class MathSymbolConverter: + """Strategy to convert identifier name to LaTeX math symbols.""" + + _enabled: bool + + def __init__(self, enabled: bool): + """Initializer. + + Args: + enabled: Whether to enable every conversion. If True, all conversion will be + performed. If False, the given string is returned as-is. + """ + self._enabled = enabled + + def convert(self, name: str) -> str: + """Converts given identifier to the specified form. + + Args: + name: Name of the identifier to be converted. + + Returns: + Converted LaTeX string. + """ + if not self._enabled: + return name + + if name in _MATH_SYMBOLS: + return "{\\" + name + "}" + + return name diff --git a/src/latexify/math_symbols_test.py b/src/latexify/math_symbols_test.py new file mode 100644 index 0000000..ff64148 --- /dev/null +++ b/src/latexify/math_symbols_test.py @@ -0,0 +1,20 @@ +"""Tests for latexify.math_symbols.""" + +from __future__ import annotations + +import pytest + +from latexify import math_symbols + + +@pytest.mark.parametrize( + "name,converted,enabled", + [ + ("foo", "foo", False), + ("foo", "foo", True), + ("alpha", "alpha", False), + ("alpha", "{\\alpha}", True), + ], +) +def test_math_symbol_converter_convert(name: str, converted: str, enabled: bool): + assert math_symbols.MathSymbolConverter(enabled=enabled).convert(name) == converted