Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IdentifierConverter #139

Merged
merged 4 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ print(solve)
```

```
\\mathrm{f}(n) = \\frac{-b + \\sqrt{b^{{2}} - {4}ac}}{{2}a}
f(n) = \\frac{-b + \\sqrt{b^{{2}} - {4}ac}}{{2}a}
```

`latexify.expression` works similarly to `latexify.function`,
Expand Down Expand Up @@ -85,5 +85,5 @@ latexify.get_latex(solve)
```

```
\\mathrm{f}(n) = \\frac{-b + \\sqrt{b^{{2}} - {4}ac}}{{2}a}
f(n) = \\frac{-b + \\sqrt{b^{{2}} - {4}ac}}{{2}a}
```
22 changes: 3 additions & 19 deletions docs/parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def my_function(my_argument):
my_function
```

$$\mathrm{f}(x) = \mathrm{g}\left(x\right)$$
$$f(x) = \mathrm{g}\left(x\right)$$


## `reduce_assignments: bool`
Expand All @@ -42,7 +42,7 @@ def f(a, b, c):
f
```

$$\mathrm{f}(a, b, c) = \frac{-b + \sqrt{b^{{2}} - {4} a c}}{{2} a}$$
$$f(a, b, c) = \frac{-b + \sqrt{b^{{2}} - {4} a c}}{{2} a}$$


## `use_math_symbols: bool`
Expand All @@ -60,22 +60,6 @@ greek
$$\mathrm{greek}({\alpha}, {\beta}, {\gamma}, {\Omega}) = {\alpha} {\beta} + \Gamma\left({{\gamma}}\right) + {\Omega}$$


## `use_raw_function_name: bool`

Whether to use the original string as the function name or not.

```python
@latexify.function(use_raw_function_name=True)
def quadratic_solution(a, b, c):
return (-b + math.sqrt(b**2 - 4 * a * c)) / (2 * a)

f
```

$$\mathrm{quadratic\_solution}(a, b, c) = \frac{-b + \sqrt{b^{{2}} - {4} a c}}{{2} a}$$

