diff --git a/bowler/query.py b/bowler/query.py index 9e7fac0..c520952 100644 --- a/bowler/query.py +++ b/bowler/query.py @@ -10,9 +10,8 @@ import pathlib import re from functools import wraps -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, cast +from typing import Callable, List, Optional, Type, TypeVar, Union, cast -from attr import Factory, dataclass from fissix.fixer_base import BaseFix from fissix.fixer_util import Attr, Comma, Dot, LParen, Name, Newline, RParen from fissix.pytree import Leaf, Node, type_repr @@ -89,12 +88,14 @@ def __init__( self, *paths: Union[str, List[str]], filename_matcher: Optional[FilenameMatcher] = None, + python_version: int = 3, ) -> None: self.paths: List[str] = [] self.transforms: List[Transform] = [] self.processors: List[Processor] = [] self.retcode: Optional[int] = None self.filename_matcher = filename_matcher + self.python_version = python_version self.exceptions: List[BowlerException] = [] for path in paths: @@ -995,6 +996,8 @@ def processor(filename: Filename, hunk: Hunk) -> bool: kwargs["hunk_processor"] = processor kwargs.setdefault("filename_matcher", self.filename_matcher) + if self.python_version == 3: + kwargs.setdefault("options", {})["print_function"] = True tool = BowlerTool(fixers, **kwargs) self.retcode = tool.run(self.paths) self.exceptions = tool.exceptions diff --git a/bowler/tests/lib.py b/bowler/tests/lib.py index 490bff7..9a99889 100644 --- a/bowler/tests/lib.py +++ b/bowler/tests/lib.py @@ -8,9 +8,7 @@ import functools import multiprocessing import sys -import tempfile import unittest -from contextlib import contextmanager from io import StringIO import click diff --git a/bowler/tests/query.py b/bowler/tests/query.py index a87dc30..8e20c55 100644 --- a/bowler/tests/query.py +++ b/bowler/tests/query.py @@ -8,7 +8,7 @@ from unittest import mock from ..query import SELECTORS, Query -from ..types import TOKEN, BowlerException, Leaf +from ..types import TOKEN, Leaf from .lib import BowlerTestCase @@ -48,6 +48,82 @@ def query_func(arg): query_func=query_func, ) + def test_parse_print_func_py3(self): + # Py 3 mode is the default + def select_print_func(arg): + return Query(arg).select_var("bar").rename("baz") + + template = """{} = 1; {}""" + self.run_bowler_modifiers( + [ + ( + # ParseError prevents rename succeeding + template.format("bar", 'print "hello world"'), + template.format("bar", 'print "hello world"'), + ), + ( + template.format("bar", 'print("hello world")'), + template.format("baz", 'print("hello world")'), + ), + ( + template.format("bar", 'print("hello world", end="")'), + template.format("baz", 'print("hello world", end="")'), + ), + ], + query_func=select_print_func, + ) + + def test_parse_print_func_py2(self): + def select_print_func(arg): + return Query(arg, python_version=2).select_var("bar").rename("baz") + + template = """{} = 1; {}""" + self.run_bowler_modifiers( + [ + ( + template.format("bar", 'print "hello world"'), + template.format("baz", 'print "hello world"'), + ), + ( + # not a print function call, just parenthesised statement + template.format("bar", 'print("hello world")'), + template.format("baz", 'print("hello world")'), + ), + ( + # ParseError prevents rename succeeding + template.format("bar", 'print("hello world", end="")'), + template.format("bar", 'print("hello world", end="")'), + ), + ], + query_func=select_print_func, + ) + + def test_parse_print_func_py2_future_print(self): + def select_print_func(arg): + return Query(arg, python_version=2).select_var("bar").rename("baz") + + template = """\ +from __future__ import print_function +{} = 1; {}""" + self.run_bowler_modifiers( + [ + ( + # ParseError prevents rename succeeding + template.format("bar", 'print "hello world"'), + template.format("bar", 'print "hello world"'), + ), + ( + template.format("bar", 'print("hello world")'), + template.format("baz", 'print("hello world")'), + ), + ( + template.format("bar", 'print("hello world", end="")'), + template.format("baz", 'print("hello world", end="")'), + ), + ], + query_func=select_print_func, + ) + def test_rename_class(self): self.run_bowler_modifiers( [("class Bar(Foo):\n pass", "class FooBar(Foo):\n pass")], diff --git a/bowler/tool.py b/bowler/tool.py index 021dd5d..52a5f01 100755 --- a/bowler/tool.py +++ b/bowler/tool.py @@ -12,11 +12,12 @@ import sys import time from queue import Empty -from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple +from typing import Any, Iterator, List, Optional, Sequence, Tuple import click +from fissix import pygram from fissix.pgen2.parse import ParseError -from fissix.refactor import RefactoringTool +from fissix.refactor import RefactoringTool, _detect_future_features from moreorless.patch import PatchException, apply_single_file @@ -29,7 +30,6 @@ FilenameMatcher, Fixers, Hunk, - Node, Processor, RetryFile, ) @@ -97,7 +97,6 @@ def __init__( **kwargs, ) -> None: options = kwargs.pop("options", {}) - options["print_function"] = True super().__init__(fixers, *args, options=options, **kwargs) self.queue_count = 0 self.queue = multiprocessing.JoinableQueue() # type: ignore @@ -148,6 +147,9 @@ def processed_file( if hunk: hunks.append([a, b, *hunk]) + original_grammar = self.driver.grammar + if "print_function" in _detect_future_features(new_text): + self.driver.grammar = pygram.python_grammar_no_print_statement try: new_tree = self.driver.parse_string(new_text) if new_tree is None: @@ -158,6 +160,8 @@ def processed_file( filename=filename, hunks=hunks, ) from e + finally: + self.driver.grammar = original_grammar return hunks diff --git a/docs/api-query.md b/docs/api-query.md index b36dae5..9a38ff1 100644 --- a/docs/api-query.md +++ b/docs/api-query.md @@ -45,7 +45,11 @@ clarity and brevity. Create a new query object to process the given set of files or directories. ```python -Query(*paths: Union[str, List[str]], filename_matcher: FilenameMatcher) +Query( + *paths: Union[str, List[str]], + python_version: int, + filename_matcher: FilenameMatcher, +) ``` * `*paths` - Accepts either individual file or directory paths (relative to the current @@ -56,6 +60,11 @@ Query(*paths: Union[str, List[str]], filename_matcher: FilenameMatcher) eligible for refactoring. Defaults to only matching files that end with `.py`. +* `python_version` - The 'major' python version of the files to be refactored, i.e. `2` + or `3`. This allows the parser to handle `print` statement vs function correctly. This + includes detecting use of `from __future__ import print_function` when + `python_version=2`. Default is `3`. + ### `.select()`