diff --git a/src/latexify/codegen/algorithmic_codegen.py b/src/latexify/codegen/algorithmic_codegen.py index 685663c..e53fbea 100644 --- a/src/latexify/codegen/algorithmic_codegen.py +++ b/src/latexify/codegen/algorithmic_codegen.py @@ -36,7 +36,8 @@ def __init__( use_math_symbols=use_math_symbols, use_set_symbols=use_set_symbols ) self._identifier_converter = identifier_converter.IdentifierConverter( - use_math_symbols=use_math_symbols + use_math_symbols=use_math_symbols, + use_mathrm=False, ) self._indent_level = 0 @@ -63,6 +64,8 @@ def visit_Expr(self, node: ast.Expr) -> str: # TODO(ZibingZhang): support nested functions def visit_FunctionDef(self, node: ast.FunctionDef) -> str: """Visit a FunctionDef node.""" + name_latex = self._identifier_converter.convert(node.name)[0] + # Arguments arg_strs = [ self._identifier_converter.convert(arg.arg)[0] for arg in node.args.args @@ -71,7 +74,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> str: latex = self._add_indent("\\begin{algorithmic}\n") with self._increment_level(): latex += self._add_indent( - f"\\Function{{{node.name}}}{{${', '.join(arg_strs)}$}}\n" + f"\\Function{{{name_latex}}}{{${', '.join(arg_strs)}$}}\n" ) with self._increment_level(): @@ -197,6 +200,8 @@ def visit_Expr(self, node: ast.Expr) -> str: # TODO(ZibingZhang): support nested functions def visit_FunctionDef(self, node: ast.FunctionDef) -> str: """Visit a FunctionDef node.""" + name_latex = self._identifier_converter.convert(node.name)[0] + # Arguments arg_strs = [ self._identifier_converter.convert(arg.arg)[0] for arg in node.args.args @@ -209,7 +214,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> str: return ( r"\begin{array}{l} " + self._add_indent(r"\mathbf{function}") - + rf" \ \mathrm{{{node.name}}}({', '.join(arg_strs)})" + + rf" \ {name_latex}({', '.join(arg_strs)})" + f"{self._LINE_BREAK}{body}{self._LINE_BREAK}" + self._add_indent(r"\mathbf{end \ function}") + r" \end{array}" diff --git a/src/latexify/codegen/algorithmic_codegen_test.py b/src/latexify/codegen/algorithmic_codegen_test.py index a0781c7..e7b3236 100644 --- a/src/latexify/codegen/algorithmic_codegen_test.py +++ b/src/latexify/codegen/algorithmic_codegen_test.py @@ -174,7 +174,7 @@ def test_visit_while_with_else() -> None: ("a = b = 0", r"a \gets b \gets 0"), ], ) -def test_visit_assign_jupyter(code: str, latex: str) -> None: +def test_visit_assign_ipython(code: str, latex: str) -> None: node = ast.parse(textwrap.dedent(code)).body[0] assert isinstance(node, ast.Assign) assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex @@ -188,7 +188,7 @@ def test_visit_assign_jupyter(code: str, latex: str) -> None: ( r"\begin{array}{l}" r" \mathbf{function}" - r" \ \mathrm{f}(x) \\" + r" \ f(x) \\" r" \hspace{1em} \mathbf{return} \ x \\" r" \mathbf{end \ function}" r" \end{array}" @@ -199,7 +199,7 @@ def test_visit_assign_jupyter(code: str, latex: str) -> None: ( r"\begin{array}{l}" r" \mathbf{function}" - r" \ \mathrm{f}(a, b, c) \\" + r" \ f(a, b, c) \\" r" \hspace{1em} \mathbf{return} \ 3 \\" r" \mathbf{end \ function}" r" \end{array}" diff --git a/src/latexify/codegen/identifier_converter.py b/src/latexify/codegen/identifier_converter.py index 5290867..19263f8 100644 --- a/src/latexify/codegen/identifier_converter.py +++ b/src/latexify/codegen/identifier_converter.py @@ -16,15 +16,19 @@ class IdentifierConverter: """ _use_math_symbols: bool + _use_mathrm: bool - def __init__(self, *, use_math_symbols: bool) -> None: - """Initializer. + def __init__(self, *, use_math_symbols: bool, use_mathrm: bool = True) -> None: + r"""Initializer. Args: use_math_symbols: Whether to convert identifiers with math symbol names to appropriate LaTeX command. + use_mathrm: Whether to wrap the resulting expression by \mathrm, if + applicable. """ self._use_math_symbols = use_math_symbols + self._use_mathrm = use_mathrm def convert(self, name: str) -> tuple[str, bool]: """Converts Python identifier to LaTeX expression. @@ -44,4 +48,7 @@ def convert(self, name: str) -> tuple[str, bool]: if len(name) == 1 and name != "_": return name, True - return r"\mathrm{" + name.replace("_", r"\_") + "}", False + escaped = name.replace("_", r"\_") + wrapped = rf"\mathrm{{{escaped}}}" if self._use_mathrm else escaped + + return wrapped, False diff --git a/src/latexify/codegen/identifier_converter_test.py b/src/latexify/codegen/identifier_converter_test.py index d507a9b..b46982d 100644 --- a/src/latexify/codegen/identifier_converter_test.py +++ b/src/latexify/codegen/identifier_converter_test.py @@ -8,31 +8,32 @@ @pytest.mark.parametrize( - "name,use_math_symbols,expected", + "name,use_math_symbols,use_mathrm,expected", [ - ("a", False, ("a", True)), - ("_", False, (r"\mathrm{\_}", False)), - ("aa", False, (r"\mathrm{aa}", False)), - ("a1", False, (r"\mathrm{a1}", False)), - ("a_", False, (r"\mathrm{a\_}", False)), - ("_a", False, (r"\mathrm{\_a}", False)), - ("_1", False, (r"\mathrm{\_1}", False)), - ("__", False, (r"\mathrm{\_\_}", False)), - ("a_a", False, (r"\mathrm{a\_a}", False)), - ("a__", False, (r"\mathrm{a\_\_}", False)), - ("a_1", False, (r"\mathrm{a\_1}", False)), - ("alpha", False, (r"\mathrm{alpha}", False)), - ("alpha", True, (r"\alpha", True)), - ("foo", False, (r"\mathrm{foo}", False)), - ("foo", True, (r"\mathrm{foo}", False)), + ("a", False, True, ("a", True)), + ("_", False, True, (r"\mathrm{\_}", False)), + ("aa", False, True, (r"\mathrm{aa}", False)), + ("a1", False, True, (r"\mathrm{a1}", False)), + ("a_", False, True, (r"\mathrm{a\_}", False)), + ("_a", False, True, (r"\mathrm{\_a}", False)), + ("_1", False, True, (r"\mathrm{\_1}", False)), + ("__", False, True, (r"\mathrm{\_\_}", False)), + ("a_a", False, True, (r"\mathrm{a\_a}", False)), + ("a__", False, True, (r"\mathrm{a\_\_}", False)), + ("a_1", False, True, (r"\mathrm{a\_1}", False)), + ("alpha", False, True, (r"\mathrm{alpha}", False)), + ("alpha", True, True, (r"\alpha", True)), + ("foo", False, True, (r"\mathrm{foo}", False)), + ("foo", True, True, (r"\mathrm{foo}", False)), + ("foo", True, False, (r"foo", False)), ], ) def test_identifier_converter( - name: str, use_math_symbols: bool, expected: tuple[str, bool] + name: str, use_math_symbols: bool, use_mathrm: bool, expected: tuple[str, bool] ) -> None: assert ( identifier_converter.IdentifierConverter( - use_math_symbols=use_math_symbols + use_math_symbols=use_math_symbols, use_mathrm=use_mathrm ).convert(name) == expected )