Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable mypy #153

Merged
merged 8 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,17 @@ jobs:
python -m pip install isort
- name: Check
run: python -m isort --check src
mypy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.x"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install '.[mypy]'
- name: Check
run: python -m mypy src
1 change: 1 addition & 0 deletions checks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ python -m pytest src -vv
python -m black --check src
python -m pflake8 src
python -m isort --check src
python -m mypy src
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,16 @@ dev = [
"build>=0.8",
"black>=22.10",
"flake8>=5.0",
"isort>=5.10",
"mypy>=0.991",
"notebook>=6.5.1",
"pyproject-flake8>=5.0",
"pytest>=7.1",
"twine>=4.0",
"isort>=5.10",
]
mypy = [
"mypy>=0.991",
"pytest>=7.1",
]

[project.urls]
Expand Down
3 changes: 0 additions & 3 deletions src/latexify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,3 @@

function = frontend.function
expression = frontend.expression

# Deprecated
with_latex = frontend.with_latex
17 changes: 0 additions & 17 deletions src/latexify/ast_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,6 @@ def test_extract_int_or_none() -> None:


def test_extract_int_or_none_invalid() -> None:
# Not a subtree.
assert ast_utils.extract_int_or_none(123) is None

# Not a direct Constant node.
assert (
ast_utils.extract_int_or_none(ast.Expr(value=ast_utils.make_constant(123)))
is None
)

# Not a Constant node with int.
assert ast_utils.extract_int_or_none(ast_utils.make_constant(None)) is None
assert ast_utils.extract_int_or_none(ast_utils.make_constant(True)) is None
Expand All @@ -147,14 +138,6 @@ def test_extract_int() -> None:


def test_extract_int_invalid() -> None:
# Not a subtree.
with pytest.raises(ValueError, match=r"^Unsupported node to extract int"):
ast_utils.extract_int(123)

# Not a direct Constant node.
with pytest.raises(ValueError, match=r"^Unsupported node to extract int"):
ast_utils.extract_int(ast.Expr(value=ast_utils.make_constant(123)))

# Not a Constant node with int.
with pytest.raises(ValueError, match=r"^Unsupported node to extract int"):
ast_utils.extract_int(ast_utils.make_constant(None))
Expand Down
62 changes: 38 additions & 24 deletions src/latexify/codegen/function_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,11 @@ def visit_Assign(self, node: ast.Assign) -> str:
return " = ".join(operands)

def visit_Return(self, node: ast.Return) -> str:
return self.visit(node.value)
return (
self.visit(node.value)
if node.value is not None
else self._convert_constant(None)
)

def visit_Tuple(self, node: ast.Tuple) -> str:
elts = [self.visit(i) for i in node.elts]
Expand Down Expand Up @@ -401,15 +405,16 @@ def generate_matrix_from_array(data: list[list[str]]) -> str:

ncols = len(row0.elts)

if not all(
isinstance(row, ast.List) and len(row.elts) == ncols for row in arg.elts
):
# Length mismatch
return None
rows: list[list[str]] = []

return generate_matrix_from_array(
[[self.visit(x) for x in row.elts] for row in arg.elts]
)
for row in arg.elts:
if not isinstance(row, ast.List) or len(row.elts) != ncols:
# Length mismatch
return None

rows.append([self.visit(x) for x in row.elts])

return generate_matrix_from_array(rows)

def visit_Call(self, node: ast.Call) -> str:
"""Visit a call node."""
Expand All @@ -427,7 +432,8 @@ def visit_Call(self, node: ast.Call) -> str:
return special_latex

# Obtains the codegen rule.
rule = constants.BUILTIN_FUNCS.get(func_name)
rule = constants.BUILTIN_FUNCS.get(func_name) if func_name is not None else None

if rule is None:
rule = constants.FunctionRule(self.visit(node.func))

