diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index dd405af9..59df9640 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,6 +22,9 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install --upgrade --no-cache-dir -e '.[dev]' + - name: Test linter assertions + run: | + python check_linter_assertions.py tests/typechecked - name: Test env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/Makefile b/Makefile index 0fcaee1c..a4aa001b 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: docmake docopen docinit docremove docupdate install test clean +.PHONY: install test PACKAGE := plum @@ -6,6 +6,7 @@ install: pip install -e '.[dev]' test: - pre-commit run --all-files && sleep 0.2 && \ - PRAGMA_VERSION=`python -c "import sys; print('.'.join(map(str, sys.version_info[:2])))"` \ - pytest tests -v --cov=$(PACKAGE) --cov-report html:cover --cov-report term-missing + python check_linter_assertions.py tests/typechecked + pre-commit run --all-files + PRAGMA_VERSION=`python -c "import sys; print('.'.join(map(str, sys.version_info[:2])))"` \ + pytest tests -v --cov=$(PACKAGE) --cov-report html:cover --cov-report term-missing diff --git a/check_linter_assertions.py b/check_linter_assertions.py new file mode 100644 index 00000000..f563553c --- /dev/null +++ b/check_linter_assertions.py @@ -0,0 +1,256 @@ +import re +import subprocess +import sys +from collections import defaultdict +from pathlib import Path +from typing import Callable, Dict, List, Tuple + +FileLineInfo = Dict[Path, Dict[int, List[str]]] +"""type: Type of a nested dictionary that gives for a collection of files line-wise +information, where the information is of the form `list[str]`.""" + + +def next_noncomment_line(index: int, lines: List[str], path: Path) -> int: + """Starting at `index`, find the next line with code. + + Args: + index (int): Index to start at. + lines (list[str]): Source code lines. + path (:class:`pathlib.Path`): Path where the source code is from. + + Returns: + int: Index of the next line with code. + """ + i = index + 1 # Start at the next line. + while i < len(lines): + line_content = lines[i].strip() + if line_content and not line_content.startswith("#"): + return i + i += 1 + raise RuntimeError(f"{path}:{index}: Cannot match error assertion to code line.") + + +def parse_assertions(source_dir: Path, linter: str) -> FileLineInfo: + """Parse all assertions in all Python files in `source_dir` for linter `linter`. + + Args: + source_dir (:class:`pathlib.Path`): Source directory. + linter (str): Linter. + + Returns: + :obj:`FileLineInfo`: Assertions. + """ + asserted_errors: FileLineInfo = defaultdict(lambda: defaultdict(list)) + + for path in source_dir.resolve().rglob("*.py"): # Important to `resolve` here! + with open(path, "r") as f: + lines = f.read().splitlines() + for i, line in enumerate(lines): + # Check if the line has an error assertion. + try: + code, comment = line.rsplit("# E:", 1) + except ValueError: + continue + + # Find error assertions. + assertions = re.findall(linter + r"\(([^\)]*)\)", comment) + + # There is nothing to do if there are no assertions. + if not assertions: + continue + + # Find line number of the code that the assertions pertains to. If there is + # no code on the line, find the next non-comment line. + if not code.strip(): + i = next_noncomment_line(i, lines, path) + + line_number = i + 1 # Line numbers start at one. + asserted_errors[path][line_number].extend(assertions) + + return asserted_errors + + +def parse_mypy_line(line: str) -> Tuple[Path, int, str, str]: + """Parse a line of the output of `mypy`. + + Args: + line (str): Line. + + Raises: + ValueError: If the line cannot be parsed. + + Returns: + :class:`pathlib.Path`: Path of file. + int: Line number. + str: Kind of message. + str: Message. + """ + path, line_number, status, message = line.split(":", 3) + # Path must be `resolve`d! + return Path(path).resolve(), int(line_number), status, message + + +def parse_pyright_line(line: str) -> Tuple[Path, int, str, str]: + """Parse a line of the output of `pyright`. + + Args: + line (str): Line. + + Raises: + ValueError: If the line cannot be parsed. + + Returns: + :class:`pathlib.Path`: Path of file. + int: Line number. + str: Kind of message. + str: Message. + """ + specification, status_message = line.split(" - ", 1) + path, line_number, _ = specification.split(":", 2) + status, message = status_message.split(":", 1) + # Path must be `resolve`d! + return Path(path.strip()).resolve(), int(line_number), status, message + + +parse_line: Dict[str, Callable[[str], Tuple[Path, int, str, str]]] = { + "mypy": parse_mypy_line, + "pyright": parse_pyright_line, +} +"""dict[str, Callable[[str], tuple[:class:`pathlib.Path`, int, str, str]]]: Map a +linter to a function that parses a line of the output of the linter.""" + + +def parse_output(stdout: str, linter: str) -> FileLineInfo: + """Parse the whole output of a linter. + + Args: + stdout (str): `stdout` of the linter. + linter (str): Name of the linter. + + Returns: + :obj:`FileLineInfo`: Linter errors. + """ + errors: FileLineInfo = defaultdict(lambda: defaultdict(list)) + + for line in stdout.splitlines(): + # Parse line in the output of `mypy`. If it cannot be parsed, just skip it. + try: + path, line_number, status, message = parse_line[linter](line) + except ValueError: + continue + + # We only need to validate errors. + if status.lower().strip() != "error": + continue + + errors[Path(path)][line_number].append(message) + + return errors + + +def run_linter(linter: str) -> str: + """Run a linter and get the `stdout`. + + Args: + linter (str): Name of the linter. + + Returns: + str: `stdout`. + """ + p = subprocess.Popen( + [linter, source_dir], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + stdout, stderr = p.communicate() + assert stderr == b"", "`stderr` must be empty." + return stdout.decode() + + +def get_missed( + errors: FileLineInfo, + assertions: FileLineInfo, + match: Callable[[str, str], bool], +) -> FileLineInfo: + """Find unasserted errors. + + Args: + error (:obj:`FileLineInfo`): Errors. + assertions (:obj:`FileLineInfo`): Assertions. + match (Callable[[str, str], bool]): Function that takes in an error and an + assertion and checks whether the assertion asserts the error. + + Returns: + :obj:`FileLineInfo`: Unasserted errors. + """ + missed_errors: FileLineInfo = defaultdict(lambda: defaultdict(list)) + for path, path_errors in errors.items(): + # If there are no assertions for `path`, report all errors as missing. + if path not in assertions: + for line_number in errors[path]: + missed_errors[path][line_number].extend(errors[path][line_number]) + continue + for line_number, line_errors in path_errors.items(): + # If there are no assertions for `line_number`, report all errors as + # missing. + if line_number not in assertions[path]: + missed_errors[path][line_number].extend(errors[path][line_number]) + continue + # Check every error for the line. + for e in line_errors: + if not any(match(e, a) for a in assertions[path][line_number]): + missed_errors[path][line_number].append(e) + return missed_errors + + +def check_linter(source_dir: Path, linter: str) -> bool: + """Run a linter and check if all errors were asserted and all assertions yielded + errors. If not, print an overview of what was missed. + + Args: + source_dir (:class:`pathlib.Path`): Source directory. + linter (str): Name of the linter. + + Returns: + bool: `True` if nothing was missed, else `False`. + """ + stdout = run_linter(linter) + + errors = parse_output(stdout, linter) + assertions = parse_assertions(source_dir, linter) + + missed_errors = get_missed( + errors, + assertions, + lambda e, a: a.strip().lower() in e.strip().lower(), + ) + missed_assertions = get_missed( + assertions, + errors, + lambda a, e: a.strip().lower() in e.strip().lower(), + ) + + for path, path_errors in missed_errors.items(): + for line_number, line_errors in path_errors.items(): + for error in line_errors: + print(f"{linter}:{path}:{line_number}: Error: {error.strip()}") + + for path, path_assertions in missed_assertions.items(): + for line_number, line_assertions in path_assertions.items(): + for assertion in line_assertions: + print( + f"{linter}:{path}:{line_number}: " + f"Missed assertion: {assertion.strip()}" + ) + + return not (missed_errors or missed_assertions) + + +if __name__ == "__main__": + source_dir = Path(sys.argv[1]) # Files that must be validated + status = True + status |= check_linter(source_dir, "mypy") + status |= check_linter(source_dir, "pyright") + if status: + print("All OK!") + exit(0 if status else 1) diff --git a/docs/_toc.yml b/docs/_toc.yml index d11bb6a8..e0345f47 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -12,6 +12,7 @@ chapters: - file: classes - file: keyword_arguments - file: comparison + - file: integration - file: advanced_usage sections: - file: conversion_promotion diff --git a/docs/integration.md b/docs/integration.md new file mode 100644 index 00000000..93508941 --- /dev/null +++ b/docs/integration.md @@ -0,0 +1,34 @@ +# Integration with Linters and `mypy` + +Plum's integration with linters and `mypy` is unfortunately limited. +Properly supporting multiple dispatch in these tools is challenging for a [variety of reasons](https://github.com/python/mypy/issues/11727). +In this section, we collect various patterns in which Plum plays nicely with type checking. + +## Overload Support + +At the moment, the only know pattern in which Plum produces `mypy`-compliant code uses `typing.overload`. + +An example is as follows: + +```python +from plum import dispatch, overload + + +@overload +def f(x: int) -> int: + return x + + +@overload +def f(x: str) -> str: + return x + + +@dispatch +def f(x): + pass +``` + +In the above, for Python versions prior to 3.11, `plum.overload` is `typing_extensions.overload`. +For this pattern to work in Python versions prior to 3.11, you must use `typing_extensions.overload`, not `typing.overload`. +By importing `overload` from `plum`, you will always use the correct `overload`. diff --git a/plum/__init__.py b/plum/__init__.py index 08b68b14..cf2515ec 100644 --- a/plum/__init__.py +++ b/plum/__init__.py @@ -14,6 +14,7 @@ from .autoreload import * # noqa: F401, F403 from .dispatcher import * # noqa: F401, F403 from .function import * # noqa: F401, F403 +from .overload import overload # noqa: F401 from .parametric import * # noqa: F401, F403 from .promotion import * # noqa: F401, F403 from .resolver import * # noqa: F401, F403 diff --git a/plum/dispatcher.py b/plum/dispatcher.py index ad931c26..bf018c39 100644 --- a/plum/dispatcher.py +++ b/plum/dispatcher.py @@ -1,11 +1,14 @@ -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union from .function import Function +from .overload import get_overloads from .signature import Signature -from .util import TypeHint, get_class, is_in_class +from .util import Callable, TypeHint, get_class, is_in_class __all__ = ["Dispatcher", "dispatch", "clear_all_cache"] +T = TypeVar("T", bound=Callable[..., Any]) + class Dispatcher: """A namespace for functions. @@ -20,11 +23,7 @@ def __init__(self): self.functions: Dict[str, Function] = {} self.classes: Dict[str, Dict[str, Function]] = {} - def __call__( - self, - method: Optional[Callable] = None, - precedence: int = 0, - ) -> Callable: + def __call__(self, method: Optional[T] = None, precedence: int = 0) -> T: """Decorator to register for a particular signature. Args: @@ -36,6 +35,16 @@ def __call__( if method is None: return lambda m: self(m, precedence=precedence) + # If `method` has overloads, assume that those overloads need to be registered + # and that `method` is not an implementation. + overloads = get_overloads(method) + if overloads: + for overload_method in overloads: + # All `f` returned by `self._add_method` are the same. + f = self._add_method(overload_method, None, precedence=precedence) + # We do not need to register `method`, because it is not an implementation. + return f + # The signature will be automatically derived from `method`, so we can safely # set the signature argument to `None`. return self._add_method(method, None, precedence=precedence) diff --git a/plum/function.py b/plum/function.py index 20e6b36d..7b97c006 100644 --- a/plum/function.py +++ b/plum/function.py @@ -127,7 +127,7 @@ def __doc__(self) -> Optional[str]: """ try: self._resolve_pending_registrations() - except NameError: + except NameError: # pragma: specific no cover 3.7 3.8 3.9 # When `staticmethod` is combined with # `from __future__ import annotations`, in Python 3.10 and higher # `staticmethod` will attempt to inherit `__doc__` (see diff --git a/plum/overload.py b/plum/overload.py new file mode 100644 index 00000000..ba267b46 --- /dev/null +++ b/plum/overload.py @@ -0,0 +1,9 @@ +import sys + +if sys.version_info >= (3, 11): # pragma: specific no cover 3.7 3.8 3.9 3.10 + from typing import get_overloads, overload +else: # pragma: specific no cover 3.11 + from typing_extensions import get_overloads, overload + + +__all__ = ["overload", "get_overloads"] diff --git a/plum/py.typed b/plum/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/plum/util.py b/plum/util.py index 969a7447..fbce29cf 100644 --- a/plum/util.py +++ b/plum/util.py @@ -3,7 +3,13 @@ import typing from typing import Callable, List +if sys.version_info.minor <= 8: # pragma: specific no cover 3.9 3.10 3.11 + from typing import Callable +else: # pragma: specific no cover 3.8 + from collections.abc import Callable + __all__ = [ + "Callable", "TypeHint", "repr_short", "Missing", diff --git a/pyproject.toml b/pyproject.toml index f24bc57d..bd4eec59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,10 @@ classifiers = [ dynamic = ["version"] requires-python = ">=3.8" -dependencies = ["beartype"] +dependencies = [ + "beartype", + "typing-extensions; python_version<='3.10'", +] [project.optional-dependencies] dev = [ @@ -34,6 +37,8 @@ dev = [ "build", "tox", "jupyter-book", + "mypy", + "pyright", ] [project.urls] diff --git a/tests/test_type.py b/tests/test_type.py index a4d7ba18..1396eb77 100644 --- a/tests/test_type.py +++ b/tests/test_type.py @@ -2,11 +2,6 @@ import sys import typing -if sys.version_info.minor <= 8: - from typing import Callable -else: - from collections.abc import Callable - try: from typing import Literal except ImportError: @@ -32,6 +27,7 @@ class Literal(metaclass=LiteralMeta): resolve_type_hint, type_mapping, ) +from plum.util import Callable def test_resolvabletype(): diff --git a/tests/typechecked/__init__.py b/tests/typechecked/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/typechecked/test_overload.py b/tests/typechecked/test_overload.py new file mode 100644 index 00000000..47f4795b --- /dev/null +++ b/tests/typechecked/test_overload.py @@ -0,0 +1,29 @@ +import pytest + +from plum import Dispatcher, NotFoundLookupError, overload + +dispatch = Dispatcher() + + +@overload +def f(x: int) -> int: # E: pyright(marked as overload) + return x + + +@overload +def f(x: str) -> str: # E: pyright(marked as overload) + return x + + +@dispatch +def f(x): # E: pyright(overloaded implementation is not consistent) + pass + + +def test_overload() -> None: + assert f(1) == 1 + assert f("1") == "1" + with pytest.raises(NotFoundLookupError): + # E: pyright(argument of type "float" cannot be assigned to parameter "x") + # E: mypy(no overload variant of "f" matches argument type "float") + f(1.0)