Skip to content

Commit 693d825

Browse files
author
Yusuke Oda
authoredDec 7, 2022
Enable mypy (#153)
* add mypy * revert BuiltinFnName * fix * fix * fix * fix
1 parent 02116c1 commit 693d825

File tree

11 files changed

+123
-93
lines changed

11 files changed

+123
-93
lines changed
 

‎.github/workflows/ci.yml

+14
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,17 @@ jobs:
6969
python -m pip install isort
7070
- name: Check
7171
run: python -m isort --check src
72+
mypy:
73+
runs-on: ubuntu-latest
74+
steps:
75+
- uses: actions/checkout@v3
76+
- name: Set up Python
77+
uses: actions/setup-python@v4
78+
with:
79+
python-version: "3.x"
80+
- name: Install dependencies
81+
run: |
82+
python -m pip install --upgrade pip
83+
python -m pip install '.[mypy]'
84+
- name: Check
85+
run: python -m mypy src

‎checks.sh

+1
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ python -m pytest src -vv
55
python -m black --check src
66
python -m pflake8 src
77
python -m isort --check src
8+
python -m mypy src

‎pyproject.toml

+6-1
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,16 @@ dev = [
4545
"build>=0.8",
4646
"black>=22.10",
4747
"flake8>=5.0",
48+
"isort>=5.10",
49+
"mypy>=0.991",
4850
"notebook>=6.5.1",
4951
"pyproject-flake8>=5.0",
5052
"pytest>=7.1",
5153
"twine>=4.0",
52-
"isort>=5.10",
54+
]
55+
mypy = [
56+
"mypy>=0.991",
57+
"pytest>=7.1",
5358
]
5459

5560
[project.urls]

‎src/latexify/__init__.py

-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,3 @@
1313

1414
function = frontend.function
1515
expression = frontend.expression
16-
17-
# Deprecated
18-
with_latex = frontend.with_latex

‎src/latexify/ast_utils_test.py

-17
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,6 @@ def test_extract_int_or_none() -> None:
121121

122122

123123
def test_extract_int_or_none_invalid() -> None:
124-
# Not a subtree.
125-
assert ast_utils.extract_int_or_none(123) is None
126-
127-
# Not a direct Constant node.
128-
assert (
129-
ast_utils.extract_int_or_none(ast.Expr(value=ast_utils.make_constant(123)))
130-
is None
131-
)
132-
133124
# Not a Constant node with int.
134125
assert ast_utils.extract_int_or_none(ast_utils.make_constant(None)) is None
135126
assert ast_utils.extract_int_or_none(ast_utils.make_constant(True)) is None
@@ -147,14 +138,6 @@ def test_extract_int() -> None:
147138

148139

149140
def test_extract_int_invalid() -> None:
150-
# Not a subtree.
151-
with pytest.raises(ValueError, match=r"^Unsupported node to extract int"):
152-
ast_utils.extract_int(123)
153-
154-
# Not a direct Constant node.
155-
with pytest.raises(ValueError, match=r"^Unsupported node to extract int"):
156-
ast_utils.extract_int(ast.Expr(value=ast_utils.make_constant(123)))
157-
158141
# Not a Constant node with int.
159142
with pytest.raises(ValueError, match=r"^Unsupported node to extract int"):
160143
ast_utils.extract_int(ast_utils.make_constant(None))

‎src/latexify/codegen/function_codegen.py

+38-24
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,11 @@ def visit_Assign(self, node: ast.Assign) -> str:
296296
return " = ".join(operands)
297297

298298
def visit_Return(self, node: ast.Return) -> str:
299-
return self.visit(node.value)
299+
return (
300+
self.visit(node.value)
301+
if node.value is not None
302+
else self._convert_constant(None)
303+
)
300304

301305
def visit_Tuple(self, node: ast.Tuple) -> str:
302306
elts = [self.visit(i) for i in node.elts]
@@ -401,15 +405,16 @@ def generate_matrix_from_array(data: list[list[str]]) -> str:
401405

402406
ncols = len(row0.elts)
403407

404-
if not all(
405-
isinstance(row, ast.List) and len(row.elts) == ncols for row in arg.elts
406-
):
407-
# Length mismatch
408-
return None
408+
rows: list[list[str]] = []
409409

410-
return generate_matrix_from_array(
411-
[[self.visit(x) for x in row.elts] for row in arg.elts]
412-
)
410+
for row in arg.elts:
411+
if not isinstance(row, ast.List) or len(row.elts) != ncols:
412+
# Length mismatch
413+
return None
414+
415+
rows.append([self.visit(x) for x in row.elts])
416+
417+
return generate_matrix_from_array(rows)
413418

414419
def visit_Call(self, node: ast.Call) -> str:
415420
"""Visit a call node."""
@@ -427,7 +432,8 @@ def visit_Call(self, node: ast.Call) -> str:
427432
return special_latex
428433

429434
# Obtains the codegen rule.
430-
rule = constants.BUILTIN_FUNCS.get(func_name)
435+
rule = constants.BUILTIN_FUNCS.get(func_name) if func_name is not None else None
436+
431437
if rule is None:
432438
rule = constants.FunctionRule(self.visit(node.func))
433439

@@ -556,8 +562,11 @@ def _wrap_binop_operand(
556562
return self.visit(child)
557563

558564
if isinstance(child, ast.Call):
559-
rule = constants.BUILTIN_FUNCS.get(
560-
ast_utils.extract_function_name_or_none(child)
565+
child_fn_name = ast_utils.extract_function_name_or_none(child)
566+
rule = (
567+
constants.BUILTIN_FUNCS.get(child_fn_name)
568+
if child_fn_name is not None
569+
else None
561570
)
562571
if rule is not None and rule.is_wrapped:
563572
return self.visit(child)
@@ -612,30 +621,35 @@ def visit_If(self, node: ast.If) -> str:
612621
"""Visit an if node."""
613622
latex = r"\left\{ \begin{array}{ll} "
614623

615-
while isinstance(node, ast.If):
616-
if len(node.body) != 1 or len(node.orelse) != 1:
624+
current_stmt: ast.stmt = node
625+
626+
while isinstance(current_stmt, ast.If):
627+
if len(current_stmt.body) != 1 or len(current_stmt.orelse) != 1:
617628
raise exceptions.LatexifySyntaxError(
618629
"Multiple statements are not supported in If nodes."
619630
)
620631

621-
cond_latex = self.visit(node.test)
622-
true_latex = self.visit(node.body[0])
632+
cond_latex = self.visit(current_stmt.test)
633+
true_latex = self.visit(current_stmt.body[0])
623634
latex += true_latex + r", & \mathrm{if} \ " + cond_latex + r" \\ "
624-
node = node.orelse[0]
635+
current_stmt = current_stmt.orelse[0]
625636

626-
latex += self.visit(node)
637+
latex += self.visit(current_stmt)
627638
return latex + r", & \mathrm{otherwise} \end{array} \right."
628639

629640
def visit_IfExp(self, node: ast.IfExp) -> str:
630641
"""Visit an ifexp node"""
631642
latex = r"\left\{ \begin{array}{ll} "
632-
while isinstance(node, ast.IfExp):
633-
cond_latex = self.visit(node.test)
634-
true_latex = self.visit(node.body)
643+
644+
current_expr: ast.expr = node
645+
646+
while isinstance(current_expr, ast.IfExp):
647+
cond_latex = self.visit(current_expr.test)
648+
true_latex = self.visit(current_expr.body)
635649
latex += true_latex + r", & \mathrm{if} \ " + cond_latex + r" \\ "
636-
node = node.orelse
650+
current_expr = current_expr.orelse
637651

638-
latex += self.visit(node)
652+
latex += self.visit(current_expr)
639653
return latex + r", & \mathrm{otherwise} \end{array} \right."
640654

641655
def _reduce_stop_parameter(self, node: ast.expr) -> ast.expr:
@@ -768,7 +782,7 @@ def _get_sum_prod_info(
768782
# Until 3.8
769783
def visit_Index(self, node: ast.Index) -> str:
770784
"""Visitor for the Index nodes."""
771-
return self.visit(node.value)
785+
return self.visit(node.value) # type: ignore[attr-defined]
772786

773787
def _convert_nested_subscripts(self, node: ast.Subscript) -> tuple[str, list[str]]:
774788
"""Helper function to convert nested subscription.

‎src/latexify/codegen/latex.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self, raw: str) -> None:
2121
"""
2222
self._raw = raw
2323

24-
def __eq__(self, other: object) -> None:
24+
def __eq__(self, other: object) -> bool:
2525
"""Checks equality.
2626
2727
Args:

‎src/latexify/frontend.py

+41-24
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33
from __future__ import annotations
44

5-
import warnings
65
from collections.abc import Callable
7-
from typing import Any
6+
from typing import Any, overload
87

98
from latexify import codegen
109
from latexify import config as cfg
@@ -119,45 +118,63 @@ def _repr_latex_(self):
119118
)
120119

121120

122-
def function(*args, **kwargs) -> Callable[[Callable[..., Any]], LatexifiedFunction]:
123-
"""Translate a function into a corresponding LaTeX representation.
121+
@overload
122+
def function(fn: Callable[..., Any], **kwargs: Any) -> LatexifiedFunction:
123+
...
124+
125+
126+
@overload
127+
def function(**kwargs: Any) -> Callable[[Callable[..., Any]], LatexifiedFunction]:
128+
...
129+
130+
131+
def function(
132+
fn: Callable[..., Any] | None = None, **kwargs: Any
133+
) -> LatexifiedFunction | Callable[[Callable[..., Any]], LatexifiedFunction]:
134+
"""Attach LaTeX pretty-printing to the given function.
124135
125136
This function works with or without specifying the target function as the positional
126137
argument. The following two syntaxes works similarly.
127-
- with_latex(fn, **kwargs)
128-
- with_latex(**kwargs)(fn)
138+
- latexify.function(fn, **kwargs)
139+
- latexify.function(**kwargs)(fn)
129140
130141
Args:
131-
*args: No argument, or a callable.
142+
fn: Callable to be wrapped.
132143
**kwargs: Arguments to control behavior. See also get_latex().
133144
134145
Returns:
135-
- If the target function is passed directly, returns the wrapped function.
146+
- If `fn` is passed, returns the wrapped function.
136147
- Otherwise, returns the wrapper function with given settings.
137148
"""
138-
if len(args) == 1 and isinstance(args[0], Callable):
139-
return LatexifiedFunction(args[0], **kwargs)
140-
141-
def wrapper(fn):
149+
if fn is not None:
142150
return LatexifiedFunction(fn, **kwargs)
143151

152+
def wrapper(f):
153+
return LatexifiedFunction(f, **kwargs)
154+
144155
return wrapper
145156

146157

147-
def expression(*args, **kwargs) -> Callable[[Callable[..., Any]], LatexifiedFunction]:
148-
"""Translate a function into a LaTeX representation without the signature.
158+
@overload
159+
def expression(fn: Callable[..., Any], **kwargs: Any) -> LatexifiedFunction:
160+
...
161+
162+
163+
@overload
164+
def expression(**kwargs: Any) -> Callable[[Callable[..., Any]], LatexifiedFunction]:
165+
...
166+
167+
168+
def expression(
169+
fn: Callable[..., Any] | None = None, **kwargs: Any
170+
) -> LatexifiedFunction | Callable[[Callable[..., Any]], LatexifiedFunction]:
171+
"""Attach LaTeX pretty-printing to the given function.
149172
150173
This function is a shortcut for `latexify.function` with the default parameter
151174
`use_signature=False`.
152175
"""
153176
kwargs["use_signature"] = kwargs.get("use_signature", False)
154-
return function(*args, **kwargs)
155-
156-
157-
def with_latex(*args, **kwargs) -> Callable[[Callable[..., Any]], LatexifiedFunction]:
158-
"""Deprecated. use `latexify.function` instead."""
159-
warnings.warn(
160-
"`latexify.with_latex` is deprecated. Use `latexify.function` instead.",
161-
DeprecationWarning,
162-
)
163-
return function(*args, **kwargs)
177+
if fn is not None:
178+
return function(fn, **kwargs)
179+
else:
180+
return function(**kwargs)

‎src/latexify/parser.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
from collections.abc import Callable
99
from typing import Any
1010

11-
import dill
11+
import dill # type: ignore[import]
1212

1313
from latexify import exceptions
1414

1515

16-
def parse_function(fn: Callable[..., Any]) -> ast.FunctionDef:
16+
def parse_function(fn: Callable[..., Any]) -> ast.Module:
1717
"""Parses given function.
1818
1919
Args:

‎src/latexify/transformers/__init__.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from latexify.transformers.prefix_trimmer import PrefixTrimmer
77

88
__all__ = [
9-
AssignmentReducer,
10-
FunctionExpander,
11-
IdentifierReplacer,
12-
PrefixTrimmer,
9+
"AssignmentReducer",
10+
"FunctionExpander",
11+
"IdentifierReplacer",
12+
"PrefixTrimmer",
1313
]

‎src/latexify/transformers/identifier_replacer.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import ast
66
import re
77
import sys
8-
from typing import ClassVar
8+
from typing import ClassVar, cast
99

1010

1111
class IdentifierReplacer(ast.NodeTransformer):
@@ -47,33 +47,32 @@ def _replace_args(self, args: list[ast.arg]) -> list[ast.arg]:
4747
"""Helper function to replace arg names."""
4848
return [ast.arg(arg=self._mapping.get(a.arg, a.arg)) for a in args]
4949

50-
def _visit_children(self, children: list[ast.AST]) -> list[ast.AST]:
51-
"""Helper function to visit all children."""
52-
return [self.visit(child) for child in children]
53-
5450
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
5551
"""Visitor of FunctionDef."""
52+
53+
visited = cast(ast.FunctionDef, super().generic_visit(node))
54+
5655
if sys.version_info.minor < 8:
5756
args = ast.arguments(
58-
args=self._replace_args(node.args.args),
59-
kwonlyargs=self._replace_args(node.args.kwonlyargs),
60-
kw_defaults=self._visit_children(node.args.kw_defaults),
61-
defaults=self._visit_children(node.args.defaults),
57+
args=self._replace_args(visited.args.args),
58+
kwonlyargs=self._replace_args(visited.args.kwonlyargs),
59+
kw_defaults=visited.args.kw_defaults,
60+
defaults=visited.args.defaults,
6261
)
6362
else:
6463
args = ast.arguments(
65-
posonlyargs=self._replace_args(node.args.posonlyargs), # from 3.8
66-
args=self._replace_args(node.args.args),
67-
kwonlyargs=self._replace_args(node.args.kwonlyargs),
68-
kw_defaults=self._visit_children(node.args.kw_defaults),
69-
defaults=self._visit_children(node.args.defaults),
64+
posonlyargs=self._replace_args(visited.args.posonlyargs), # from 3.8
65+
args=self._replace_args(visited.args.args),
66+
kwonlyargs=self._replace_args(visited.args.kwonlyargs),
67+
kw_defaults=visited.args.kw_defaults,
68+
defaults=visited.args.defaults,
7069
)
7170

7271
return ast.FunctionDef(
73-
name=self._mapping.get(node.name, node.name),
72+
name=self._mapping.get(visited.name, visited.name),
7473
args=args,
75-
body=self._visit_children(node.body),
76-
decorator_list=self._visit_children(node.decorator_list),
74+
body=visited.body,
75+
decorator_list=visited.decorator_list,
7776
)
7877

7978
def visit_Name(self, node: ast.Name) -> ast.Name:

0 commit comments

Comments
 (0)
Please sign in to comment.