(note that GitHub's LaTeX renderer does not process underscores correctly.)

## `use_set_symbols: bool`

Whether to use binary operators for set operations or not.
Expand All @@ -88,7 +72,7 @@ def f(x, y):
f
```

$$\mathrm{f}(x, y) = \left( x \cap y\space,\space x \cup y\space,\space x \setminus y\space,\space x \mathbin{\triangle} y\space,\space {x \subset y}\space,\space {x \subseteq y}\space,\space {x \supset y}\space,\space {x \supseteq y}\right)$$
$$f(x, y) = \left( x \cap y\space,\space x \cup y\space,\space x \setminus y\space,\space x \mathbin{\triangle} y\space,\space {x \subset y}\space,\space {x \subseteq y}\space,\space {x \supset y}\space,\space {x \supseteq y}\right)$$


## `use_signature: bool`
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,6 @@ path = "src/latexify/_version.py"
[tool.flake8]
max-line-length = 88
extend-ignore = "E203"

[tool.isort]
profile = "black"
45 changes: 14 additions & 31 deletions src/integration_tests/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def test_x_times_beta() -> None:
def xtimesbeta(x, beta):
return x * beta

latex_without_symbols = r"\mathrm{xtimesbeta}(x, beta) = x beta"
latex_without_symbols = r"\mathrm{xtimesbeta}(x, \mathrm{beta}) = x \mathrm{beta}"
utils.check_function(xtimesbeta, latex_without_symbols)
utils.check_function(xtimesbeta, latex_without_symbols, use_math_symbols=False)

latex_with_symbols = r"\mathrm{xtimesbeta}(x, {\beta}) = x {\beta}"
latex_with_symbols = r"\mathrm{xtimesbeta}(x, \beta) = x \beta"
utils.check_function(xtimesbeta, latex_with_symbols, use_math_symbols=True)


Expand All @@ -49,7 +49,7 @@ def sum_with_limit(n):
return sum(i**2 for i in range(n))

latex = (
r"\mathrm{sum_with_limit}(n) = \sum_{i = {0}}^{{n - 1}}"
r"\mathrm{sum\_with\_limit}(n) = \sum_{i = {0}}^{{n - 1}}"
r" \mathopen{}\left({i^{{2}}}\mathclose{}\right)"
)
utils.check_function(sum_with_limit, latex)
Expand All @@ -60,7 +60,7 @@ def sum_with_limit(a, n):
return sum(i**2 for i in range(a, n))

latex = (
r"\mathrm{sum_with_limit}(a, n) = \sum_{i = a}^{{n - 1}} "
r"\mathrm{sum\_with\_limit}(a, n) = \sum_{i = a}^{{n - 1}} "
r"\mathopen{}\left({i^{{2}}}\mathclose{}\right)"
)
utils.check_function(sum_with_limit, latex)
Expand All @@ -71,7 +71,7 @@ def sum_with_limit(n):
return sum(i for i in range(n + 1))

latex = (
r"\mathrm{sum_with_limit}(n) = \sum_{i = {0}}^{{n}} "
r"\mathrm{sum\_with\_limit}(n) = \sum_{i = {0}}^{{n}} "
r"\mathopen{}\left({i}\mathclose{}\right)"
)
utils.check_function(sum_with_limit, latex)
Expand All @@ -82,7 +82,7 @@ def sum_with_limit(n):
return sum(i for i in range(n * 3))

latex = (
r"\mathrm{sum_with_limit}(n) = \sum_{i = {0}}^{{n {3} - 1}} "
r"\mathrm{sum\_with\_limit}(n) = \sum_{i = {0}}^{{n {3} - 1}} "
r"\mathopen{}\left({i}\mathclose{}\right)"
)
utils.check_function(sum_with_limit, latex)
Expand All @@ -93,7 +93,7 @@ def prod_with_limit(n):
return math.prod(i**2 for i in range(n))

latex = (
r"\mathrm{prod_with_limit}(n) = "
r"\mathrm{prod\_with\_limit}(n) = "
r"\prod_{i = {0}}^{{n - 1}} \mathopen{}\left({i^{{2}}}\mathclose{}\right)"
)
utils.check_function(prod_with_limit, latex)
Expand All @@ -104,7 +104,7 @@ def prod_with_limit(a, n):
return math.prod(i**2 for i in range(a, n))

latex = (
r"\mathrm{prod_with_limit}(a, n) = "
r"\mathrm{prod\_with\_limit}(a, n) = "
r"\prod_{i = a}^{{n - 1}} \mathopen{}\left({i^{{2}}}\mathclose{}\right)"
)
utils.check_function(prod_with_limit, latex)
Expand All @@ -115,7 +115,7 @@ def prod_with_limit(n):
return math.prod(i for i in range(n - 1))

latex = (
r"\mathrm{prod_with_limit}(n) = "
r"\mathrm{prod\_with\_limit}(n) = "
r"\prod_{i = {0}}^{{n - {2}}} \mathopen{}\left({i}\mathclose{}\right)"
)
utils.check_function(prod_with_limit, latex)
Expand All @@ -126,7 +126,7 @@ def prod_with_limit(n):
return math.prod(i for i in range(n * 3))

latex = (
r"\mathrm{prod_with_limit}(n) = "
r"\mathrm{prod\_with\_limit}(n) = "
r"\prod_{i = {0}}^{{n {3} - 1}} \mathopen{}\left({i}\mathclose{}\right)"
)
utils.check_function(prod_with_limit, latex)
Expand All @@ -149,35 +149,18 @@ def inner(y):
utils.check_function(nested(3), r"\mathrm{inner}(y) = x y")


def test_use_raw_function_name() -> None:
def foo_bar():
return 42

utils.check_function(foo_bar, r"\mathrm{foo_bar}() = {42}")
utils.check_function(
foo_bar,
r"\mathrm{foo_bar}() = {42}",
use_raw_function_name=False,
)
utils.check_function(
foo_bar,
r"\mathrm{foo\_bar}() = {42}",
use_raw_function_name=True,
)


def test_reduce_assignments() -> None:
def f(x):
a = x + x
return 3 * a

utils.check_function(
f,
r"\begin{array}{l} a = x + x \\ \mathrm{f}(x) = {3} a \end{array}",
r"\begin{array}{l} a = x + x \\ f(x) = {3} a \end{array}",
)
utils.check_function(
f,
r"\mathrm{f}(x) = {3} \mathopen{}\left( x + x \mathclose{}\right)",
r"f(x) = {3} \mathopen{}\left( x + x \mathclose{}\right)",
reduce_assignments=True,
)

Expand All @@ -192,15 +175,15 @@ def f(x):
r"\begin{array}{l} "
r"a = x^{{2}} \\ "
r"b = a + a \\ "
r"\mathrm{f}(x) = {3} b "
r"f(x) = {3} b "
r"\end{array}"
)

utils.check_function(f, latex_without_option)
utils.check_function(f, latex_without_option, reduce_assignments=False)
utils.check_function(
f,
r"\mathrm{f}(x) = {3} \mathopen{}\left( x^{{2}} + x^{{2}} \mathclose{}\right)",
r"f(x) = {3} \mathopen{}\left( x^{{2}} + x^{{2}} \mathclose{}\right)",
reduce_assignments=True,
)

Expand Down
48 changes: 22 additions & 26 deletions src/latexify/codegen/function_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import sys
from typing import Any

from latexify import analyzers, ast_utils, constants, exceptions, math_symbols
from latexify import analyzers, ast_utils, constants, exceptions
from latexify.codegen import identifier_converter

# Precedences of operators for BoolOp, BinOp, UnaryOp, and Compare nodes.
# Note that this value affects only the appearance of surrounding parentheses for each
Expand Down Expand Up @@ -194,8 +195,7 @@ class FunctionCodegen(ast.NodeVisitor):
LaTeX expression of the given function.
"""

_math_symbol_converter: math_symbols.MathSymbolConverter
_use_raw_function_name: bool
_identifier_converter: identifier_converter.IdentifierConverter
_use_signature: bool

_bin_op_rules: dict[type[ast.operator], BinOpRule]
Expand All @@ -205,7 +205,6 @@ def __init__(
self,
*,
use_math_symbols: bool = False,
use_raw_function_name: bool = False,
use_signature: bool = True,
use_set_symbols: bool = False,
) -> None:
Expand All @@ -214,16 +213,13 @@ def __init__(
Args:
use_math_symbols: Whether to convert identifiers with a math symbol surface
(e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha").
use_raw_function_name: Whether to keep underscores "_" in the function name,
or convert it to subscript.
use_signature: Whether to add the function signature before the expression
or not.
use_set_symbols: Whether to use set symbols or not.
"""
self._math_symbol_converter = math_symbols.MathSymbolConverter(
enabled=use_math_symbols
self._identifier_converter = identifier_converter.IdentifierConverter(
use_math_symbols=use_math_symbols
)
self._use_raw_function_name = use_raw_function_name
self._use_signature = use_signature

self._bin_op_rules = _SET_BIN_OP_RULES if use_set_symbols else _BIN_OP_RULES
Expand All @@ -239,14 +235,11 @@ def visit_Module(self, node: ast.Module) -> str:

def visit_FunctionDef(self, node: ast.FunctionDef) -> str:
# Function name
name_str = str(node.name)
if self._use_raw_function_name:
name_str = name_str.replace(r"_", r"\_")
name_str = r"\mathrm{" + name_str + "}"
name_str = self._identifier_converter.convert(node.name)[0]

# Arguments
arg_strs = [
self._math_symbol_converter.convert(str(arg.arg)) for arg in node.args.args
self._identifier_converter.convert(arg.arg)[0] for arg in node.args.args
]

body_strs: list[str] = []
Expand Down Expand Up @@ -343,33 +336,36 @@ def visit_comprehension(self, node: ast.comprehension) -> str:

def visit_Call(self, node: ast.Call) -> str:
"""Visit a call node."""
# Function signature (possibly an expression).
func_str = self.visit(node.func)

# Obtains wrapper syntax: sqrt -> "\sqrt{" and "}"
lstr, rstr = constants.BUILTIN_FUNCS.get(
func_str,
(r"\mathrm{" + func_str + r"}\mathopen{}\left(", r"\mathclose{}\right)"),
)
func_name = ast_utils.extract_function_name_or_none(node)

if func_str in ("sum", "prod") and isinstance(node.args[0], ast.GeneratorExp):
# Special processing for sum and prod.
if func_name in ("sum", "prod") and isinstance(node.args[0], ast.GeneratorExp):
elt, scripts = self._get_sum_prod_info(node.args[0])
scripts_str = [rf"\{func_str}_{{{lo}}}^{{{up}}}" for lo, up in scripts]
scripts_str = [rf"\{func_name}_{{{lo}}}^{{{up}}}" for lo, up in scripts]
return (
" ".join(scripts_str)
+ rf" \mathopen{{}}\left({{{elt}}}\mathclose{{}}\right)"
)

# Function signature (possibly an expression).
default_func_str = self.visit(node.func)

# Obtains wrapper syntax: sqrt -> "\sqrt{" and "}"
lstr, rstr = constants.BUILTIN_FUNCS.get(
func_name,
(default_func_str + r"\mathopen{}\left(", r"\mathclose{}\right)"),
)

arg_strs = [self.visit(arg) for arg in node.args]
return lstr + ", ".join(arg_strs) + rstr

def visit_Attribute(self, node: ast.Attribute) -> str:
vstr = self.visit(node.value)
astr = str(node.attr)
astr = self._identifier_converter.convert(node.attr)[0]
return vstr + "." + astr

def visit_Name(self, node: ast.Name) -> str:
return self._math_symbol_converter.convert(str(node.id))
return self._identifier_converter.convert(node.id)[0]

def _convert_constant(self, value: Any) -> str:
"""Helper to convert constant values to LaTeX.
Expand Down
Loading