Skip to content

Commit

Permalink
Typing for ./pylint: typing for ./pylint/testutils
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielNoord committed Sep 2, 2021
1 parent 62db444 commit 4bcff3e
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 50 deletions.
4 changes: 2 additions & 2 deletions pylint/reporters/base_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import os
import sys
from typing import List
from typing import List, Union

from pylint.message import Message

Expand All @@ -22,7 +22,7 @@ def __init__(self, output=None):
self.out = None
self.out_encoding = None
self.set_output(output)
self.messages: List[Message] = []
self.messages: List[Union[str, Message]] = []
# Build the path prefix to strip to get relative paths
self.path_strip_prefix = os.getcwd() + os.sep

Expand Down
15 changes: 9 additions & 6 deletions pylint/testutils/checker_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,37 @@
# For details: https://github.com/PyCQA/pylint/blob/main/LICENSE

import contextlib
from typing import Dict, Optional, Type
from typing import Dict, Iterator, Type

from astroid.nodes.scoped_nodes import Module

from pylint.testutils.global_test_linter import linter
from pylint.testutils.output_line import Message
from pylint.testutils.unittest_linter import UnittestLinter
from pylint.utils import ASTWalker


class CheckerTestCase:
"""A base testcase class for unit testing individual checker classes."""

CHECKER_CLASS: Optional[Type] = None
CHECKER_CLASS: Type
CONFIG: Dict = {}

def setup_method(self):
def setup_method(self) -> None:
self.linter = UnittestLinter()
self.checker = self.CHECKER_CLASS(self.linter) # pylint: disable=not-callable
for key, value in self.CONFIG.items():
setattr(self.checker.config, key, value)
self.checker.open()

@contextlib.contextmanager
def assertNoMessages(self):
def assertNoMessages(self) -> Iterator:
"""Assert that no messages are added by the given method."""
with self.assertAddsMessages():
yield

