Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring of the library #58

Merged
merged 3 commits into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 48 additions & 60 deletions src/integration_tests/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,24 +70,23 @@ def sum_with_limit_two_args(a, n):
)


func_and_latex_str_list = [
(solve, solve_latex, None),
(sinc, sinc_latex, None),
(xtimesbeta, xtimesbeta_latex, True),
(xtimesbeta, xtimesbeta_latex_no_symbols, False),
(sum_with_limit, sum_with_limit_latex, None),
(sum_with_limit_two_args, sum_with_limit_two_args_latex, None),
]


@pytest.mark.parametrize("func, expected_latex, math_symbol", func_and_latex_str_list)
def test_with_latex_to_str(func, expected_latex, math_symbol):
@pytest.mark.parametrize(
"func, expected_latex, use_math_symbols",
[
(solve, solve_latex, None),
(sinc, sinc_latex, None),
(xtimesbeta, xtimesbeta_latex, True),
(xtimesbeta, xtimesbeta_latex_no_symbols, False),
(sum_with_limit, sum_with_limit_latex, None),
(sum_with_limit_two_args, sum_with_limit_two_args_latex, None),
],
)
def test_with_latex_to_str(func, expected_latex, use_math_symbols):
"""Test with_latex to str."""
# pylint: disable=protected-access
if math_symbol is None:
if use_math_symbols is None:
latexified_function = with_latex(func)
else:
latexified_function = with_latex(math_symbol=math_symbol)(func)
latexified_function = with_latex(use_math_symbols=use_math_symbols)(func)
assert str(latexified_function) == expected_latex
expected_repr = r"$$ \displaystyle %s $$" % expected_latex
assert latexified_function._repr_latex_() == expected_repr
Expand All @@ -110,57 +109,46 @@ def inner(y):
assert get_latex(nested(3)) == r"\mathrm{inner}(y) \triangleq xy"


def test_assign_feature():
@with_latex
def f(x):
return abs(x) * math.exp(math.sqrt(x))

@with_latex
def g(x):
a = abs(x)
b = math.exp(math.sqrt(x))
return a * b

@with_latex(reduce_assignments=False)
def h(x):
a = abs(x)
b = math.exp(math.sqrt(x))
return a * b

