Skip to content

Commit 78a6d01

Browse files
author
Yusuke Oda
authored
set operations (#94)
1 parent 5dde4e4 commit 78a6d01

File tree

4 files changed

+75
-2
lines changed

4 files changed

+75
-2
lines changed

src/latexify/codegen/function_codegen.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,17 @@ class BinOpRule:
144144
ast.BitOr: BinOpRule("", r" \mathbin{|} ", ""),
145145
}
146146

147+
# Typeset for BinOp of sets.
148+
_SET_BIN_OP_RULES: dict[type[ast.operator], BinOpRule] = {
149+
**_BIN_OP_RULES,
150+
ast.Sub: BinOpRule(
151+
"", r" \setminus ", "", operand_right=BinOperandRule(force=True)
152+
),
153+
ast.BitAnd: BinOpRule("", r" \cap ", ""),
154+
ast.BitXor: BinOpRule("", r" \mathbin{\triangle} ", ""),
155+
ast.BitOr: BinOpRule("", r" \cup ", ""),
156+
}
157+
147158
_UNARY_OPS: dict[type[ast.unaryop], str] = {
148159
ast.Invert: r"\mathord{\sim} ",
149160
ast.UAdd: "+", # Explicitly adds the $+$ operator.
@@ -164,6 +175,15 @@ class BinOpRule:
164175
ast.NotIn: r"\notin",
165176
}
166177

