Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Improve error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
David Robertson committed Aug 15, 2022
1 parent 1db56bc commit a056718
Showing 1 changed file with 65 additions and 30 deletions.
95 changes: 65 additions & 30 deletions scripts-dev/check_pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,34 @@
R = TypeVar("R")


class NonStrictTypeError(Exception):
class ModelCheckerException(Exception):
"""Dummy exception. Allows us to detect unwanted types during a module import."""


class MissingStrictInConstrainedTypeException(ModelCheckerException):
factory_name: str

def __init__(self, factory_name: str):
self.factory_name = factory_name


class FieldHasUnwantedTypeException(ModelCheckerException):
message: str

def __init__(self, message: str):
self.message = message


def make_wrapper(factory: Callable[P, R]) -> Callable[P, R]:
"""We patch `constr` and friends with wrappers that enforce strict=True. """
"""We patch `constr` and friends with wrappers that enforce strict=True."""

@functools.wraps(factory)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# type-ignore: should be redundant once we can use https://github.com/python/mypy/pull/12668
if "strict" not in kwargs: # type: ignore[attr-defined]
raise NonStrictTypeError()
raise MissingStrictInConstrainedTypeException(factory.__name__)
if not kwargs["strict"]: # type: ignore[index]
raise NonStrictTypeError()
raise MissingStrictInConstrainedTypeException(factory.__name__)
return factory(*args, **kwargs)

return wrapper
Expand Down Expand Up @@ -113,18 +127,25 @@ def __init_subclass__(cls: Type[PydanticBaseModel], **kwargs: object):
# Note that field.type_ and field.outer_type are computed based on the
# annotation type, see pydantic.fields.ModelField._type_analysis
if field_type_unwanted(field.outer_type_):
raise NonStrictTypeError()
# TODO: this only reports the first bad field. Can we find all bad ones
# and report them all?
raise FieldHasUnwantedTypeException(
f"{cls.__module__}.{cls.__qualname__} has field '{field.name}' "
f"with unwanted type `{field.outer_type_}`"
)


@contextmanager
def monkeypatch_pydantic() -> Generator[None, None, None]:
"""Patch pydantic with our snooping versions of BaseModel and the con* functions.
Most Synapse code ought to import the patched objects directly from `pydantic`.
But we include their containing models `pydantic.main` and `pydantic.types` for
completeness.
If the snooping functions see something they don't like, they'll raise a
ModelCheckingException instance.
"""
with contextlib.ExitStack() as patches:
# Most Synapse code ought to import the patched objects directly from
# `pydantic`. But we also patch their containing modules `pydantic.main` and
# `pydantic.types` for completeness.
patch_basemodel1 = unittest.mock.patch(
"pydantic.BaseModel", new=PatchedBaseModel
)
Expand All @@ -144,10 +165,20 @@ def monkeypatch_pydantic() -> Generator[None, None, None]:
yield


def format_error(e: NonStrictTypeError) -> str:
def format_model_checker_exception(e: ModelCheckerException) -> str:
"""Work out which line of code caused e. Format the line in a human-friendly way."""
frame_summary = traceback.extract_tb(e.__traceback__)[-2]
return traceback.format_list([frame_summary])[0].lstrip()
# TODO. FieldHasUnwantedTypeException gives better error messages. Can we ditch the
# patches of constr() etc, and instead inspect fields to look for ConstrainedStr
# with strict=False? There is some difficulty with the inheritance hierarchy
# because StrictStr < ConstrainedStr < str.
if isinstance(e, FieldHasUnwantedTypeException):
return e.message
elif isinstance(e, MissingStrictInConstrainedTypeException):
frame_summary = traceback.extract_tb(e.__traceback__)[-2]
return (
f"Missing `strict=True` from {e.factory_name}() call \n"
+ traceback.format_list([frame_summary])[0].lstrip()
)


def lint() -> int:
Expand All @@ -168,26 +199,30 @@ def do_lint() -> Set[str]:

with monkeypatch_pydantic():
try:
synapse = importlib.import_module("synapse")
except NonStrictTypeError as e:
# TODO: make "synapse" an argument so we can target this script at
# a subpackage
module = importlib.import_module("synapse")
except ModelCheckerException as e:
logger.warning("Bad annotation found when importing synapse")
failures.add(format_error(e))
failures.add(format_model_checker_exception(e))
return failures

try:
modules = list(pkgutil.walk_packages(synapse.__path__, "synapse."))
except NonStrictTypeError as e:
modules = list(
pkgutil.walk_packages(module.__path__, f"{module.__name__}.")
)
except ModelCheckerException as e:
logger.warning("Bad annotation found when looking for modules to import")
failures.add(format_error(e))
failures.add(format_model_checker_exception(e))
return failures

for module in modules:
logger.debug("Importing %s", module.name)
try:
importlib.import_module(module.name)
except NonStrictTypeError as e:
except ModelCheckerException as e:
logger.warning(f"Bad annotation found when importing {module.name}")
failures.add(format_error(e))
failures.add(format_model_checker_exception(e))

return failures

Expand All @@ -208,7 +243,7 @@ def run_test_snippet(source: str) -> None:

class TestConstrainedTypesPatch(unittest.TestCase):
def test_expression_without_strict_raises(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
from pydantic import constr
Expand All @@ -217,7 +252,7 @@ def test_expression_without_strict_raises(self) -> None:
)

def test_called_as_module_attribute_raises(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
import pydantic
Expand All @@ -226,7 +261,7 @@ def test_called_as_module_attribute_raises(self) -> None:
)

def test_wildcard_import_raises(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
from pydantic import *
Expand All @@ -235,7 +270,7 @@ def test_wildcard_import_raises(self) -> None:
)

def test_alternative_import_raises(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
from pydantic.types import constr
Expand All @@ -244,7 +279,7 @@ def test_alternative_import_raises(self) -> None:
)

def test_alternative_import_attribute_raises(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
import pydantic.types
Expand All @@ -253,7 +288,7 @@ def test_alternative_import_attribute_raises(self) -> None:
)

def test_kwarg_but_no_strict_raises(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
from pydantic import constr
Expand All @@ -262,7 +297,7 @@ def test_kwarg_but_no_strict_raises(self) -> None:
)

def test_kwarg_strict_False_raises(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
from pydantic import constr
Expand All @@ -280,7 +315,7 @@ def test_kwarg_strict_True_doesnt_raise(self) -> None:
)

def test_annotation_without_strict_raises(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
from pydantic import constr
Expand All @@ -289,7 +324,7 @@ def test_annotation_without_strict_raises(self) -> None:
)

def test_field_annotation_without_strict_raises(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
from pydantic import BaseModel, conint
Expand Down Expand Up @@ -317,7 +352,7 @@ class TestFieldTypeInspection(unittest.TestCase):
]
)
def test_field_holding_unwanted_type_raises(self, annotation: str) -> None:
with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
f"""
from typing import *
Expand Down Expand Up @@ -355,7 +390,7 @@ class C(BaseModel):
)

def test_field_holding_str_raises_with_alternative_import(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
from pydantic.main import BaseModel
Expand Down

0 comments on commit a056718

Please sign in to comment.