assert str(f) == (
r"\mathrm{f}(x) \triangleq \left|{x}\right|\exp{\left({\sqrt{x}}\right)}"
)
assert str(g) == (
r"\mathrm{g}(x) \triangleq "
r"\left( "
r"\left|{x}\right| \right)\left( \exp{\left({\sqrt{x}}\right)} "
r"\right)"
def test_use_raw_function_name():
def foo_bar():
return 42

assert str(with_latex(foo_bar)) == r"\mathrm{foo_bar}() \triangleq 42"
assert (
str(with_latex(foo_bar, use_raw_function_name=True))
== r"\mathrm{foo\_bar}() \triangleq 42"
)
assert str(h) == (
r"a \triangleq "
r"\left|{x}\right| \\ "
r"b \triangleq \exp{\left({\sqrt{x}}\right)} \\ "
r"\mathrm{h}(x) \triangleq ab"
assert (
str(with_latex(use_raw_function_name=True)(foo_bar))
== r"\mathrm{foo\_bar}() \triangleq 42"
)

@with_latex(reduce_assignments=True)

def test_reduce_assignments():
def f(x):
a = math.sqrt(math.exp(x))
return abs(x) * math.log10(a)
a = x + x
return 3 * a

assert str(with_latex(f)) == r"a \triangleq x + x \\ \mathrm{f}(x) \triangleq 3a"

latex_with_option = r"\mathrm{f}(x) \triangleq 3\left( x + x \right)"
assert str(with_latex(f, reduce_assignments=True)) == latex_with_option
assert str(with_latex(reduce_assignments=True)(f)) == latex_with_option

assert str(f) == (
r"\mathrm{f}(x) \triangleq "
r"\left|{x}\right|"
r"\log_{10}{\left({\left( \sqrt{\exp{\left({x}\right)}} \right)}\right)}"
)

@with_latex(reduce_assignments=False)
def test_reduce_assignments_double():
def f(x):
a = math.sqrt(math.exp(x))
return abs(x) * math.log10(a)
a = x**2
b = a + a
return 3 * b

assert str(f) == (
r"a \triangleq "
r"\sqrt{\exp{\left({x}\right)}} \\ "
r"\mathrm{f}(x) \triangleq \left|{x}\right|\log_{10}{\left({a}\right)}"
assert str(with_latex(f)) == (
r"a \triangleq x^{2} \\ b \triangleq a + a \\ \mathrm{f}(x) \triangleq 3b"
)

latex_with_option = (
r"\mathrm{f}(x) \triangleq "
r"3\left( \left( x^{2} \right) + \left( x^{2} \right) \right)"
)
assert str(with_latex(f, reduce_assignments=True)) == latex_with_option
assert str(with_latex(reduce_assignments=True)(f)) == latex_with_option
6 changes: 3 additions & 3 deletions src/latexify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Latexify toplevel module."""

from latexify import core
from latexify import frontend

get_latex = core.get_latex
with_latex = core.with_latex
get_latex = frontend.get_latex
with_latex = frontend.with_latex
48 changes: 0 additions & 48 deletions src/latexify/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,54 +25,6 @@ class Actions(NamedTuple):

actions = Actions()

MATH_SYMBOLS = {
"aleph",
"alpha",
"beta",
"beth",
"chi",
"daleth",
"delta",
"digamma",
"epsilon",
"eta",
"gamma",
"gimel",
"iota",
"kappa",
"lambda",
"mu",
"nu",
"omega",
"phi",
"pi",
"psi",
"rho",
"sigma",
"tau",
"theta",
"upsilon",
"varepsilon",
"varkappa",
"varphi",
"varpi",
"varrho",
"varsigma",
"vartheta",
"xi",
"zeta",
"Delta",
"Gamma",
"Lambda",
"Omega",
"Phi",
"Pi",
"Sigma",
"Theta",
"Upsilon",
"Xi",
}

PREFIXES = ["math", "numpy", "np"]

BUILTIN_CALLEES = {
Expand Down
Empty file removed src/latexify/core_test.py
Empty file.
113 changes: 113 additions & 0 deletions src/latexify/frontend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""Frontend interfaces of latexify."""

from __future__ import annotations

import ast
from collections.abc import Callable
import inspect
import textwrap
from typing import Any

import dill

from latexify import latexify_visitor


def get_latex(
fn: Callable[..., Any],
*,
use_math_symbols: bool = False,
use_raw_function_name: bool = False,
reduce_assignments: bool = False,
) -> str:
"""Obtains LaTeX description from the function's source.

Args:
fn: Reference to a function to analyze.
use_math_symbols: Whether to convert identifiers with a math symbol surface
(e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha").
use_raw_function_name: Whether to keep underscores "_" in the function name,
or convert it to subscript.
reduce_assignments: If True, assignment statements are used to synthesize
the final expression.

Returns:
Generatee LaTeX description.
"""
try:
source = inspect.getsource(fn)
except Exception:
# Maybe running on console.
source = dill.source.getsource(fn)

# Remove extra indentation so that ast.parse runs correctly.
source = textwrap.dedent(source)

tree = ast.parse(source)

visitor = latexify_visitor.LatexifyVisitor(
use_math_symbols=use_math_symbols,
use_raw_function_name=use_raw_function_name,
reduce_assignments=reduce_assignments,
)

return visitor.visit(tree)


class LatexifiedFunction:
"""Function with latex representation."""

def __init__(self, fn, **kwargs):
self._fn = fn
self._str = get_latex(fn, **kwargs)

@property
def __doc__(self):
return self._fn.__doc__

@__doc__.setter
def __doc__(self, val):
self._fn.__doc__ = val

@property
def __name__(self):
return self._fn.__name__

@__name__.setter
def __name__(self, val):
self._fn.__name__ = val

def __call__(self, *args):
return self._fn(*args)

def __str__(self):
return self._str

def _repr_latex_(self):
"""IPython hook to display LaTeX visualization."""
return r"$$ \displaystyle " + self._str + " $$"


def with_latex(*args, **kwargs) -> Callable[[Callable[..., Any]], LatexifiedFunction]:
"""Translate a function with latex representation.

This function works with or without specifying the target function as the positional
argument. The following two syntaxes works similarly.
- with_latex(fn, **kwargs)
- with_latex(**kwargs)(fn)

Args:
*args: No argument, or a callable.
**kwargs: Arguments to control behavior. See also get_latex().

Returns:
- If the target function is passed directly, returns the wrapped function.
- Otherwise, returns the wrapper function with given settings.
"""
if len(args) == 1 and isinstance(args[0], Callable):
return LatexifiedFunction(args[0], **kwargs)

def wrapper(fn):
return LatexifiedFunction(fn, **kwargs)

return wrapper
Loading