Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix trimmer #138

Merged
merged 4 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ jobs:
python -m pip install --upgrade pip
python -m pip install isort
- name: Check
run: python -m isort -v src
run: python -m isort --check src
2 changes: 1 addition & 1 deletion checks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ set -eoux pipefail
python -m pytest src -vv
python -m black --check src
python -m pflake8 src
python -m isort -v src
python -m isort --check src
5 changes: 5 additions & 0 deletions src/integration_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Package integration_tests."""

import pytest

pytest.register_assert_rewrite("integration_tests.utils")
25 changes: 25 additions & 0 deletions src/latexify/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,31 @@
from typing import Any


def make_name(id: str) -> ast.Name:
"""Generates a new Name node.

Args:
id: Name of the node.

Returns:
Generated ast.Name.
"""
return ast.Name(id=id, ctx=ast.Load())


def make_attribute(value: ast.Expr, attr: str):
"""Generates a new Attribute node.

Args:
value: Parent value.
attr: Attribute name.

Returns:
Generated ast.Attribute.
"""
return ast.Attribute(value=value, attr=attr, ctx=ast.Load())


def make_constant(value: Any) -> ast.expr:
"""Generates a new Constant node.

Expand Down
13 changes: 13 additions & 0 deletions src/latexify/ast_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@
from latexify import ast_utils, test_utils


def test_make_name() -> None:
test_utils.assert_ast_equal(
ast_utils.make_name("foo"), ast.Name(id="foo", ctx=ast.Load())
)


def test_make_attribute() -> None:
test_utils.assert_ast_equal(
ast_utils.make_attribute(ast_utils.make_name("foo"), "bar"),
ast.Attribute(ast.Name(id="foo", ctx=ast.Load()), attr="bar", ctx=ast.Load()),
)


@test_utils.require_at_most(7)
@pytest.mark.parametrize(
"value,expected",
Expand Down
7 changes: 0 additions & 7 deletions src/latexify/codegen/function_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,6 @@ def visit_Call(self, node: ast.Call) -> str:
# Function signature (possibly an expression).
func_str = self.visit(node.func)

# Removes common prefixes: math.sqrt -> sqrt
# TODO(odashi): This process can be implemented as a NodeTransformer.
for prefix in constants.PREFIXES:
if func_str.startswith(f"{prefix}."):
func_str = func_str[len(prefix) + 1 :]
break

# Obtains wrapper syntax: sqrt -> "\sqrt{" and "}"
lstr, rstr = constants.BUILTIN_FUNCS.get(
func_str,
Expand Down
14 changes: 7 additions & 7 deletions src/latexify/codegen/function_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def test_visit_setcomp(code: str, latex: str) -> None:
],
)
def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
for src_fn, dest_fn in [("sum", r"\sum"), ("math.prod", r"\prod")]:
for src_fn, dest_fn in [("sum", r"\sum"), ("prod", r"\prod")]:
node = ast.parse(src_fn + src_suffix).body[0].value
assert isinstance(node, ast.Call)
assert FunctionCodegen().visit(node) == dest_fn + dest_suffix
Expand All @@ -243,12 +243,12 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
),
# 3 clauses
(
"math.prod(i for y in x for i in y)",
"prod(i for y in x for i in y)",
r"\prod_{y \in x}^{} \prod_{i \in y}^{} "
r"\mathopen{}\left({i}\mathclose{}\right)",
),
(
"math.prod(i for y in x for z in y for i in z)",
"prod(i for y in x for z in y for i in z)",
r"\prod_{y \in x}^{} \prod_{z \in y}^{} \prod_{i \in z}^{} "
r"\mathopen{}\left({i}\mathclose{}\right)",
),
Expand All @@ -258,7 +258,7 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
r"\sum_{i = {0}}^{{n}} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"math.prod(i for i in range(n-1))",
"prod(i for i in range(n-1))",
r"\prod_{i = {0}}^{{n - {2}}} \mathopen{}\left({i}\mathclose{}\right)",
),
# reduce stop parameter
Expand All @@ -267,7 +267,7 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
r"\sum_{i = {0}}^{{n}} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"math.prod(i for i in range(n-1))",
"prod(i for i in range(n-1))",
r"\prod_{i = {0}}^{{n - {2}}} \mathopen{}\left({i}\mathclose{}\right)",
),
],
Expand Down Expand Up @@ -298,7 +298,7 @@ def test_visit_call_sum_prod_multiple_comprehension(code: str, latex: str) -> No
],
)
def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None:
for src_fn, dest_fn in [("sum", r"\sum"), ("math.prod", r"\prod")]:
for src_fn, dest_fn in [("sum", r"\sum"), ("prod", r"\prod")]:
node = ast.parse(src_fn + src_suffix).body[0].value
assert isinstance(node, ast.Call)
assert FunctionCodegen().visit(node) == dest_fn + dest_suffix
Expand Down Expand Up @@ -707,7 +707,7 @@ def test_visit_constant(code: str, latex: str) -> None:
("x[0][1]", "{x_{{0}, {1}}}"),
("x[0][1][2]", "{x_{{0}, {1}, {2}}}"),
("x[foo]", "{x_{foo}}"),
("x[math.floor(x)]", r"{x_{\left\lfloor{x}\right\rfloor}}"),
("x[floor(x)]", r"{x_{\left\lfloor{x}\right\rfloor}}"),
],
)
def test_visit_subscript(code: str, latex: str) -> None:
Expand Down
3 changes: 2 additions & 1 deletion src/latexify/codegen/latex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from __future__ import annotations

