diff --git a/docs/known-limitations.rst b/docs/known-limitations.rst index 7ca8e219..f2e1dfe7 100644 --- a/docs/known-limitations.rst +++ b/docs/known-limitations.rst @@ -8,37 +8,6 @@ Automagic `Automagic `_ ("Make magic functions callable without having to type the initial %") will not work well with most code quality tools, as it will not parse as valid Python syntax. -Black ------ - -Comment after trailing semicolon -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Putting a comment after a trailing semicolon will make ``black`` move the comment to the -next line, and the semicolon will be lost. - -Example: - -.. code:: python - - plt.plot(); # some comment - - -Will be transformed to: - -.. code:: python - - plt.plot() - # some comment - -You can overcome this limitation by moving the comment to the previous line - like this, -the trailing semicolon will be preserved: - -.. code:: python - - # some comment - plt.plot(); - Linters (flake8, mypy, pylint, ...) ----------------------------------- diff --git a/nbqa/replace_source.py b/nbqa/replace_source.py index 28d3da83..defd9490 100644 --- a/nbqa/replace_source.py +++ b/nbqa/replace_source.py @@ -19,6 +19,8 @@ Set, ) +import tokenize_rt + from nbqa.handle_magics import MagicHandler from nbqa.notebook_info import NotebookInfo from nbqa.save_source import CODE_SEPARATOR @@ -52,11 +54,24 @@ def _restore_semicolon( ------- str New source with removed semicolon restored. - """ - rstripped_source = source.rstrip() - if cell_number in trailing_semicolons and not rstripped_source.endswith(";"): - source = rstripped_source + ";" + Raises + ------ + AssertionError + If code thought to be unreachable is reached. + """ + if cell_number in trailing_semicolons: + tokens = tokenize_rt.src_to_tokens(source) + for idx, token in tokenize_rt.reversed_enumerate(tokens): + if not token.src.strip(" \n") or token.name == "COMMENT": + continue + tokens[idx] = token._replace(src=token.src + ";") + break + else: # pragma: nocover + raise AssertionError( + "Unreachable code, please report bug at https://github.com/nbQA-dev/nbQA/issues" + ) + source = tokenize_rt.tokens_to_src(tokens) return source diff --git a/nbqa/save_source.py b/nbqa/save_source.py index 86166b8d..f9ca613f 100644 --- a/nbqa/save_source.py +++ b/nbqa/save_source.py @@ -17,10 +17,13 @@ List, Mapping, MutableMapping, + NamedTuple, Sequence, Tuple, ) +import tokenize_rt + from nbqa.handle_magics import INPUT_SPLITTER, IPythonMagicType, MagicHandler from nbqa.notebook_info import NotebookInfo @@ -34,6 +37,13 @@ NEWLINES["isort"] = NEWLINE * 2 +class Index(NamedTuple): + """Keep track of line and cell number while iterating over cells.""" + + line_number: int + cell_number: int + + def _is_src_code_indentation_valid(source: str) -> bool: """ Return True is the indentation of the input source code is valid. @@ -303,6 +313,34 @@ def _should_ignore_code_cell( return first_line.split()[0] not in {f"%%{magic}" for magic in process} +def _has_trailing_semicolon(src: str) -> Tuple[str, bool]: + """ + Check if cell has trailing semicolon. + + Parameters + ---------- + src + Notebook cell source. + + Returns + ------- + bool + Whether notebook has trailing semicolon. + """ + tokens = tokenize_rt.src_to_tokens(src) + trailing_semicolon = False + for idx, token in tokenize_rt.reversed_enumerate(tokens): + if not token.src.strip(" \n") or token.name == "COMMENT": + continue + if token.name == "OP" and token.src == ";": + tokens[idx] = token._replace(src="") + trailing_semicolon = True + break + if not trailing_semicolon: + return src, False + return tokenize_rt.tokens_to_src(tokens), True + + def main( notebook: "Path", temp_python_file: "Path", @@ -332,36 +370,40 @@ def main( result = [] cell_mapping = {0: "cell_0:0"} - line_number = 0 - cell_number = 0 + index = Index(line_number=0, cell_number=0) trailing_semicolons = set() temporary_lines: DefaultDict[int, Sequence[MagicHandler]] = defaultdict(list) code_cells_to_ignore = set() for cell in cells: if cell["cell_type"] == "code": - cell_number += 1 + index = index._replace(cell_number=index.cell_number + 1) if _should_ignore_code_cell(cell["source"], process_cells): - code_cells_to_ignore.add(cell_number) + code_cells_to_ignore.add(index.cell_number) continue parsed_cell = _parse_cell( - cell["source"], cell_number, temporary_lines, command + cell["source"], index.cell_number, temporary_lines, command ) cell_mapping.update( { - py_line + line_number + 1: f"cell_{cell_number}:{cell_line}" + py_line + + index.line_number + + 1: f"cell_{index.cell_number}:{cell_line}" for py_line, cell_line in _get_line_numbers_for_mapping( - parsed_cell, temporary_lines[cell_number] + parsed_cell, temporary_lines[index.cell_number] ).items() } ) - if parsed_cell.rstrip().endswith(";"): - trailing_semicolons.add(cell_number) - result.append(re.sub(r";(\s*)$", "\\1", parsed_cell)) - line_number += len(parsed_cell.splitlines()) + parsed_cell, trailing_semicolon = _has_trailing_semicolon(parsed_cell) + if trailing_semicolon: + trailing_semicolons.add(index.cell_number) + result.append(parsed_cell) + index = index._replace( + line_number=index.line_number + len(parsed_cell.splitlines()) + ) result_txt = "".join(result).rstrip(NEWLINE) + NEWLINE if result else "" temp_python_file.write_text(result_txt, encoding="utf-8") diff --git a/setup.cfg b/setup.cfg index dccb3d98..1f6dc456 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,6 +34,7 @@ packages = find: py_modules = nbqa install_requires = ipython>=7.8.0 + tokenize-rt>=3.2.0 toml importlib_metadata;python_version < '3.8' python_requires = >=3.6.1 diff --git a/tests/data/comment_after_trailing_semicolon.ipynb b/tests/data/comment_after_trailing_semicolon.ipynb new file mode 100644 index 00000000..2861dc91 --- /dev/null +++ b/tests/data/comment_after_trailing_semicolon.ipynb @@ -0,0 +1,49 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import glob;\n", + "\n", + "import nbqa;\n", + "# this is a comment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def func(a, b):\n", + " pass;\n", + " " + ] + } + ], + "metadata": { + "anaconda-cloud": {}, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests/tools/test_black.py b/tests/tools/test_black.py index 3ae6cd2c..8f6fe9dd 100644 --- a/tests/tools/test_black.py +++ b/tests/tools/test_black.py @@ -500,3 +500,47 @@ def test_invalid_syntax_with_nbqa_diff(capsys: "CaptureFixture") -> None: assert expected_out == out assert expected_err == err + + +def test_comment_after_trailing_comma(capsys: "CaptureFixture") -> None: + """ + Check trailing semicolon is still preserved if comment is after it. + + Parameters + ---------- + capsys + Pytest fixture to capture stdout and stderr. + """ + path = os.path.abspath( + os.path.join("tests", "data", "comment_after_trailing_semicolon.ipynb") + ) + + with pytest.raises(SystemExit): + main(["black", path, "--nbqa-diff"]) + + out, _ = capsys.readouterr() + expected_out = ( + "\x1b[1mCell 1\x1b[0m\n" + "------\n" + f"--- {path}\n" + f"+++ {path}\n" + "@@ -1,4 +1,5 @@\n" + "\x1b[31m-import glob;\n" + "\x1b[0m\x1b[32m+import glob\n" + "\x1b[0m \n" + " import nbqa;\n" + "\x1b[32m+\n" + "\x1b[0m # this is a comment\n" + "\n" + "\x1b[1mCell 2\x1b[0m\n" + "------\n" + f"--- {path}\n" + f"+++ {path}\n" + "@@ -1,3 +1,2 @@\n" + " def func(a, b):\n" + " pass;\n" + "\x1b[31m- \n" + "\x1b[0m\n" + "To apply these changes use `--nbqa-mutate` instead of `--nbqa-diff`\n" + ) + assert out == expected_out diff --git a/tests/tools/test_isort_works.py b/tests/tools/test_isort_works.py index 85fe4f7a..c81b2240 100644 --- a/tests/tools/test_isort_works.py +++ b/tests/tools/test_isort_works.py @@ -36,7 +36,8 @@ def test_isort_works(tmp_notebook_for_testing: Path, capsys: "CaptureFixture") - with open(tmp_notebook_for_testing) as handle: after = handle.readlines() diff = difflib.unified_diff(before, after) - result = "".join([i for i in diff if any([i.startswith("+ "), i.startswith("- ")])]) + result = "".join(i for i in diff if any([i.startswith("+ "), i.startswith("- ")])) + expected = dedent( """\ + "import glob\\n", @@ -77,7 +78,8 @@ def test_isort_initial_md( with open(tmp_notebook_starting_with_md) as handle: after = handle.readlines() diff = difflib.unified_diff(before, after) - result = "".join([i for i in diff if any([i.startswith("+ "), i.startswith("- ")])]) + result = "".join(i for i in diff if any([i.startswith("+ "), i.startswith("- ")])) + expected = dedent( """\ + "import glob\\n", @@ -159,16 +161,9 @@ def test_isort_trailing_semicolon(tmp_notebook_with_trailing_semicolon: Path) -> with open(tmp_notebook_with_trailing_semicolon) as handle: after = handle.readlines() diff = difflib.unified_diff(before, after) - result = "".join([i for i in diff if any([i.startswith("+ "), i.startswith("- ")])]) - expected = dedent( - """\ - - "import glob;\\n", - + "import glob\\n", - - " pass;\\n", - - " " - + " pass;" - """ - ) + result = "".join(i for i in diff if any([i.startswith("+ "), i.startswith("- ")])) + + expected = '- "import glob;\\n",\n+ "import glob\\n",\n' assert result == expected @@ -212,3 +207,33 @@ def test_old_isort(monkeypatch: "MonkeyPatch") -> None: msg = "\x1b[1mnbqa only works with isort >= 5.3.0, while you have 4.3.21 installed.\x1b[0m" assert msg == str(excinfo.value) + + +def test_comment_after_trailing_semicolons(capsys: "CaptureFixture") -> None: + """Check isort works normally when there's a comment after trailing semicolon.""" + # check diff + path = os.path.abspath( + os.path.join("tests", "data", "comment_after_trailing_semicolon.ipynb") + ) + + with pytest.raises(SystemExit): + main(["isort", path, "--nbqa-diff"]) + + out, _ = capsys.readouterr() + expected_out = ( + "\x1b[1mCell 1\x1b[0m\n" + "------\n" + f"--- {path}\n" + f"+++ {path}\n" + "@@ -1,4 +1,5 @@\n" + "\x1b[31m-import glob;\n" + "\x1b[0m\x1b[32m+import glob\n" + "\x1b[0m \n" + " import nbqa;\n" + "\x1b[32m+\n" + "\x1b[0m # this is a comment\n" + "\n" + f"Fixing {path}\n" + "To apply these changes use `--nbqa-mutate` instead of `--nbqa-diff`\n" + ) + assert out == expected_out diff --git a/tests/tools/test_mypy_works.py b/tests/tools/test_mypy_works.py index 869f5799..c2791cbf 100644 --- a/tests/tools/test_mypy_works.py +++ b/tests/tools/test_mypy_works.py @@ -45,7 +45,7 @@ def test_mypy_works(capsys: "CaptureFixture") -> None: {path_0}:cell_2:19: error: Argument 1 to "hello" has incompatible type "int"; expected "str" {path_3}:cell_8:3: error: Name 'flake8_version' is not defined {path_3}:cell_8:4: error: Name 'flake8_version' is not defined - Found 5 errors in 4 files (checked 23 source files) + Found 5 errors in 4 files (checked 24 source files) """ # noqa ) expected_err = ""