Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factor out expression codegen from function codegen #155

Merged
merged 18 commits into from
Dec 10, 2022
41 changes: 41 additions & 0 deletions src/latexify/analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import ast
import dataclasses
import sys

from latexify import ast_utils, exceptions

Expand Down Expand Up @@ -62,3 +63,43 @@ def analyze_range(node: ast.Call) -> RangeInfo:
stop_int=ast_utils.extract_int_or_none(stop),
step_int=ast_utils.extract_int_or_none(step),
)


def reduce_stop_parameter(node: ast.expr) -> ast.expr:
"""Adjusts the stop expression of the range.

This function tries to convert the syntax as follows:
* n + 1 --> n
* n + 2 --> n + 1
* n - 1 --> n - 2

Args:
node: The target expression.

Returns:
Converted expression.
"""
if not (isinstance(node, ast.BinOp) and isinstance(node.op, (ast.Add, ast.Sub))):
return ast.BinOp(left=node, op=ast.Sub(), right=ast_utils.make_constant(1))

# Treatment for Python 3.7.
rhs = (
ast.Constant(value=node.right.n)
if sys.version_info.minor < 8 and isinstance(node.right, ast.Num)
else node.right
)

if not isinstance(rhs, ast.Constant):
return ast.BinOp(left=node, op=ast.Sub(), right=ast_utils.make_constant(1))

shift = 1 if isinstance(node.op, ast.Add) else -1

return (
node.left
if rhs.value == shift
else ast.BinOp(
left=node.left,
op=node.op,
right=ast_utils.make_constant(value=rhs.value - shift),
)
)
17 changes: 17 additions & 0 deletions src/latexify/analyzers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,20 @@ def test_analyze_range_invalid(code: str) -> None:
exceptions.LatexifySyntaxError, match=r"^Unsupported AST for analyze_range\.$"
):
analyzers.analyze_range(node)


@pytest.mark.parametrize(
"before,after",
[
("n + 1", "n"),
("n + 2", "n + 1"),
("n - (-1)", "n - (-1) - 1"),
("n - 1", "n - 2"),
("1 * 2", "1 * 2 - 1"),
],
)
def test_reduce_stop_parameter(before: str, after: str) -> None:
test_utils.assert_ast_equal(
analyzers.reduce_stop_parameter(ast_utils.parse_expr(before)),
ast_utils.parse_expr(after),
)
3 changes: 2 additions & 1 deletion src/latexify/codegen/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Package latexify.codegen."""

from latexify.codegen import function_codegen
from latexify.codegen import expression_codegen, function_codegen

ExpressionCodegen = expression_codegen.ExpressionCodegen
FunctionCodegen = function_codegen.FunctionCodegen
28 changes: 28 additions & 0 deletions src/latexify/codegen/codegen_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any

from latexify import exceptions


def convert_constant(value: Any) -> str:
"""Helper to convert constant values to LaTeX.

Args:
value: A constant value.

Returns:
The LaTeX representation of `value`.
"""
if value is None or isinstance(value, bool):
return r"\mathrm{" + str(value) + "}"
if isinstance(value, (int, float, complex)):
# TODO(odashi): Support other symbols for the imaginary unit than j.
return str(value)
if isinstance(value, str):
return r'\textrm{"' + value + '"}'
if isinstance(value, bytes):
return r"\textrm{" + str(value) + "}"
if value is ...:
return r"\cdots"
raise exceptions.LatexifyNotSupportedError(
f"Unrecognized constant: {type(value).__name__}"
)
34 changes: 34 additions & 0 deletions src/latexify/codegen/codegen_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Tests for latexify.codegen.codegen_utils."""

from __future__ import annotations

from typing import Any

import pytest

from latexify import exceptions
from latexify.codegen.codegen_utils import convert_constant


@pytest.mark.parametrize(
"constant,latex",
[
(None, r"\mathrm{None}"),
(True, r"\mathrm{True}"),
(False, r"\mathrm{False}"),
(123, "123"),
(456.789, "456.789"),
(-3 + 4j, "(-3+4j)"),
("string", r'\textrm{"string"}'),
(..., r"\cdots"),
],
)
def test_convert_constant(constant: Any, latex: str) -> None:
assert convert_constant(constant) == latex


def test_convert_constant_unsupported_constant() -> None:
with pytest.raises(
exceptions.LatexifyNotSupportedError, match="^Unrecognized constant: "
):
convert_constant({})
Loading