178+
# Typeset for Compare of sets.
179+
_SET_COMPARE_OPS: dict[type[ast.cmpop], str] = {
180+
**_COMPARE_OPS,
181+
ast.Gt: r"\supset",
182+
ast.GtE: r"\supseteq",
183+
ast.Lt: r"\subset",
184+
ast.LtE: r"\subseteq",
185+
}
186+
167187
_BOOL_OPS: dict[type[ast.boolop], str] = {
168188
ast.And: r"\land",
169189
ast.Or: r"\lor",
@@ -181,12 +201,16 @@ class FunctionCodegen(ast.NodeVisitor):
181201
_use_raw_function_name: bool
182202
_use_signature: bool
183203

204+
_bin_op_rules: dict[type[ast.operator], BinOpRule]
205+
_compare_ops: dict[type[ast.cmpop], str]
206+
184207
def __init__(
185208
self,
186209
*,
187210
use_math_symbols: bool = False,
188211
use_raw_function_name: bool = False,
189212
use_signature: bool = True,
213+
use_set_symbols: bool = False,
190214
) -> None:
191215
"""Initializer.
192216
@@ -197,13 +221,17 @@ def __init__(
197221
or convert it to subscript.
198222
use_signature: Whether to add the function signature before the expression
199223
or not.
224+
use_set_symbols: Whether to use set symbols or not.
200225
"""
201226
self._math_symbol_converter = math_symbols.MathSymbolConverter(
202227
enabled=use_math_symbols
203228
)
204229
self._use_raw_function_name = use_raw_function_name
205230
self._use_signature = use_signature
206231

232+
self._bin_op_rules = _SET_BIN_OP_RULES if use_set_symbols else _BIN_OP_RULES
233+
self._compare_ops = _SET_COMPARE_OPS if use_set_symbols else _COMPARE_OPS
234+
207235
def generic_visit(self, node: ast.AST) -> str:
208236
raise exceptions.LatexifyNotSupportedError(
209237
f"Unsupported AST: {type(node).__name__}"
@@ -445,7 +473,7 @@ def _wrap_binop_operand(
445473
def visit_BinOp(self, node: ast.BinOp) -> str:
446474
"""Visit a BinOp node."""
447475
prec = _get_precedence(node)
448-
rule = _BIN_OP_RULES[type(node.op)]
476+
rule = self._bin_op_rules[type(node.op)]
449477
lhs = self._wrap_binop_operand(node.left, prec, rule.operand_left)
450478
rhs = self._wrap_binop_operand(node.right, prec, rule.operand_right)
451479
return f"{rule.latex_left}{lhs}{rule.latex_middle}{rhs}{rule.latex_right}"
@@ -459,7 +487,7 @@ def visit_Compare(self, node: ast.Compare) -> str:
459487
"""Visit a compare node."""
460488
parent_prec = _get_precedence(node)
461489
lhs = self._wrap_operand(node.left, parent_prec)
462-
ops = [_COMPARE_OPS[type(x)] for x in node.ops]
490+
ops = [self._compare_ops[type(x)] for x in node.ops]
463491
rhs = [self._wrap_operand(x, parent_prec) for x in node.comparators]
464492
ops_rhs = [f" {o} {r}" for o, r in zip(ops, rhs)]
465493
return "{" + lhs + "".join(ops_rhs) + "}"

src/latexify/codegen/function_codegen_test.py

+30
Original file line numberDiff line numberDiff line change
@@ -563,3 +563,33 @@ def test_visit_subscript(code: str, latex: str) -> None:
563563
tree = ast.parse(code).body[0].value
564564
assert isinstance(tree, ast.Subscript)
565565
assert function_codegen.FunctionCodegen().visit(tree) == latex
566+
567+
568+
@pytest.mark.parametrize(
569+
"code,latex",
570+
[
571+
("a - b", r"a \setminus b"),
572+
("a & b", r"a \cap b"),
573+
("a ^ b", r"a \mathbin{\triangle} b"),
574+
("a | b", r"a \cup b"),
575+
],
576+
)
577+
def test_use_set_symbols_binop(code: str, latex: str) -> None:
578+
tree = ast.parse(code).body[0].value
579+
assert isinstance(tree, ast.BinOp)
580+
assert function_codegen.FunctionCodegen(use_set_symbols=True).visit(tree) == latex
581+
582+
583+
@pytest.mark.parametrize(
584+
"code,latex",
585+
[
586+
("a < b", r"{a \subset b}"),
587+
("a <= b", r"{a \subseteq b}"),
588+
("a > b", r"{a \supset b}"),
589+
("a >= b", r"{a \supseteq b}"),
590+
],
591+
)
592+
def test_use_set_symbols_compare(code: str, latex: str) -> None:
593+
tree = ast.parse(code).body[0].value
594+
assert isinstance(tree, ast.Compare)
595+
assert function_codegen.FunctionCodegen(use_set_symbols=True).visit(tree) == latex

src/latexify/frontend.py

+3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def get_latex(
2020
use_math_symbols: bool = False,
2121
use_raw_function_name: bool = False,
2222
use_signature: bool = True,
23+
use_set_symbols: bool = False,
2324
) -> str:
2425
"""Obtains LaTeX description from the function's source.
2526
@@ -38,6 +39,7 @@ def get_latex(
3839
or convert it to subscript.
3940
use_signature: Whether to add the function signature before the expression or
4041
not.
42+
use_set_symbols: Whether to use set symbols or not.
4143
4244
Returns:
4345
Generatee LaTeX description.
@@ -59,6 +61,7 @@ def get_latex(
5961
use_math_symbols=use_math_symbols,
6062
use_raw_function_name=use_raw_function_name,
6163
use_signature=use_signature,
64+
use_set_symbols=use_set_symbols,
6265
).visit(tree)
6366

6467

src/latexify/frontend_test.py

+12
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,18 @@ def f(x):
6969
assert frontend.get_latex(f, use_signature=True) == latex_with_flag
7070

7171

72+
def test_get_latex_use_set_symbols() -> None:
73+
def f(x, y):
74+
return x & y
75+
76+
latex_without_flag = r"\mathrm{f}(x, y) = x \mathbin{\&} y"
77+
latex_with_flag = r"\mathrm{f}(x, y) = x \cap y"
78+
79+
assert frontend.get_latex(f) == latex_without_flag
80+
assert frontend.get_latex(f, use_set_symbols=False) == latex_without_flag
81+
assert frontend.get_latex(f, use_set_symbols=True) == latex_with_flag
82+
83+
7284
def test_function() -> None:
7385
def f(x):
7486
return x

0 commit comments

Comments
 (0)