Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement summation with limits using comprehension #30 #32

Merged
merged 9 commits into from
Oct 4, 2022
30 changes: 30 additions & 0 deletions latexify/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,38 @@ def visit_Call(self, node): # pylint: disable=invalid-name
rstr = r'\right)'

arg_strs = [self.visit(arg) for arg in node.args]

if callee_str == 'sum' and isinstance(node.args[0], ast.GeneratorExp):
limit_str, formula_str = self.visit(node.args[0])
lstr = r'\sum' + limit_str + r' \left({'
return lstr + formula_str + rstr
if callee_str == 'range':
return arg_strs

return lstr + ', '.join(arg_strs) + rstr

def visit_GeneratorExp(self, node): # pylint: disable=invalid-name
limit_str = self.visit(node.generators[0])
formula_str = self.visit(node.elt)
return limit_str, formula_str

def visit_comprehension(self, node): # pylint: disable=invalid-name
"""Visit a comprehension node (for clause)."""
var = self.visit(node.target)
limits = self.visit(node.iter)
try:
if isinstance(limits, list):
if len(limits) == 1:
lower_limit, upper_limit = '0', limits[0]
elif len(limits) == 2:
lower_limit, upper_limit = limits
else:
raise ValueError
except ValueError as e:
print(e, 'Maybe function other than "range" is used is comprehension')
return fr'_{{{var}={lower_limit}}}^{{{upper_limit}}}'


def visit_Attribute(self, node): # pylint: disable=invalid-name
vstr = self.visit(node.value)
astr = str(node.attr)
Expand Down
21 changes: 21 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,32 @@ def xtimesbeta(x, beta):
xtimesbeta_latex_no_symbols = r'\mathrm{xtimesbeta}(x, beta) \triangleq xbeta'


def sum_with_limit(n):
return sum(i**2 for i in range(n))


sum_with_limit_latex = (
r'\mathrm{sum_with_limit}(n) \triangleq \sum_{i=0}^{n} \left({i^{2}}\right)'
)


def sum_with_limit_two_args(a, n):
return sum(i**2 for i in range(a, n))


sum_with_limit_two_args_latex = (
r'\mathrm{sum_with_limit_two_args}(a, n) '
r'\triangleq \sum_{i=a}^{n} \left({i^{2}}\right)'
)


func_and_latex_str_list = [
(solve, solve_latex, None),
(sinc, sinc_latex, None),
(xtimesbeta, xtimesbeta_latex, True),
(xtimesbeta, xtimesbeta_latex_no_symbols, False),
(sum_with_limit, sum_with_limit_latex, None),
(sum_with_limit_two_args, sum_with_limit_two_args_latex, None),
]


Expand Down