from latexify.codegen.latex import Latex # Ignores [22-imports] for convenience.
# Ignores [22-imports] for convenience.
from latexify.codegen.latex import Latex


def test_eq() -> None:
Expand Down
4 changes: 4 additions & 0 deletions src/latexify/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class Config:
and corresponding values are the replacements.
Both keys and values have to represent valid Python identifiers:
^[A-Za-z_][A-Za-z0-9_]*$
prefixes: Prefixes of identifiers to trim. E.g., if "foo.bar" in prefixes, all
identifiers with the form "foo.bar.suffix" will be replaced to "suffix"
reduce_assignments: If True, assignment statements are used to synthesize
the final expression.
use_math_symbols: Whether to convert identifiers with a math symbol surface
Expand All @@ -30,6 +32,7 @@ class Config:

expand_functions: set[str] | None
identifiers: dict[str, str] | None
prefixes: set[str] | None
reduce_assignments: bool
use_math_symbols: bool
use_raw_function_name: bool
Expand Down Expand Up @@ -70,6 +73,7 @@ def defaults() -> Config:
return Config(
expand_functions=None,
identifiers=None,
prefixes=None,
reduce_assignments=False,
use_math_symbols=False,
use_raw_function_name=False,
Expand Down
2 changes: 0 additions & 2 deletions src/latexify/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

import enum

PREFIXES = ["math", "numpy", "np"]


class BuiltinFnName(str, enum.Enum):
"""Built-in function name."""
Expand Down
10 changes: 10 additions & 0 deletions src/latexify/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
from latexify import config as cfg
from latexify import exceptions, parser, transformers

# NOTE(odashi):
# These prefixes are trimmed by default.
# This behavior shouldn't be controlled by users in the current implementation because
# some processes expects absense of these prefixes.
_COMMON_PREFIXES = {"math", "numpy", "np"}


# TODO(odashi): move expand_functions to Config.
def get_latex(
Expand Down Expand Up @@ -39,6 +45,10 @@ def get_latex(
tree = parser.parse_function(fn)

# Applies AST transformations.

prefixes = _COMMON_PREFIXES | (merged_config.prefixes or set())
tree = transformers.PrefixTrimmer(prefixes).visit(tree)

if merged_config.identifiers is not None:
tree = transformers.IdentifierReplacer(merged_config.identifiers).visit(tree)
if merged_config.reduce_assignments:
Expand Down
21 changes: 21 additions & 0 deletions src/latexify/frontend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,27 @@ def myfn(myvar):
assert frontend.get_latex(myfn, identifiers=identifiers) == latex_with_flag


def test_get_latex_prefixes() -> None:
math = numpy = np = abc = object()

def f(x):
return math.foo + numpy.bar + np.baz + abc.qux + x.y.z.quux

latex_without_flag = r"\mathrm{f}(x) = foo + bar + baz + abc.qux + x.y.z.quux"
latex_with_flag1 = r"\mathrm{f}(x) = foo + bar + baz + qux + x.y.z.quux"
latex_with_flag2 = r"\mathrm{f}(x) = foo + bar + baz + abc.qux + y.z.quux"
latex_with_flag3 = r"\mathrm{f}(x) = foo + bar + baz + abc.qux + z.quux"
latex_with_flag4 = r"\mathrm{f}(x) = foo + bar + baz + qux + quux"

assert frontend.get_latex(f) == latex_without_flag
assert frontend.get_latex(f, prefixes=set()) == latex_without_flag
assert frontend.get_latex(f, prefixes={"abc"}) == latex_with_flag1
assert frontend.get_latex(f, prefixes={"x"}) == latex_with_flag2
assert frontend.get_latex(f, prefixes={"x.y"}) == latex_with_flag3
assert frontend.get_latex(f, prefixes={"abc", "x.y.z"}) == latex_with_flag4
assert frontend.get_latex(f, prefixes={"abc", "x", "x.y.z"}) == latex_with_flag4


def test_get_latex_reduce_assignments() -> None:
def f(x):
y = 3 * x
Expand Down
18 changes: 10 additions & 8 deletions src/latexify/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Package latexify.transformers."""

from latexify.transformers import (
assignment_reducer,
function_expander,
identifier_replacer,
)
from latexify.transformers.assignment_reducer import AssignmentReducer
from latexify.transformers.function_expander import FunctionExpander
from latexify.transformers.identifier_replacer import IdentifierReplacer
from latexify.transformers.prefix_trimmer import PrefixTrimmer

AssignmentReducer = assignment_reducer.AssignmentReducer
FunctionExpander = function_expander.FunctionExpander
IdentifierReplacer = identifier_replacer.IdentifierReplacer
__all__ = [
AssignmentReducer,
FunctionExpander,
IdentifierReplacer,
PrefixTrimmer,
]
97 changes: 97 additions & 0 deletions src/latexify/transformers/prefix_trimmer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""NodeTransformer to trim unnecessary prefixes."""

from __future__ import annotations

import ast
import re

from latexify import ast_utils

_PREFIX_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)*$")


class PrefixTrimmer(ast.NodeTransformer):
"""NodeTransformer to trim unnecessary prefixes.

This class investigates all Attribute subtrees, and replace them if the prefix of
the attribute matches the given set of prefixes.
Prefix is searched in the manner of leftmost longest matching.

Example:
def f(x):
return math.sqrt(x)

PrefixTrimmer({"math"}) will modify the AST of the function above to below:

def f(x):
return sqrt(x)
"""

_prefixes: list[tuple[str, ...]]

def __init__(self, prefixes: set[str]) -> None:
"""Initializer.

Args:
prefixes: Set of prefixes to be trimmed. Nested prefix is allowed too.
Each value must follow one of the following formats:
- A Python identifier, e.g., "math"
- Python identifiers joined by periods, e.g., "numpy.random"
"""
for p in prefixes:
if not _PREFIX_PATTERN.match(p):
raise ValueError(f"Invalid prefix: {p}")

self._prefixes = [tuple(p.split(".")) for p in prefixes]

def _get_prefix(self, node: ast.expr) -> tuple[str, ...] | None:
"""Helper to obtain nested prefix.

Args:
node: Node to investigate.

Returns:
The prefix tuple, or None if the node has unsupported syntax.
"""
if isinstance(node, ast.Name):
return (node.id,)

if isinstance(node, ast.Attribute):
parent = self._get_prefix(node.value)
return parent + (node.attr,) if parent is not None else None

return None

def _make_attribute(self, prefix: tuple[str, ...], name: str) -> ast.expr:
"""Helper to generate a new Attribute or Name node.

Args:
prefix: List of prefixes.
name: Attribute name.

Returns:
Name node if prefix == (), (possibly nested) Attribute node otherwise.
"""
if not prefix:
return ast_utils.make_name(name)

parent = self._make_attribute(prefix[:-1], prefix[-1])
return ast_utils.make_attribute(parent, name)

def visit_Attribute(self, node: ast.Attribute) -> ast.expr:
prefix = self._get_prefix(node.value)
if prefix is None:
return node

# Performs leftmost longest match.
# NOTE(odashi):
# This implementation is very naive, but would work efficiently as long as the
# number of patterns is small.
matched_length = 0

for p in self._prefixes:
length = min(len(p), len(prefix))
if prefix[:length] == p and length > matched_length:
matched_length = length

return self._make_attribute(prefix[matched_length:], node.attr)
Loading