Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for ForwardRef types #15

Merged
merged 5 commits into from
Oct 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 38 additions & 18 deletions dataclass_type_validator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import dataclasses
import typing
import functools
import sys
from typing import Any
from typing import Optional
from typing import Dict

GlobalNS_T = Dict[str, Any]


class TypeValidationError(Exception):
Expand Down Expand Up @@ -40,43 +44,46 @@ def _validate_type(expected_type: type, value: Any) -> Optional[str]:
return f'must be an instance of {expected_type}, but received {type(value)}'


def _validate_iterable_items(expected_type: type, value: Any, strict: bool) -> Optional[str]:
def _validate_iterable_items(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]:
expected_item_type = expected_type.__args__[0]
errors = [_validate_types(expected_type=expected_item_type, value=v, strict=strict) for v in value]
errors = [_validate_types(expected_type=expected_item_type, value=v, strict=strict, globalns=globalns)
for v in value]
errors = [x for x in errors if x]
if len(errors) > 0:
return f'must be an instance of {expected_type}, but there are some errors: {errors}'


def _validate_typing_list(expected_type: type, value: Any, strict: bool) -> Optional[str]:
def _validate_typing_list(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]:
if not isinstance(value, list):
return f'must be an instance of list, but received {type(value)}'
return _validate_iterable_items(expected_type, value, strict)
return _validate_iterable_items(expected_type, value, strict, globalns)


def _validate_typing_tuple(expected_type: type, value: Any, strict: bool) -> Optional[str]:
def _validate_typing_tuple(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]:
if not isinstance(value, tuple):
return f'must be an instance of tuple, but received {type(value)}'
return _validate_iterable_items(expected_type, value, strict)
return _validate_iterable_items(expected_type, value, strict, globalns)


def _validate_typing_frozenset(expected_type: type, value: Any, strict: bool) -> Optional[str]:
def _validate_typing_frozenset(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]:
if not isinstance(value, frozenset):
return f'must be an instance of frozenset, but received {type(value)}'
return _validate_iterable_items(expected_type, value, strict)
return _validate_iterable_items(expected_type, value, strict, globalns)


def _validate_typing_dict(expected_type: type, value: Any, strict: bool) -> Optional[str]:
def _validate_typing_dict(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]:
if not isinstance(value, dict):
return f'must be an instance of dict, but received {type(value)}'

expected_key_type = expected_type.__args__[0]
expected_value_type = expected_type.__args__[1]

key_errors = [_validate_types(expected_type=expected_key_type, value=k, strict=strict) for k in value.keys()]
key_errors = [_validate_types(expected_type=expected_key_type, value=k, strict=strict, globalns=globalns)
for k in value.keys()]
key_errors = [k for k in key_errors if k]

val_errors = [_validate_types(expected_type=expected_value_type, value=v, strict=strict) for v in value.values()]
val_errors = [_validate_types(expected_type=expected_value_type, value=v, strict=strict, globalns=globalns)
for v in value.values()]
val_errors = [v for v in val_errors if v]

if len(key_errors) > 0 and len(val_errors) > 0:
Expand All @@ -88,7 +95,7 @@ def _validate_typing_dict(expected_type: type, value: Any, strict: bool) -> Opti
return f'must be an instance of {expected_type}, but there are some errors in values: {val_errors}'


def _validate_typing_callable(expected_type: type, value: Any, strict: bool) -> Optional[str]:
def _validate_typing_callable(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]:
_ = strict
if not isinstance(value, type(lambda a: a)):
return f'must be an instance of {expected_type._name}, but received {type(value)}'
Expand All @@ -109,16 +116,16 @@ def _validate_typing_literal(expected_type: type, value: Any, strict: bool) -> O
}


def _validate_sequential_types(expected_type: type, value: Any, strict: bool) -> Optional[str]:
def _validate_sequential_types(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]:
validate_func = _validate_typing_mappings.get(expected_type._name)
if validate_func is not None:
return validate_func(expected_type, value, strict)
return validate_func(expected_type, value, strict, globalns)

if str(expected_type).startswith('typing.Literal'):
return _validate_typing_literal(expected_type, value, strict)

if str(expected_type).startswith('typing.Union') or str(expected_type).startswith('typing.Optional'):
is_valid = any(_validate_types(expected_type=t, value=value, strict=strict) is None
is_valid = any(_validate_types(expected_type=t, value=value, strict=strict, globalns=globalns) is None
for t in expected_type.__args__)
if not is_valid:
return f'must be an instance of {expected_type}, but received {value}'
Expand All @@ -128,24 +135,37 @@ def _validate_sequential_types(expected_type: type, value: Any, strict: bool) ->
raise RuntimeError(f'Unknown type of {expected_type} (_name = {expected_type._name})')


def _validate_types(expected_type: type, value: Any, strict: bool) -> Optional[str]:
def _validate_types(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]:
if isinstance(expected_type, type):
return _validate_type(expected_type=expected_type, value=value)

if isinstance(expected_type, typing._GenericAlias):
return _validate_sequential_types(expected_type=expected_type, value=value, strict=strict)
return _validate_sequential_types(expected_type=expected_type, value=value,
strict=strict, globalns=globalns)

if isinstance(expected_type, typing.ForwardRef):
referenced_type = _evaluate_forward_reference(expected_type, globalns)
return _validate_type(expected_type=referenced_type, value=value)


def _evaluate_forward_reference(ref_type: typing.ForwardRef, globalns: GlobalNS_T):
""" Support evaluating ForwardRef types on both Python 3.8 and 3.9. """
if sys.version_info < (3, 9):
return ref_type._evaluate(globalns, None)
return ref_type._evaluate(globalns, None, set())


def dataclass_type_validator(target, strict: bool = False):
fields = dataclasses.fields(target)
globalns = sys.modules[target.__module__].__dict__.copy()

errors = {}
for field in fields:
field_name = field.name
expected_type = field.type
value = getattr(target, field_name)

err = _validate_types(expected_type=expected_type, value=value, strict=strict)
err = _validate_types(expected_type=expected_type, value=value, strict=strict, globalns=globalns)
if err is not None:
errors[field_name] = err

Expand Down
28 changes: 28 additions & 0 deletions tests/test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,34 @@ def test_build_failure(self):
), DataclassTestCallable)


@dataclasses.dataclass(frozen=True)
class DataclassTestForwardRef:
number: 'int'
ref: typing.Optional['DataclassTestForwardRef'] = None

def __post_init__(self):
dataclass_type_validator(self)


class TestTypeValidationForwardRef:
def test_build_success(self):
assert isinstance(DataclassTestForwardRef(
number=1,
ref=None,
), DataclassTestForwardRef)
assert isinstance(DataclassTestForwardRef(
number=1,
ref=DataclassTestForwardRef(2, None)
), DataclassTestForwardRef)

def test_build_failure_on_number(self):
with pytest.raises(TypeValidationError):
assert isinstance(DataclassTestForwardRef(
number=1,
ref='string'
), DataclassTestForwardRef)


@dataclasses.dataclass(frozen=True)
class ChildValue:
child: str
Expand Down