Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def read(fname):
],
python_requires='~=3.4',
install_requires=[
'filelock>=3.0',
'pytest>=2.8,<4.7; python_version<"3.5"',
'pytest>=2.8; python_version>="3.5"',
'mypy>=0.500,<0.700; python_version<"3.5"',
Expand Down
162 changes: 118 additions & 44 deletions src/pytest_mypy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Mypy static type checker plugin for Pytest"""

import json
import os
from tempfile import NamedTemporaryFile

from filelock import FileLock
import pytest
import mypy.api

Expand All @@ -20,11 +23,44 @@ def pytest_addoption(parser):
help="suppresses error messages about imports that cannot be resolved")


def _is_master(config):
"""
True if the code running the given pytest.config object is running in
an xdist master node or not running xdist at all.
"""
return not hasattr(config, 'slaveinput')


def pytest_configure(config):
"""
Register a custom marker for MypyItems,
Initialize the path used to cache mypy results,
register a custom marker for MypyItems,
and configure the plugin based on the CLI.
"""
if _is_master(config):

# Get the path to a temporary file and delete it.
# The first MypyItem to run will see the file does not exist,
# and it will run and parse mypy results to create it.
# Subsequent MypyItems will see the file exists,
# and they will read the parsed results.
with NamedTemporaryFile(delete=True) as tmp_f:
config._mypy_results_path = tmp_f.name

# If xdist is enabled, then the results path should be exposed to
# the slaves so that they know where to read parsed results from.
if config.pluginmanager.getplugin('xdist'):
class _MypyXdistPlugin:
def pytest_configure_node(self, node): # xdist hook
"""Pass config._mypy_results_path to workers."""
node.slaveinput['_mypy_results_path'] = \
node.config._mypy_results_path
config.pluginmanager.register(_MypyXdistPlugin())

# pytest_terminal_summary cannot accept config before pytest 4.2.
global _pytest_terminal_summary_config
_pytest_terminal_summary_config = config

config.addinivalue_line(
'markers',
'{marker}: mark tests to be checked by mypy.'.format(
Expand All @@ -45,46 +81,6 @@ def pytest_collect_file(path, parent):
return None


def pytest_runtestloop(session):
"""Run mypy on collected MypyItems, then sort the output."""
mypy_items = {
os.path.abspath(str(item.fspath)): item
for item in session.items
if isinstance(item, MypyItem)
}
if mypy_items:

terminal = session.config.pluginmanager.getplugin('terminalreporter')
terminal.write(
'\nRunning {command} on {file_count} files... '.format(
command=' '.join(['mypy'] + mypy_argv),
file_count=len(mypy_items),
),
)
stdout, stderr, status = mypy.api.run(
mypy_argv + [str(item.fspath) for item in mypy_items.values()],
)
terminal.write('done with status {status}\n'.format(status=status))

unmatched_lines = []
for line in stdout.split('\n'):
if not line:
continue
mypy_path, _, error = line.partition(':')
try:
item = mypy_items[os.path.abspath(mypy_path)]
except KeyError:
unmatched_lines.append(line)
else:
item.mypy_errors.append(error)
if any(unmatched_lines):
color = {"red": True} if status != 0 else {"green": True}
terminal.write_line('\n'.join(unmatched_lines), **color)

if stderr:
terminal.write_line(stderr, red=True)


class MypyItem(pytest.Item, pytest.File):

"""A File that Mypy Runs On."""
Expand All @@ -94,12 +90,28 @@ class MypyItem(pytest.Item, pytest.File):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.add_marker(self.MARKER)
self.mypy_errors = []

def runtest(self):
"""Raise an exception if mypy found errors for this item."""
if self.mypy_errors:
raise MypyError('\n'.join(self.mypy_errors))
results = _cached_json_results(
results_path=(
self.config._mypy_results_path
if _is_master(self.config) else
self.config.slaveinput['_mypy_results_path']
),
results_factory=lambda:
_mypy_results_factory(
abspaths=[
os.path.abspath(str(item.fspath))
for item in self.session.items
if isinstance(item, MypyItem)
],
)
)
abspath = os.path.abspath(str(self.fspath))
errors = results['abspath_errors'].get(abspath)
if errors:
raise MypyError('\n'.join(errors))

def reportinfo(self):
"""Produce a heading for the test report."""
Expand All @@ -119,8 +131,70 @@ def repr_failure(self, excinfo):
return super().repr_failure(excinfo)


def _cached_json_results(results_path, results_factory=None):
"""
Read results from results_path if it exists;
otherwise, produce them with results_factory,
and write them to results_path.
"""
with FileLock(results_path + '.lock'):
try:
with open(results_path, mode='r') as results_f:
results = json.load(results_f)
except FileNotFoundError:
if not results_factory:
raise
results = results_factory()
with open(results_path, mode='w') as results_f:
json.dump(results, results_f)
return results


def _mypy_results_factory(abspaths):
"""Run mypy on abspaths and return the results as a JSON-able dict."""

stdout, stderr, status = mypy.api.run(mypy_argv + abspaths)

abspath_errors, unmatched_lines = {}, []
for line in stdout.split('\n'):
if not line:
continue
path, _, error = line.partition(':')
abspath = os.path.abspath(path)
if abspath in abspaths:
abspath_errors[abspath] = abspath_errors.get(abspath, []) + [error]
else:
unmatched_lines.append(line)

return {
'stdout': stdout,
'stderr': stderr,
'status': status,
'abspath_errors': abspath_errors,
'unmatched_stdout': '\n'.join(unmatched_lines),
}


class MypyError(Exception):
"""
An error caught by mypy, e.g a type checker violation
or a syntax error.
"""


def pytest_terminal_summary(terminalreporter):
"""Report stderr and unrecognized lines from stdout."""
config = _pytest_terminal_summary_config
try:
results = _cached_json_results(config._mypy_results_path)
except FileNotFoundError:
# No MypyItems executed.
return
if results['unmatched_stdout'] or results['stderr']:
terminalreporter.section('mypy')
if results['unmatched_stdout']:
color = {'red': True} if results['status'] else {'green': True}
terminalreporter.write_line(results['unmatched_stdout'], **color)
if results['stderr']:
terminalreporter.write_line(results['stderr'], yellow=True)
os.remove(config._mypy_results_path)
Loading