diff --git a/macchiato.py b/macchiato.py index a631638..54ab076 100644 --- a/macchiato.py +++ b/macchiato.py @@ -1,6 +1,7 @@ import itertools import sys import tempfile +import tokenize import black @@ -33,6 +34,24 @@ def macchiato(in_fp, out_fp, args=None): prefix = 4 * i * " " lines.insert(i, f"{prefix}if True:\n") + # Handle else/elif/except/finally + try: + first_token = next( + tokenize.generate_tokens(iter([first_line.lstrip()]).__next__) + ) + except tokenize.TokenError: + first_token = None + if first_token and first_token.type == tokenize.NAME: + name = first_token.string + if name in {"else", "elif"}: + lines.insert(n_fake_before, f"{indent * ' '}if True:\n") + lines.insert(n_fake_before + 1, f"{indent * ' '} pass\n") + n_fake_before += 2 + elif name in {"except", "finally"}: + lines.insert(n_fake_before, f"{indent * ' '}try:\n") + lines.insert(n_fake_before + 1, f"{indent * ' '} pass\n") + n_fake_before += 2 + # Detect an unclosed block at the end. Add ‘pass’ at the end of the line if # needed for valid syntax. last_line = lines[-1] @@ -62,9 +81,12 @@ def macchiato(in_fp, out_fp, args=None): # Write output. fp.seek(0) formatted_lines = fp.readlines() - out_fp.write("\n" * n_blank_before) until = len(formatted_lines) - n_fake_after - for line in formatted_lines[n_fake_before:until]: + formatted_lines = formatted_lines[n_fake_before:until] + fmt_n_blank_before, _ = count_surrounding_blank_lines(formatted_lines) + formatted_lines = formatted_lines[fmt_n_blank_before:] + out_fp.write("\n" * n_blank_before) + for line in formatted_lines: out_fp.write(line) out_fp.write("\n" * n_blank_after) diff --git a/test_macchiato.py b/test_macchiato.py index 9797e23..6a74d94 100644 --- a/test_macchiato.py +++ b/test_macchiato.py @@ -27,7 +27,11 @@ def test_count_surrounding_blank_lines(lines, before, after): ("foo\n", "foo\n"), (" foo\n", " foo\n"), (" if True:\n", " if True:\n"), - ("\n\n x=3\n\n", "\n\n x = 3\n\n") + ("\n\n x=3\n\n", "\n\n x = 3\n\n"), + ("elif x==5:\n", "elif x == 5:\n"), + ("'''\n'''\n", '"""\n"""\n'), # tokenize error handling + (" finally :\n", " finally:\n"), + (" def f():\n pass\n", " def f():\n pass\n"), ], ) def test_macchiato(input, expected):