Skip to content

Commit 071a12d

Browse files
authored
fix: sub operate cause bracket error (#80)
1 parent 3a82518 commit 071a12d

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

src/integration_tests/regression_test.py

+8
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,11 @@ def sigmoid(x):
216216
),
217217
reduce_assignments=True,
218218
)
219+
220+
221+
def test_sub_bracket() -> None:
222+
def solve(a, b):
223+
return ((a + b) - b) / (a - b) - (a + b) - (a - b) - (a * b)
224+
225+
latex = r"\mathrm{solve}(a, b) = \frac{a + b - b}{a - b} - (a + b) - (a - b) - ab"
226+
_check_function(solve, latex)

src/latexify/codegen/function_codegen.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -229,20 +229,20 @@ def visit_BinOp(self, node: ast.BinOp) -> str:
229229
def _unwrap(child):
230230
return self.visit(child)
231231

232-
def _wrap(child):
232+
def _wrap(child, force_use_bracket=False):
233233
latex = _unwrap(child)
234234
if isinstance(child, ast.BinOp):
235235
cp = priority[type(child.op)] if type(child.op) in priority else 100
236236
pp = priority[type(node.op)] if type(node.op) in priority else 100
237-
if cp < pp:
237+
if cp < pp or (cp == pp and force_use_bracket):
238238
return "(" + latex + ")"
239239
return latex
240240

241241
lhs = node.left
242242
rhs = node.right
243243
reprs = {
244244
ast.Add: (lambda: _wrap(lhs) + " + " + _wrap(rhs)),
245-
ast.Sub: (lambda: _wrap(lhs) + " - " + _wrap(rhs)),
245+
ast.Sub: (lambda: _wrap(lhs) + " - " + _wrap(rhs, True)),
246246
ast.Mult: (lambda: _wrap(lhs) + _wrap(rhs)),
247247
ast.MatMult: (lambda: _wrap(lhs) + _wrap(rhs)),
248248
ast.Div: (lambda: r"\frac{" + _unwrap(lhs) + "}{" + _unwrap(rhs) + "}"),

0 commit comments

Comments
 (0)