Skip to content

Commit

Permalink
Add plum.overload (#93)
Browse files Browse the repository at this point in the history
* Add `plum.overload` with tests and docs

* Produce nice error if `typing_extensions` is not available

* Incoporate simplifications suggested by @wch

* Update plum/overload.py

Co-authored-by: Gabriel de Marmiesse <gabrieldemarmiesse@gmail.com>

* Fix type stub

* Update plum/overload.py

Co-authored-by: Winston Chang <winston@stdout.org>

* Import `Any`

* Add ability to assert `mypy` errors

* Delete unnecessary `overload.pyi`

* Add tests for both mypy and pyright

* Fix CI

* Fix sentence

* Print output if things are OK

* Implement UX suggestion by @gabrieldemarmiesse

* Fix typo

* Fix typo

* Remove empty API section

---------

Co-authored-by: Gabriel de Marmiesse <gabrieldemarmiesse@gmail.com>
Co-authored-by: Winston Chang <winston@stdout.org>
  • Loading branch information
3 people authored Aug 19, 2023
1 parent 9273285 commit a0f59d6
Show file tree
Hide file tree
Showing 15 changed files with 368 additions and 18 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade --no-cache-dir -e '.[dev]'
- name: Test linter assertions
run: |
python check_linter_assertions.py tests/typechecked
- name: Test
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down
9 changes: 5 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
.PHONY: docmake docopen docinit docremove docupdate install test clean
.PHONY: install test

PACKAGE := plum

install:
pip install -e '.[dev]'

test:
pre-commit run --all-files && sleep 0.2 && \
PRAGMA_VERSION=`python -c "import sys; print('.'.join(map(str, sys.version_info[:2])))"` \
pytest tests -v --cov=$(PACKAGE) --cov-report html:cover --cov-report term-missing
python check_linter_assertions.py tests/typechecked
pre-commit run --all-files
PRAGMA_VERSION=`python -c "import sys; print('.'.join(map(str, sys.version_info[:2])))"` \
pytest tests -v --cov=$(PACKAGE) --cov-report html:cover --cov-report term-missing
256 changes: 256 additions & 0 deletions check_linter_assertions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import re
import subprocess
import sys
from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, List, Tuple

FileLineInfo = Dict[Path, Dict[int, List[str]]]
"""type: Type of a nested dictionary that gives for a collection of files line-wise
information, where the information is of the form `list[str]`."""


def next_noncomment_line(index: int, lines: List[str], path: Path) -> int:
"""Starting at `index`, find the next line with code.
Args:
index (int): Index to start at.
lines (list[str]): Source code lines.
path (:class:`pathlib.Path`): Path where the source code is from.
Returns:
int: Index of the next line with code.
"""
i = index + 1 # Start at the next line.
while i < len(lines):
line_content = lines[i].strip()
if line_content and not line_content.startswith("#"):
return i
i += 1
raise RuntimeError(f"{path}:{index}: Cannot match error assertion to code line.")


def parse_assertions(source_dir: Path, linter: str) -> FileLineInfo:
"""Parse all assertions in all Python files in `source_dir` for linter `linter`.
Args:
source_dir (:class:`pathlib.Path`): Source directory.
linter (str): Linter.
Returns:
:obj:`FileLineInfo`: Assertions.
"""
asserted_errors: FileLineInfo = defaultdict(lambda: defaultdict(list))

for path in source_dir.resolve().rglob("*.py"): # Important to `resolve` here!
with open(path, "r") as f:
lines = f.read().splitlines()
for i, line in enumerate(lines):
# Check if the line has an error assertion.
try:
code, comment = line.rsplit("# E:", 1)
except ValueError:
continue

# Find error assertions.
assertions = re.findall(linter + r"\(([^\)]*)\)", comment)

# There is nothing to do if there are no assertions.
if not assertions:
continue

# Find line number of the code that the assertions pertains to. If there is
# no code on the line, find the next non-comment line.
if not code.strip():
i = next_noncomment_line(i, lines, path)

line_number = i + 1 # Line numbers start at one.
asserted_errors[path][line_number].extend(assertions)

return asserted_errors


def parse_mypy_line(line: str) -> Tuple[Path, int, str, str]:
"""Parse a line of the output of `mypy`.
Args:
line (str): Line.
Raises:
ValueError: If the line cannot be parsed.
Returns:
:class:`pathlib.Path`: Path of file.
int: Line number.
str: Kind of message.
str: Message.
"""
path, line_number, status, message = line.split(":", 3)
# Path must be `resolve`d!
return Path(path).resolve(), int(line_number), status, message


def parse_pyright_line(line: str) -> Tuple[Path, int, str, str]:
"""Parse a line of the output of `pyright`.
Args:
line (str): Line.
Raises:
ValueError: If the line cannot be parsed.
Returns:
:class:`pathlib.Path`: Path of file.
int: Line number.
str: Kind of message.
str: Message.
"""
specification, status_message = line.split(" - ", 1)
path, line_number, _ = specification.split(":", 2)
status, message = status_message.split(":", 1)
# Path must be `resolve`d!
return Path(path.strip()).resolve(), int(line_number), status, message


parse_line: Dict[str, Callable[[str], Tuple[Path, int, str, str]]] = {
"mypy": parse_mypy_line,
"pyright": parse_pyright_line,
}
"""dict[str, Callable[[str], tuple[:class:`pathlib.Path`, int, str, str]]]: Map a
linter to a function that parses a line of the output of the linter."""


def parse_output(stdout: str, linter: str) -> FileLineInfo:
"""Parse the whole output of a linter.
Args:
stdout (str): `stdout` of the linter.
linter (str): Name of the linter.
Returns:
:obj:`FileLineInfo`: Linter errors.
"""
errors: FileLineInfo = defaultdict(lambda: defaultdict(list))

for line in stdout.splitlines():
# Parse line in the output of `mypy`. If it cannot be parsed, just skip it.
try:
path, line_number, status, message = parse_line[linter](line)
except ValueError:
continue

# We only need to validate errors.
if status.lower().strip() != "error":
continue

errors[Path(path)][line_number].append(message)

return errors


def run_linter(linter: str) -> str:
"""Run a linter and get the `stdout`.
Args:
linter (str): Name of the linter.
Returns:
str: `stdout`.
"""
p = subprocess.Popen(
[linter, source_dir],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout, stderr = p.communicate()
assert stderr == b"", "`stderr` must be empty."
return stdout.decode()


def get_missed(
errors: FileLineInfo,
assertions: FileLineInfo,
match: Callable[[str, str], bool],
) -> FileLineInfo:
"""Find unasserted errors.
Args:
error (:obj:`FileLineInfo`): Errors.
assertions (:obj:`FileLineInfo`): Assertions.
match (Callable[[str, str], bool]): Function that takes in an error and an
assertion and checks whether the assertion asserts the error.
Returns:
:obj:`FileLineInfo`: Unasserted errors.
"""
missed_errors: FileLineInfo = defaultdict(lambda: defaultdict(list))
for path, path_errors in errors.items():
# If there are no assertions for `path`, report all errors as missing.
if path not in assertions:
for line_number in errors[path]:
missed_errors[path][line_number].extend(errors[path][line_number])
continue
for line_number, line_errors in path_errors.items():
# If there are no assertions for `line_number`, report all errors as
# missing.
if line_number not in assertions[path]:
missed_errors[path][line_number].extend(errors[path][line_number])
continue
# Check every error for the line.
for e in line_errors:
if not any(match(e, a) for a in assertions[path][line_number]):
missed_errors[path][line_number].append(e)
return missed_errors


def check_linter(source_dir: Path, linter: str) -> bool:
"""Run a linter and check if all errors were asserted and all assertions yielded
errors. If not, print an overview of what was missed.
Args:
source_dir (:class:`pathlib.Path`): Source directory.
linter (str): Name of the linter.
Returns:
bool: `True` if nothing was missed, else `False`.
"""
stdout = run_linter(linter)

errors = parse_output(stdout, linter)
assertions = parse_assertions(source_dir, linter)

missed_errors = get_missed(
errors,
assertions,
lambda e, a: a.strip().lower() in e.strip().lower(),
)
missed_assertions = get_missed(
assertions,
errors,
lambda a, e: a.strip().lower() in e.strip().lower(),
)

for path, path_errors in missed_errors.items():
for line_number, line_errors in path_errors.items():
for error in line_errors:
print(f"{linter}:{path}:{line_number}: Error: {error.strip()}")

for path, path_assertions in missed_assertions.items():
for line_number, line_assertions in path_assertions.items():
for assertion in line_assertions:
print(
f"{linter}:{path}:{line_number}: "
f"Missed assertion: {assertion.strip()}"
)

return not (missed_errors or missed_assertions)


if __name__ == "__main__":
source_dir = Path(sys.argv[1]) # Files that must be validated
status = True
status |= check_linter(source_dir, "mypy")
status |= check_linter(source_dir, "pyright")
if status:
print("All OK!")
exit(0 if status else 1)
1 change: 1 addition & 0 deletions docs/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ chapters:
- file: classes
- file: keyword_arguments
- file: comparison
- file: integration
- file: advanced_usage
sections:
- file: conversion_promotion
Expand Down
34 changes: 34 additions & 0 deletions docs/integration.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Integration with Linters and `mypy`

Plum's integration with linters and `mypy` is unfortunately limited.
Properly supporting multiple dispatch in these tools is challenging for a [variety of reasons](https://github.com/python/mypy/issues/11727).
In this section, we collect various patterns in which Plum plays nicely with type checking.

## Overload Support

At the moment, the only know pattern in which Plum produces `mypy`-compliant code uses `typing.overload`.

An example is as follows:

```python
from plum import dispatch, overload


@overload
def f(x: int) -> int:
return x


@overload
def f(x: str) -> str:
return x


@dispatch
def f(x):
pass
```

In the above, for Python versions prior to 3.11, `plum.overload` is `typing_extensions.overload`.
For this pattern to work in Python versions prior to 3.11, you must use `typing_extensions.overload`, not `typing.overload`.
By importing `overload` from `plum`, you will always use the correct `overload`.
1 change: 1 addition & 0 deletions plum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .autoreload import * # noqa: F401, F403
from .dispatcher import * # noqa: F401, F403
from .function import * # noqa: F401, F403
from .overload import overload # noqa: F401
from .parametric import * # noqa: F401, F403
from .promotion import * # noqa: F401, F403
from .resolver import * # noqa: F401, F403
Expand Down
23 changes: 16 additions & 7 deletions plum/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union

from .function import Function
from .overload import get_overloads
from .signature import Signature
from .util import TypeHint, get_class, is_in_class
from .util import Callable, TypeHint, get_class, is_in_class

__all__ = ["Dispatcher", "dispatch", "clear_all_cache"]

T = TypeVar("T", bound=Callable[..., Any])


class Dispatcher:
"""A namespace for functions.
Expand All @@ -20,11 +23,7 @@ def __init__(self):
self.functions: Dict[str, Function] = {}
self.classes: Dict[str, Dict[str, Function]] = {}

def __call__(
self,
method: Optional[Callable] = None,
precedence: int = 0,
) -> Callable:
def __call__(self, method: Optional[T] = None, precedence: int = 0) -> T:
"""Decorator to register for a particular signature.
Args:
Expand All @@ -36,6 +35,16 @@ def __call__(
if method is None:
return lambda m: self(m, precedence=precedence)

# If `method` has overloads, assume that those overloads need to be registered
# and that `method` is not an implementation.
overloads = get_overloads(method)
if overloads:
for overload_method in overloads:
# All `f` returned by `self._add_method` are the same.
f = self._add_method(overload_method, None, precedence=precedence)
# We do not need to register `method`, because it is not an implementation.
return f

# The signature will be automatically derived from `method`, so we can safely
# set the signature argument to `None`.
return self._add_method(method, None, precedence=precedence)
Expand Down
2 changes: 1 addition & 1 deletion plum/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __doc__(self) -> Optional[str]:
"""
try:
self._resolve_pending_registrations()
except NameError:
except NameError: # pragma: specific no cover 3.7 3.8 3.9
# When `staticmethod` is combined with
# `from __future__ import annotations`, in Python 3.10 and higher
# `staticmethod` will attempt to inherit `__doc__` (see
Expand Down
Loading

0 comments on commit a0f59d6

Please sign in to comment.