Skip to content

Commit 4f0a208

Browse files
ZibingZhangZibing Zhang
and
Zibing Zhang
authored
Allow docstring in function definition (#126)
* Allow docstring * Suggestions * Remove obvious comment * use ast.parse * edited suggested format * Suggestions 2 * formatting Co-authored-by: Zibing Zhang <zizhang@hubspot.com>
1 parent 866cca5 commit 4f0a208

File tree

5 files changed

+118
-12
lines changed

5 files changed

+118
-12
lines changed

src/integration_tests/regression_test.py

+20
Original file line numberDiff line numberDiff line change
@@ -279,3 +279,23 @@ def solve(a, b):
279279
r"a - b \mathclose{}\right) - a b"
280280
)
281281
_check_function(solve, latex)
282+
283+
284+
def test_docstring_allowed() -> None:
285+
def solve(x):
286+
"""The identity function."""
287+
return x
288+
289+
latex = r"\mathrm{solve}(x) = x"
290+
_check_function(solve, latex)
291+
292+
293+
def test_multiple_constants_allowed() -> None:
294+
def solve(x):
295+
"""The identity function."""
296+
123
297+
True
298+
return x
299+
300+
latex = r"\mathrm{solve}(x) = x"
301+
_check_function(solve, latex)

src/latexify/ast_utils.py

+18
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,24 @@ def make_constant(value: Any) -> ast.expr:
4141
raise ValueError(f"Unsupported type to generate Constant: {type(value).__name__}")
4242

4343

44+
def is_constant(node: ast.AST) -> bool:
45+
"""Checks if the node is a constant.
46+
47+
Args:
48+
node: The node to examine.
49+
50+
Returns:
51+
True if the node is a constant, False otherwise.
52+
"""
53+
if sys.version_info.minor < 8:
54+
return isinstance(
55+
node,
56+
(ast.Bytes, ast.Constant, ast.Ellipsis, ast.NameConstant, ast.Num, ast.Str),
57+
)
58+
else:
59+
return isinstance(node, ast.Constant)
60+
61+
4462
def extract_int_or_none(node: ast.expr) -> int | None:
4563
"""Extracts int constant from the given Constant node.
4664

src/latexify/ast_utils_test.py

+31
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,37 @@ def test_make_constant_invalid() -> None:
5959
ast_utils.make_constant(object())
6060

6161

62+
@test_utils.require_at_most(7)
63+
@pytest.mark.parametrize(
64+
"value,expected",
65+
[
66+
(ast.Bytes(s=b"foo"), True),
67+
(ast.Constant("bar"), True),
68+
(ast.Ellipsis(), True),
69+
(ast.NameConstant(value=None), True),
70+
(ast.Num(n=123), True),
71+
(ast.Str(s="baz"), True),
72+
(ast.Expr(value=ast.Num(456)), False),
73+
(ast.Global("qux"), False),
74+
],
75+
)
76+
def test_is_constant_legacy(value: ast.AST, expected: bool) -> None:
77+
assert ast_utils.is_constant(value) is expected
78+
79+
80+
@test_utils.require_at_least(8)
81+
@pytest.mark.parametrize(
82+
"value,expected",
83+
[
84+
(ast.Constant("foo"), True),
85+
(ast.Expr(value=ast.Constant(123)), False),
86+
(ast.Global("bar"), False),
87+
],
88+
)
89+
def test_is_constant(value: ast.AST, expected: bool) -> None:
90+
assert ast_utils.is_constant(value) is expected
91+
92+
6293
def test_extract_int_or_none() -> None:
6394
assert ast_utils.extract_int_or_none(ast_utils.make_constant(-123)) == -123
6495
assert ast_utils.extract_int_or_none(ast_utils.make_constant(0)) == 0

src/latexify/codegen/function_codegen.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sys
88
from typing import Any
99

10-
from latexify import analyzers, constants, exceptions, math_symbols
10+
from latexify import analyzers, ast_utils, constants, exceptions, math_symbols
1111

1212
# Precedences of operators for BoolOp, BinOp, UnaryOp, and Compare nodes.
1313
# Note that this value affects only the appearance of surrounding parentheses for each
@@ -253,6 +253,9 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> str:
253253

254254
# Assignment statements (if any): x = ...
255255
for child in node.body[:-1]:
256+
if isinstance(child, ast.Expr) and ast_utils.is_constant(child.value):
257+
continue
258+
256259
if not isinstance(child, ast.Assign):
257260
raise exceptions.LatexifyNotSupportedError(
258261
"Codegen supports only Assign nodes in multiline functions, "

src/latexify/codegen/function_codegen_test.py

+45-11
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import ast
6+
import textwrap
67

78
import pytest
89

@@ -22,24 +23,57 @@ class UnknownNode(ast.AST):
2223

2324

2425
def test_visit_functiondef_use_signature() -> None:
25-
tree = ast.FunctionDef(
26-
name="f",
27-
args=ast.arguments(
28-
args=[ast.arg(arg="x")],
29-
kwonlyargs=[],
30-
kw_defaults=[],
31-
defaults=[],
32-
),
33-
body=[ast.Return(value=ast.Name(id="x", ctx=ast.Load()))],
34-
decorator_list=[],
35-
)
26+
tree = ast.parse(
27+
textwrap.dedent(
28+
"""
29+
def f(x):
30+
return x
31+
"""
32+
)
33+
).body[0]
34+
assert isinstance(tree, ast.FunctionDef)
35+
3636
latex_without_flag = "x"
3737
latex_with_flag = r"\mathrm{f}(x) = x"
3838
assert FunctionCodegen().visit(tree) == latex_with_flag
3939
assert FunctionCodegen(use_signature=False).visit(tree) == latex_without_flag
4040
assert FunctionCodegen(use_signature=True).visit(tree) == latex_with_flag
4141

4242

43+
def test_visit_functiondef_ignore_docstring() -> None:
44+
tree = ast.parse(
45+
textwrap.dedent(
46+
"""
47+
def f(x):
48+
'''docstring'''
49+
return x
50+
"""
51+
)
52+
).body[0]
53+
assert isinstance(tree, ast.FunctionDef)
54+
55+
latex = r"\mathrm{f}(x) = x"
56+
assert FunctionCodegen().visit(tree) == latex
57+
58+
59+
def test_visit_functiondef_ignore_multiple_constants() -> None:
60+
tree = ast.parse(
61+
textwrap.dedent(
62+
"""
63+
def f(x):
64+
'''docstring'''
65+
3
66+
True
67+
return x
68+
"""
69+
)
70+
).body[0]
71+
assert isinstance(tree, ast.FunctionDef)
72+
73+
latex = r"\mathrm{f}(x) = x"
74+
assert FunctionCodegen().visit(tree) == latex
75+
76+
4377
@pytest.mark.parametrize(
4478
"code,latex",
4579
[

0 commit comments

Comments
 (0)