@@ -144,6 +144,17 @@ class BinOpRule:
144
144
ast .BitOr : BinOpRule ("" , r" \mathbin{|} " , "" ),
145
145
}
146
146
147
+ # Typeset for BinOp of sets.
148
+ _SET_BIN_OP_RULES : dict [type [ast .operator ], BinOpRule ] = {
149
+ ** _BIN_OP_RULES ,
150
+ ast .Sub : BinOpRule (
151
+ "" , r" \setminus " , "" , operand_right = BinOperandRule (force = True )
152
+ ),
153
+ ast .BitAnd : BinOpRule ("" , r" \cap " , "" ),
154
+ ast .BitXor : BinOpRule ("" , r" \mathbin{\triangle} " , "" ),
155
+ ast .BitOr : BinOpRule ("" , r" \cup " , "" ),
156
+ }
157
+
147
158
_UNARY_OPS : dict [type [ast .unaryop ], str ] = {
148
159
ast .Invert : r"\mathord{\sim} " ,
149
160
ast .UAdd : "+" , # Explicitly adds the $+$ operator.
@@ -164,6 +175,15 @@ class BinOpRule:
164
175
ast .NotIn : r"\notin" ,
165
176
}
166
177
178
+ # Typeset for Compare of sets.
179
+ _SET_COMPARE_OPS : dict [type [ast .cmpop ], str ] = {
180
+ ** _COMPARE_OPS ,
181
+ ast .Gt : r"\supset" ,
182
+ ast .GtE : r"\supseteq" ,
183
+ ast .Lt : r"\subset" ,
184
+ ast .LtE : r"\subseteq" ,
185
+ }
186
+
167
187
_BOOL_OPS : dict [type [ast .boolop ], str ] = {
168
188
ast .And : r"\land" ,
169
189
ast .Or : r"\lor" ,
@@ -181,12 +201,16 @@ class FunctionCodegen(ast.NodeVisitor):
181
201
_use_raw_function_name : bool
182
202
_use_signature : bool
183
203
204
+ _bin_op_rules : dict [type [ast .operator ], BinOpRule ]
205
+ _compare_ops : dict [type [ast .cmpop ], str ]
206
+
184
207
def __init__ (
185
208
self ,
186
209
* ,
187
210
use_math_symbols : bool = False ,
188
211
use_raw_function_name : bool = False ,
189
212
use_signature : bool = True ,
213
+ use_set_symbols : bool = False ,
190
214
) -> None :
191
215
"""Initializer.
192
216
@@ -197,13 +221,17 @@ def __init__(
197
221
or convert it to subscript.
198
222
use_signature: Whether to add the function signature before the expression
199
223
or not.
224
+ use_set_symbols: Whether to use set symbols or not.
200
225
"""
201
226
self ._math_symbol_converter = math_symbols .MathSymbolConverter (
202
227
enabled = use_math_symbols
203
228
)
204
229
self ._use_raw_function_name = use_raw_function_name
205
230
self ._use_signature = use_signature
206
231
232
+ self ._bin_op_rules = _SET_BIN_OP_RULES if use_set_symbols else _BIN_OP_RULES
233
+ self ._compare_ops = _SET_COMPARE_OPS if use_set_symbols else _COMPARE_OPS
234
+
207
235
def generic_visit (self , node : ast .AST ) -> str :
208
236
raise exceptions .LatexifyNotSupportedError (
209
237
f"Unsupported AST: { type (node ).__name__ } "
@@ -445,7 +473,7 @@ def _wrap_binop_operand(
445
473
def visit_BinOp (self , node : ast .BinOp ) -> str :
446
474
"""Visit a BinOp node."""
447
475
prec = _get_precedence (node )
448
- rule = _BIN_OP_RULES [type (node .op )]
476
+ rule = self . _bin_op_rules [type (node .op )]
449
477
lhs = self ._wrap_binop_operand (node .left , prec , rule .operand_left )
450
478
rhs = self ._wrap_binop_operand (node .right , prec , rule .operand_right )
451
479
return f"{ rule .latex_left } { lhs } { rule .latex_middle } { rhs } { rule .latex_right } "
@@ -459,7 +487,7 @@ def visit_Compare(self, node: ast.Compare) -> str:
459
487
"""Visit a compare node."""
460
488
parent_prec = _get_precedence (node )
461
489
lhs = self ._wrap_operand (node .left , parent_prec )
462
- ops = [_COMPARE_OPS [type (x )] for x in node .ops ]
490
+ ops = [self . _compare_ops [type (x )] for x in node .ops ]
463
491
rhs = [self ._wrap_operand (x , parent_prec ) for x in node .comparators ]
464
492
ops_rhs = [f" { o } { r } " for o , r in zip (ops , rhs )]
465
493
return "{" + lhs + "" .join (ops_rhs ) + "}"
0 commit comments