Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add typed dict type checking #4340

Merged
merged 3 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,23 @@ def evaluate_style_namespaces(style: ComponentStyle) -> dict:
ComponentChild = Union[types.PrimitiveType, Var, BaseComponent]


def satisfies_type_hint(obj: Any, type_hint: Any) -> bool:
"""Check if an object satisfies a type hint.

Args:
obj: The object to check.
type_hint: The type hint to check against.

Returns:
Whether the object satisfies the type hint.
"""
if isinstance(obj, LiteralVar):
return types._isinstance(obj._var_value, type_hint)
if isinstance(obj, Var):
return types._issubclass(obj._var_type, type_hint)
return types._isinstance(obj, type_hint)


class Component(BaseComponent, ABC):
"""A component with style, event trigger and other props."""

Expand Down Expand Up @@ -460,8 +477,7 @@ def __init__(self, *args, **kwargs):
)
) or (
# Else just check if the passed var type is valid.
not passed_types
and not types._issubclass(passed_type, expected_type, value)
not passed_types and not satisfies_type_hint(value, expected_type)
):
value_name = value._js_expr if isinstance(value, Var) else value

Expand Down
49 changes: 48 additions & 1 deletion reflex/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
Callable,
ClassVar,
Dict,
FrozenSet,
Iterable,
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Expand All @@ -29,6 +31,7 @@
from typing import get_origin as get_origin_og

import sqlalchemy
from typing_extensions import is_typeddict

import reflex
from reflex.components.core.breakpoints import Breakpoints
Expand Down Expand Up @@ -494,6 +497,14 @@ def _issubclass(cls: GenericType, cls_check: GenericType, instance: Any = None)
if isinstance(instance, Breakpoints):
return _breakpoints_satisfies_typing(cls_check, instance)

if isinstance(cls_check_base, tuple):
cls_check_base = tuple(
cls_check_one if not is_typeddict(cls_check_one) else dict
for cls_check_one in cls_check_base
)
if is_typeddict(cls_check_base):
cls_check_base = dict

# Check if the types match.
try:
return cls_check_base == Any or issubclass(cls_base, cls_check_base)
Expand All @@ -503,6 +514,36 @@ def _issubclass(cls: GenericType, cls_check: GenericType, instance: Any = None)
raise TypeError(f"Invalid type for issubclass: {cls_base}") from te


def does_obj_satisfy_typed_dict(obj: Any, cls: GenericType) -> bool:
"""Check if an object satisfies a typed dict.

Args:
obj: The object to check.
cls: The typed dict to check against.

Returns:
Whether the object satisfies the typed dict.
"""
if not isinstance(obj, Mapping):
return False

key_names_to_values = get_type_hints(cls)
required_keys: FrozenSet[str] = getattr(cls, "__required_keys__", frozenset())

if not all(
isinstance(key, str)
and key in key_names_to_values
and _isinstance(value, key_names_to_values[key])
for key, value in obj.items()
):
return False

# TODO in 3.14: Implement https://peps.python.org/pep-0728/ if it's approved

# required keys are all present
return required_keys.issubset(required_keys)


def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool:
"""Check if an object is an instance of a class.

Expand All @@ -529,6 +570,12 @@ def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool:
origin = get_origin(cls)

if origin is None:
# cls is a typed dict
if is_typeddict(cls):
if nested:
return does_obj_satisfy_typed_dict(obj, cls)
return isinstance(obj, dict)

# cls is a simple class
return isinstance(obj, cls)

Expand All @@ -553,7 +600,7 @@ def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool:
and len(obj) == len(args)
and all(_isinstance(item, arg) for item, arg in zip(obj, args))
)
if origin is dict:
if origin in (dict, Breakpoints):
return isinstance(obj, dict) and all(
_isinstance(key, args[0]) and _isinstance(value, args[1])
for key, value in obj.items()
Expand Down
Loading