Expand Down Expand Up @@ -556,8 +562,11 @@ def _wrap_binop_operand(
return self.visit(child)

if isinstance(child, ast.Call):
rule = constants.BUILTIN_FUNCS.get(
ast_utils.extract_function_name_or_none(child)
child_fn_name = ast_utils.extract_function_name_or_none(child)
rule = (
constants.BUILTIN_FUNCS.get(child_fn_name)
if child_fn_name is not None
else None
)
if rule is not None and rule.is_wrapped:
return self.visit(child)
Expand Down Expand Up @@ -612,30 +621,35 @@ def visit_If(self, node: ast.If) -> str:
"""Visit an if node."""
latex = r"\left\{ \begin{array}{ll} "

while isinstance(node, ast.If):
if len(node.body) != 1 or len(node.orelse) != 1:
current_stmt: ast.stmt = node

while isinstance(current_stmt, ast.If):
if len(current_stmt.body) != 1 or len(current_stmt.orelse) != 1:
raise exceptions.LatexifySyntaxError(
"Multiple statements are not supported in If nodes."
)

cond_latex = self.visit(node.test)
true_latex = self.visit(node.body[0])
cond_latex = self.visit(current_stmt.test)
true_latex = self.visit(current_stmt.body[0])
latex += true_latex + r", & \mathrm{if} \ " + cond_latex + r" \\ "
node = node.orelse[0]
current_stmt = current_stmt.orelse[0]

latex += self.visit(node)
latex += self.visit(current_stmt)
return latex + r", & \mathrm{otherwise} \end{array} \right."

def visit_IfExp(self, node: ast.IfExp) -> str:
"""Visit an ifexp node"""
latex = r"\left\{ \begin{array}{ll} "
while isinstance(node, ast.IfExp):
cond_latex = self.visit(node.test)
true_latex = self.visit(node.body)

current_expr: ast.expr = node

while isinstance(current_expr, ast.IfExp):
cond_latex = self.visit(current_expr.test)
true_latex = self.visit(current_expr.body)
latex += true_latex + r", & \mathrm{if} \ " + cond_latex + r" \\ "
node = node.orelse
current_expr = current_expr.orelse

latex += self.visit(node)
latex += self.visit(current_expr)
return latex + r", & \mathrm{otherwise} \end{array} \right."

def _reduce_stop_parameter(self, node: ast.expr) -> ast.expr:
Expand Down Expand Up @@ -768,7 +782,7 @@ def _get_sum_prod_info(
# Until 3.8
def visit_Index(self, node: ast.Index) -> str:
"""Visitor for the Index nodes."""
return self.visit(node.value)
return self.visit(node.value) # type: ignore[attr-defined]

def _convert_nested_subscripts(self, node: ast.Subscript) -> tuple[str, list[str]]:
"""Helper function to convert nested subscription.
Expand Down
2 changes: 1 addition & 1 deletion src/latexify/codegen/latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, raw: str) -> None:
"""
self._raw = raw

def __eq__(self, other: object) -> None:
def __eq__(self, other: object) -> bool:
"""Checks equality.

Args:
Expand Down
108 changes: 33 additions & 75 deletions src/latexify/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,6 @@
from __future__ import annotations

import dataclasses
import enum


class BuiltinFnName(str, enum.Enum):
"""Built-in function name."""

ABS = "abs"
ACOS = "acos"
ACOSH = "acosh"
ARCCOS = "arccos"
ARCCOSH = "arcosh"
ARCSIN = "arcsin"
ARCSINH = "arcsihn"
ARCTAN = "arctan"
ARCTANH = "arctanh"
ASIN = "asin"
ASINH = "asinh"
ATAN = "atan"
ATANH = "atanh"
CEIL = "ceil"
COS = "cos"
COSH = "cosh"
EXP = "exp"
FABS = "fabs"
FACTORIAL = "factorial"
FLOOR = "floor"
FSUM = "fsum"
GAMMA = "gamma"
LOG = "log"
LOG10 = "log10"
LOG2 = "log2"
PROD = "prod"
SIN = "sin"
SINH = "sinh"
SQRT = "sqrt"
TAN = "tan"
TANH = "tanh"
SUM = "sum"


@dataclasses.dataclass(frozen=True)
Expand All @@ -61,45 +23,41 @@ class FunctionRule:


# name => left_syntax, right_syntax, is_wrapped
BUILTIN_FUNCS: dict[BuiltinFnName, FunctionRule] = {
BuiltinFnName.ABS: FunctionRule(
r"\mathropen{}\left|", r"\mathclose{}\right|", is_wrapped=True
),
BuiltinFnName.ACOS: FunctionRule(r"\arccos", is_unary=True),
BuiltinFnName.ACOSH: FunctionRule(r"\mathrm{arccosh}", is_unary=True),
BuiltinFnName.ARCCOS: FunctionRule(r"\arccos", is_unary=True),
BuiltinFnName.ARCCOSH: FunctionRule(r"\mathrm{arccosh}", is_unary=True),
BuiltinFnName.ARCSIN: FunctionRule(r"\arcsin", is_unary=True),
BuiltinFnName.ARCSINH: FunctionRule(r"\mathrm{arcsinh}", is_unary=True),
BuiltinFnName.ARCTAN: FunctionRule(r"\arctan", is_unary=True),
BuiltinFnName.ARCTANH: FunctionRule(r"\mathrm{arctanh}", is_unary=True),
BuiltinFnName.ASIN: FunctionRule(r"\arcsin", is_unary=True),
BuiltinFnName.ASINH: FunctionRule(r"\mathrm{arcsinh}", is_unary=True),
BuiltinFnName.ATAN: FunctionRule(r"\arctan", is_unary=True),
BuiltinFnName.ATANH: FunctionRule(r"\mathrm{arctanh}", is_unary=True),
BuiltinFnName.CEIL: FunctionRule(
BUILTIN_FUNCS: dict[str, FunctionRule] = {
"abs": FunctionRule(r"\mathropen{}\left|", r"\mathclose{}\right|", is_wrapped=True),
"acos": FunctionRule(r"\arccos", is_unary=True),
"acosh": FunctionRule(r"\mathrm{arccosh}", is_unary=True),
"arccos": FunctionRule(r"\arccos", is_unary=True),
"arccosh": FunctionRule(r"\mathrm{arccosh}", is_unary=True),
"arcsin": FunctionRule(r"\arcsin", is_unary=True),
"arcsinh": FunctionRule(r"\mathrm{arcsinh}", is_unary=True),
"arctan": FunctionRule(r"\arctan", is_unary=True),
"arctanh": FunctionRule(r"\mathrm{arctanh}", is_unary=True),
"asin": FunctionRule(r"\arcsin", is_unary=True),
"asinh": FunctionRule(r"\mathrm{arcsinh}", is_unary=True),
"atan": FunctionRule(r"\arctan", is_unary=True),
"atanh": FunctionRule(r"\mathrm{arctanh}", is_unary=True),
"ceil": FunctionRule(
r"\mathopen{}\left\lceil", r"\mathclose{}\right\rceil", is_wrapped=True
),
BuiltinFnName.COS: FunctionRule(r"\cos", is_unary=True),
BuiltinFnName.COSH: FunctionRule(r"\cosh", is_unary=True),
BuiltinFnName.EXP: FunctionRule(r"\exp", is_unary=True),
BuiltinFnName.FABS: FunctionRule(
r"\mathopen{}\left|", r"\mathclose{}\right|", is_wrapped=True
),
BuiltinFnName.FACTORIAL: FunctionRule("", "!", is_unary=True),
BuiltinFnName.FLOOR: FunctionRule(
"cos": FunctionRule(r"\cos", is_unary=True),
"cosh": FunctionRule(r"\cosh", is_unary=True),
"exp": FunctionRule(r"\exp", is_unary=True),
"fabs": FunctionRule(r"\mathopen{}\left|", r"\mathclose{}\right|", is_wrapped=True),
"factorial": FunctionRule("", "!", is_unary=True),
"floor": FunctionRule(
r"\mathopen{}\left\lfloor", r"\mathclose{}\right\rfloor", is_wrapped=True
),
BuiltinFnName.FSUM: FunctionRule(r"\sum", is_unary=True),
BuiltinFnName.GAMMA: FunctionRule(r"\Gamma"),
BuiltinFnName.LOG: FunctionRule(r"\log", is_unary=True),
BuiltinFnName.LOG10: FunctionRule(r"\log_10", is_unary=True),
BuiltinFnName.LOG2: FunctionRule(r"\log_2", is_unary=True),
BuiltinFnName.PROD: FunctionRule(r"\prod", is_unary=True),
BuiltinFnName.SIN: FunctionRule(r"\sin", is_unary=True),
BuiltinFnName.SINH: FunctionRule(r"\sinh", is_unary=True),
BuiltinFnName.SQRT: FunctionRule(r"\sqrt{", "}", is_wrapped=True),
BuiltinFnName.SUM: FunctionRule(r"\sum", is_unary=True),
BuiltinFnName.TAN: FunctionRule(r"\tan", is_unary=True),
BuiltinFnName.TANH: FunctionRule(r"\tanh", is_unary=True),
"fsum": FunctionRule(r"\sum", is_unary=True),
"gamma": FunctionRule(r"\Gamma"),
"log": FunctionRule(r"\log", is_unary=True),
"log10": FunctionRule(r"\log_10", is_unary=True),
"log2": FunctionRule(r"\log_2", is_unary=True),
"prod": FunctionRule(r"\prod", is_unary=True),
"sin": FunctionRule(r"\sin", is_unary=True),
"sinh": FunctionRule(r"\sinh", is_unary=True),
"sqrt": FunctionRule(r"\sqrt{", "}", is_wrapped=True),
"sum": FunctionRule(r"\sum", is_unary=True),
"tan": FunctionRule(r"\tan", is_unary=True),
"tanh": FunctionRule(r"\tanh", is_unary=True),
}
Loading