-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve Var type handling for better rx.Model attribute access #2010
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
cd2f7a8
[REF-904] model: dict() traverses into relationship attributes
masenf cc8f9df
types: Support Var with Union type (including Optional)
masenf fb2d912
Var: allow access to annotated attributes of rx.Model
masenf c546525
Improve Var default inference
masenf 47e9284
state: set BaseVar _var_type optional if default is None
masenf 530d997
test_state_union_optional: define Union and Optional vars
masenf 120e5d8
types: Only check for UnionType if it is present
masenf 1c825c9
Merge remote-tracking branch 'origin/main' into masenf/var-type-fixes
masenf fcc12c9
CR feedback: Add/fix comments, rename can_access_attribute
masenf File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,8 +3,22 @@ | |
from __future__ import annotations | ||
|
||
import contextlib | ||
import typing | ||
from typing import Any, Callable, Literal, Type, Union, _GenericAlias # type: ignore | ||
import types | ||
from typing import ( | ||
Any, | ||
Callable, | ||
Iterable, | ||
Literal, | ||
Optional, | ||
Type, | ||
Union, | ||
_GenericAlias, # type: ignore | ||
get_args, | ||
get_origin, | ||
get_type_hints, | ||
) | ||
|
||
from pydantic.fields import ModelField | ||
|
||
from reflex.base import Base | ||
from reflex.utils import serializers | ||
|
@@ -21,18 +35,6 @@ | |
ArgsSpec = Callable | ||
|
||
|
||
def get_args(alias: _GenericAlias) -> tuple[Type, ...]: | ||
"""Get the arguments of a type alias. | ||
|
||
Args: | ||
alias: The type alias. | ||
|
||
Returns: | ||
The arguments of the type alias. | ||
""" | ||
return alias.__args__ | ||
|
||
|
||
def is_generic_alias(cls: GenericType) -> bool: | ||
"""Check whether the class is a generic alias. | ||
|
||
|
@@ -69,11 +71,11 @@ def is_union(cls: GenericType) -> bool: | |
Returns: | ||
Whether the class is a Union. | ||
""" | ||
with contextlib.suppress(ImportError): | ||
from typing import _UnionGenericAlias # type: ignore | ||
# UnionType added in py3.10 | ||
if not hasattr(types, "UnionType"): | ||
return get_origin(cls) is Union | ||
|
||
return isinstance(cls, _UnionGenericAlias) | ||
return cls.__origin__ == Union if is_generic_alias(cls) else False | ||
return get_origin(cls) in [Union, types.UnionType] | ||
|
||
|
||
def is_literal(cls: GenericType) -> bool: | ||
|
@@ -85,7 +87,61 @@ def is_literal(cls: GenericType) -> bool: | |
Returns: | ||
Whether the class is a literal. | ||
""" | ||
return hasattr(cls, "__origin__") and cls.__origin__ is Literal | ||
return get_origin(cls) is Literal | ||
|
||
|
||
def is_optional(cls: GenericType) -> bool: | ||
"""Check if a class is an Optional. | ||
|
||
Args: | ||
cls: The class to check. | ||
|
||
Returns: | ||
Whether the class is an Optional. | ||
""" | ||
return is_union(cls) and type(None) in get_args(cls) | ||
|
||
|
||
def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None: | ||
"""Check if an attribute can be accessed on the cls and return its type. | ||
|
||
Supports pydantic models, unions, and annotated attributes on rx.Model. | ||
|
||
Args: | ||
cls: The class to check. | ||
name: The name of the attribute to check. | ||
|
||
Returns: | ||
The type of the attribute, if accessible, or None | ||
""" | ||
from reflex.model import Model | ||
|
||
if hasattr(cls, "__fields__") and name in cls.__fields__: | ||
# pydantic models | ||
field = cls.__fields__[name] | ||
type_ = field.outer_type_ | ||
if isinstance(type_, ModelField): | ||
type_ = type_.type_ | ||
if not field.required and field.default is None: | ||
# Ensure frontend uses null coalescing when accessing. | ||
type_ = Optional[type_] | ||
return type_ | ||
elif isinstance(cls, type) and issubclass(cls, Model): | ||
# Check in the annotations directly (for sqlmodel.Relationship) | ||
hints = get_type_hints(cls) | ||
if name in hints: | ||
type_ = hints[name] | ||
if isinstance(type_, ModelField): | ||
return type_.type_ | ||
return type_ | ||
elif is_union(cls): | ||
# Check in each arg of the annotation. | ||
for arg in get_args(cls): | ||
type_ = get_attribute_access_type(arg, name) | ||
if type_ is not None: | ||
# Return the first attribute type that is accessible. | ||
return type_ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add |
||
return None # Attribute is not accessible. | ||
|
||
|
||
def get_base_class(cls: GenericType) -> Type: | ||
|
@@ -171,7 +227,7 @@ def is_dataframe(value: Type) -> bool: | |
Returns: | ||
Whether the value is a dataframe. | ||
""" | ||
if is_generic_alias(value) or value == typing.Any: | ||
if is_generic_alias(value) or value == Any: | ||
return False | ||
return value.__name__ == "DataFrame" | ||
|
||
|
@@ -185,6 +241,8 @@ def is_valid_var_type(type_: Type) -> bool: | |
Returns: | ||
Whether the type is a valid prop type. | ||
""" | ||
if is_union(type_): | ||
return all((is_valid_var_type(arg) for arg in get_args(type_))) | ||
return _issubclass(type_, StateVar) or serializers.has_serializer(type_) | ||
|
||
|
||
|
@@ -200,9 +258,7 @@ def is_backend_variable(name: str) -> bool: | |
return name.startswith("_") and not name.startswith("__") | ||
|
||
|
||
def check_type_in_allowed_types( | ||
value_type: Type, allowed_types: typing.Iterable | ||
) -> bool: | ||
def check_type_in_allowed_types(value_type: Type, allowed_types: Iterable) -> bool: | ||
"""Check that a value type is found in a list of allowed types. | ||
|
||
Args: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
import json | ||
import os | ||
import sys | ||
from typing import Dict, Generator, List | ||
from typing import Dict, Generator, List, Optional, Union | ||
from unittest.mock import AsyncMock, Mock | ||
|
||
import pytest | ||
|
@@ -30,7 +30,7 @@ | |
StateProxy, | ||
StateUpdate, | ||
) | ||
from reflex.utils import prerequisites | ||
from reflex.utils import prerequisites, types | ||
from reflex.utils.format import json_dumps | ||
from reflex.vars import BaseVar, ComputedVar | ||
|
||
|
@@ -2239,3 +2239,52 @@ class MutableResetState(State): | |
instance.items.append([3, 3]) | ||
assert instance.items != default | ||
assert instance.items != copied_default | ||
|
||
|
||
class Custom1(Base): | ||
"""A custom class with a str field.""" | ||
|
||
foo: str | ||
|
||
|
||
class Custom2(Base): | ||
"""A custom class with a Custom1 field.""" | ||
|
||
c1: Optional[Custom1] = None | ||
c1r: Custom1 | ||
|
||
|
||
class Custom3(Base): | ||
"""A custom class with a Custom2 field.""" | ||
|
||
c2: Optional[Custom2] = None | ||
c2r: Custom2 | ||
|
||
|
||
def test_state_union_optional(): | ||
"""Test that state can be defined with Union and Optional vars.""" | ||
|
||
class UnionState(State): | ||
int_float: Union[int, float] = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think other Var operations are prepared to handle bare Union types... |
||
opt_int: Optional[int] | ||
c3: Optional[Custom3] | ||
c3i: Custom3 # implicitly required | ||
c3r: Custom3 = Custom3(c2r=Custom2(c1r=Custom1(foo=""))) | ||
custom_union: Union[Custom1, Custom2, Custom3] = Custom1(foo="") | ||
|
||
assert UnionState.c3.c2._var_name == "c3?.c2" # type: ignore | ||
assert UnionState.c3.c2.c1._var_name == "c3?.c2?.c1" # type: ignore | ||
assert UnionState.c3.c2.c1.foo._var_name == "c3?.c2?.c1?.foo" # type: ignore | ||
assert UnionState.c3.c2.c1r.foo._var_name == "c3?.c2?.c1r.foo" # type: ignore | ||
assert UnionState.c3.c2r.c1._var_name == "c3?.c2r.c1" # type: ignore | ||
assert UnionState.c3.c2r.c1.foo._var_name == "c3?.c2r.c1?.foo" # type: ignore | ||
assert UnionState.c3.c2r.c1r.foo._var_name == "c3?.c2r.c1r.foo" # type: ignore | ||
assert UnionState.c3i.c2._var_name == "c3i.c2" # type: ignore | ||
assert UnionState.c3r.c2._var_name == "c3r.c2" # type: ignore | ||
assert UnionState.custom_union.foo is not None # type: ignore | ||
assert UnionState.custom_union.c1 is not None # type: ignore | ||
assert UnionState.custom_union.c1r is not None # type: ignore | ||
assert UnionState.custom_union.c2 is not None # type: ignore | ||
assert UnionState.custom_union.c2r is not None # type: ignore | ||
assert types.is_optional(UnionState.opt_int._var_type) # type: ignore | ||
assert types.is_union(UnionState.int_float._var_type) # type: ignore |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add some comments explaining these lines