@contextlib.contextmanager
def assertAddsMessages(self, *messages):
def assertAddsMessages(self, *messages: Message) -> Iterator:
"""Assert that exactly the given method adds the given messages.
The list of messages must exactly match *all* the messages added by the
Expand All @@ -47,7 +50,7 @@ def assertAddsMessages(self, *messages):
)
assert got == list(messages), msg

def walk(self, node):
def walk(self, node: Module) -> None:
"""recursive walk on the given node"""
walker = ASTWalker(linter)
walker.add_checker(self.checker)
Expand Down
3 changes: 2 additions & 1 deletion pylint/testutils/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# For details: https://github.com/PyCQA/pylint/blob/main/LICENSE

import functools
from typing import Any

from pylint.testutils.checker_test_case import CheckerTestCase

Expand All @@ -15,7 +16,7 @@ def set_config(**kwargs):

def _wrapper(fun):
@functools.wraps(fun)
def _forward(self, *args, **test_function_kwargs):
def _forward(self: Any, *args: Any, **test_function_kwargs: Any) -> None:
for key, value in kwargs.items():
setattr(self.checker.config, key, value)
if isinstance(self, CheckerTestCase):
Expand Down
41 changes: 23 additions & 18 deletions pylint/testutils/lint_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import platform
import sys
from collections import Counter
from io import StringIO
from typing import Dict, List, Optional, Tuple
from io import StringIO, TextIOWrapper
from typing import Dict, List, Optional, TextIO, Tuple, Union

import pytest
from _pytest.config import Config
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(self, test_file: FunctionalTestFile, config: Optional[Config] = Non
self._test_file = test_file
self._config = config

def setUp(self):
def setUp(self) -> None:
if self._should_be_skipped_due_to_version():
pytest.skip(
f"Test cannot run with Python {sys.version.split(' ', maxsplit=1)[0]}."
Expand All @@ -75,10 +75,10 @@ def setUp(self):
if sys.platform.lower() in platforms:
pytest.skip(f"Test cannot run on platform {sys.platform!r}")

def runTest(self):
def runTest(self) -> None:
self._runTest()

def _should_be_skipped_due_to_version(self):
def _should_be_skipped_due_to_version(self) -> bool:
return (
sys.version_info < self._test_file.options["min_pyver"]
or sys.version_info > self._test_file.options["max_pyver"]
Expand All @@ -88,26 +88,26 @@ def __str__(self):
return f"{self._test_file.base} ({self.__class__.__module__}.{self.__class__.__name__})"

@staticmethod
def get_expected_messages(stream):
def get_expected_messages(stream: TextIO) -> Counter:
"""Parses a file and get expected messages.
:param stream: File-like input stream.
:type stream: enumerable
:returns: A dict mapping line,msg-symbol tuples to the count on this line.
:rtype: dict
"""
messages = Counter()
messages: Counter = Counter()
for i, line in enumerate(stream):
match = _EXPECTED_RE.search(line)
if match is None:
continue
line = match.group("line")
if line is None:
line = i + 1
line_no = i + 1
elif line.startswith("+") or line.startswith("-"):
line = i + 1 + int(line)
line_no = i + 1 + int(line)
else:
line = int(line)
line_no = int(line)

version = match.group("version")
op = match.group("op")
Expand All @@ -117,7 +117,7 @@ def get_expected_messages(stream):
continue

for msg_id in match.group("msgs").split(","):
messages[line, msg_id.strip()] += 1
messages[line_no, msg_id.strip()] += 1
return messages

@staticmethod
Expand All @@ -138,21 +138,21 @@ def multiset_difference(
return missing, unexpected

# pylint: disable=consider-using-with
def _open_expected_file(self):
def _open_expected_file(self) -> Union[StringIO, TextIOWrapper, TextIO]:
try:
return open(self._test_file.expected_output, encoding="utf-8")
except FileNotFoundError:
return StringIO("")

# pylint: disable=consider-using-with
def _open_source_file(self):
def _open_source_file(self) -> TextIO:
if self._test_file.base == "invalid_encoded_data":
return open(self._test_file.source, encoding="utf-8")
if "latin1" in self._test_file.base:
return open(self._test_file.source, encoding="latin1")
return open(self._test_file.source, encoding="utf8")

def _get_expected(self):
def _get_expected(self) -> Tuple[Counter, List[OutputLine]]:
with self._open_source_file() as f:
expected_msgs = self.get_expected_messages(f)
if not expected_msgs:
Expand All @@ -163,10 +163,10 @@ def _get_expected(self):
]
return expected_msgs, expected_output_lines

def _get_actual(self):
def _get_actual(self) -> Tuple[Counter, List[OutputLine]]:
messages = self._linter.reporter.messages
messages.sort(key=lambda m: (m.line, m.symbol, m.msg))
received_msgs = Counter()
received_msgs: Counter = Counter()
received_output_lines = []
for msg in messages:
assert (
Expand All @@ -176,7 +176,7 @@ def _get_actual(self):
received_output_lines.append(OutputLine.from_msg(msg))
return received_msgs, received_output_lines

def _runTest(self):
def _runTest(self) -> None:
__tracebackhide__ = True # pylint: disable=unused-variable
modules_to_check = [self._test_file.source]
self._linter.check(modules_to_check)
Expand Down Expand Up @@ -231,7 +231,12 @@ def error_msg_for_unequal_output(self, expected_lines, received_lines) -> str:
error_msg += f"{line}\n"
return error_msg

def _check_output_text(self, _, expected_output, actual_output):
def _check_output_text(
self,
_: Counter,
expected_output: List[OutputLine],
actual_output: List[OutputLine],
) -> None:
"""This is a function because we want to be able to update the text in LintModuleOutputUpdate"""
assert expected_output == actual_output, self.error_msg_for_unequal_output(
expected_output, actual_output
Expand Down
38 changes: 25 additions & 13 deletions pylint/testutils/output_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
# For details: https://github.com/PyCQA/pylint/blob/main/LICENSE

import collections
from typing import Any, NamedTuple
from typing import List, NamedTuple, Tuple, Union

from pylint import interfaces
from pylint.constants import PY38_PLUS
from pylint.message.message import Message as pylint_Message
from pylint.testutils.constants import UPDATE_OPTION


Expand All @@ -15,16 +16,20 @@ class Message(
def __new__(cls, msg_id, line=None, node=None, args=None, confidence=None):
return tuple.__new__(cls, (msg_id, line, node, args, confidence))

def __eq__(self, other):
def __eq__(self, other: Union[object, "Message"]) -> bool:
if isinstance(other, Message):
if self.confidence and other.confidence:
return super().__eq__(other)
return self[:-1] == other[:-1]
return self[:-1] == other[:-1] # type: ignore
return NotImplemented # pragma: no cover


class MalformedOutputLineException(Exception):
def __init__(self, row, exception):
def __init__(
self,
row: Union[List[str], Tuple[str, str, str, str, str], str],
exception: ValueError,
) -> None:
example = "msg-symbolic-name:42:27:MyClass.my_function:The message"
other_example = "msg-symbolic-name:7:42::The message"
expected = [
Expand Down Expand Up @@ -58,13 +63,13 @@ def __init__(self, row, exception):
class OutputLine(NamedTuple):
symbol: str
lineno: int
column: int
object: Any
column: str
object: str
msg: str
confidence: str

@classmethod
def from_msg(cls, msg):
def from_msg(cls, msg: pylint_Message) -> "OutputLine":
column = cls.get_column(msg.column)
return cls(
msg.symbol,
Expand All @@ -78,19 +83,26 @@ def from_msg(cls, msg):
)

@classmethod
def get_column(cls, column):
def get_column(cls, column: Union[int, str]) -> str:
if not PY38_PLUS:
return "" # pragma: no cover
return str(column)

@classmethod
def from_csv(cls, row):
def from_csv(
cls, row: Union[List[str], Tuple[str, str, str, str, str], str]
) -> "OutputLine":
try:
confidence = row[5] if len(row) == 6 else interfaces.HIGH.name
if len(row) == 6:
confidence = row[5] # type: ignore # mypy does not recognize that this cannot fail
else:
confidence = interfaces.HIGH.name
column = cls.get_column(row[2])
return cls(row[0], int(row[1]), column, row[3], row[4], confidence)
except Exception as e:
raise MalformedOutputLineException(row, e) from e
if isinstance(e, ValueError):
raise MalformedOutputLineException(row, e) from e
raise e

def to_csv(self):
return tuple(self)
def to_csv(self) -> "OutputLine":
return self
19 changes: 10 additions & 9 deletions pylint/testutils/reporter_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,27 @@

from io import StringIO
from os import getcwd, linesep, sep
from typing import Dict, List
from typing import Dict, List, Optional, Union

from pylint import interfaces
from pylint.message import Message
from pylint.reporters import BaseReporter
from pylint.reporters.ureports.nodes import EvaluationSection, Section


class GenericTestReporter(BaseReporter):
"""reporter storing plain text messages"""

__implements__ = interfaces.IReporter

def __init__(self): # pylint: disable=super-init-not-called
def __init__(self) -> None: # pylint: disable=super-init-not-called
self.reset()

def reset(self):
def reset(self) -> None:
self.message_ids: Dict = {}
self.out = StringIO()
self.path_strip_prefix: str = getcwd() + sep
self.messages: List[str] = []
self.messages: List[Union[str, Message]] = []

def handle_message(self, msg: Message) -> None:
"""manage message of different type and in the context of path"""
Expand All @@ -40,7 +41,7 @@ def handle_message(self, msg: Message) -> None:
str_message = str_message.replace("\r\n", "\n")
self.messages.append(f"{sigle}:{line:>3}{obj}: {str_message}")

def finalize(self):
def finalize(self) -> str:
self.messages.sort()
for msg in self.messages:
print(msg, file=self.out)
Expand All @@ -49,26 +50,26 @@ def finalize(self):
return result

# pylint: disable=unused-argument
def on_set_current_module(self, module, filepath):
def on_set_current_module(self, module: str, filepath: Optional[str]) -> None:
pass

# pylint: enable=unused-argument

def display_reports(self, layout):
def display_reports(self, layout: Union[EvaluationSection, Section]) -> None:
"""ignore layouts"""

_display = None


class MinimalTestReporter(BaseReporter):
def on_set_current_module(self, module, filepath):
def on_set_current_module(self, module: str, filepath: str) -> None:
self.messages = []

_display = None


class FunctionalTestReporter(BaseReporter):
def on_set_current_module(self, module, filepath):
def on_set_current_module(self, module: str, filepath: str) -> None:
self.messages = []

def display_reports(self, layout):
Expand Down
4 changes: 3 additions & 1 deletion pylint/testutils/tokenize_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import tokenize
from io import StringIO
from tokenize import TokenInfo
from typing import List


def _tokenize_str(code):
def _tokenize_str(code: str) -> List[TokenInfo]:
return list(tokenize.generate_tokens(StringIO(code).readline))

0 comments on commit 4bcff3e

Please sign in to comment.