Skip to content

Commit

Permalink
implement _evaluate in state (#4018)
Browse files Browse the repository at this point in the history
* implement _evaluate in state

* add warning

* use typing_extension

* add integration test
  • Loading branch information
adhami3310 authored Sep 27, 2024
1 parent 1b3422d commit 62021b0
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 2 deletions.
32 changes: 32 additions & 0 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

import dill
from sqlalchemy.orm import DeclarativeBase
from typing_extensions import Self

from reflex.config import get_config
from reflex.vars.base import (
Expand All @@ -43,6 +44,7 @@
Var,
computed_var,
dispatch,
get_unique_variable_name,
is_computed_var,
)

Expand Down Expand Up @@ -695,6 +697,36 @@ def _item_is_event_handler(name: str, value: Any) -> bool:
and hasattr(value, "__code__")
)

@classmethod
def _evaluate(cls, f: Callable[[Self], Any]) -> Var:
"""Evaluate a function to a ComputedVar. Experimental.
Args:
f: The function to evaluate.
Returns:
The ComputedVar.
"""
console.warn(
"The _evaluate method is experimental and may be removed in future versions."
)
from reflex.components.base.fragment import fragment
from reflex.components.component import Component

unique_var_name = get_unique_variable_name()

@computed_var(_js_expr=unique_var_name, return_type=Component)
def computed_var_func(state: Self):
return fragment(f(state))

setattr(cls, unique_var_name, computed_var_func)
cls.computed_vars[unique_var_name] = computed_var_func
cls.vars[unique_var_name] = computed_var_func
cls._update_substate_inherited_vars({unique_var_name: computed_var_func})
cls._always_dirty_computed_vars.add(unique_var_name)

return getattr(cls, unique_var_name)

@classmethod
def _mixins(cls) -> List[Type]:
"""Get the mixin classes of the state.
Expand Down
5 changes: 3 additions & 2 deletions reflex/vars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,8 +1559,9 @@ def __init__(
Raises:
TypeError: If the computed var dependencies are not Var instances or var names.
"""
hints = get_type_hints(fget)
hint = hints.get("return", Any)
hint = kwargs.pop("return_type", None) or get_type_hints(fget).get(
"return", Any
)

kwargs["_js_expr"] = kwargs.pop("_js_expr", fget.__name__)
kwargs["_var_type"] = kwargs.pop("_var_type", hint)
Expand Down
15 changes: 15 additions & 0 deletions tests/integration/test_dynamic_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def DynamicComponents():
import reflex as rx

class DynamicComponentsState(rx.State):
value: int = 10

button: rx.Component = rx.button(
"Click me",
custom_attrs={
Expand Down Expand Up @@ -52,11 +54,20 @@ def client_token_component(self) -> rx.Component:

app = rx.App()

def factorial(n: int) -> int:
if n == 0:
return 1
return n * factorial(n - 1)

@app.add_page
def index():
return rx.vstack(
DynamicComponentsState.client_token_component,
DynamicComponentsState.button,
rx.text(
DynamicComponentsState._evaluate(lambda state: factorial(state.value)),
id="factorial",
),
)


Expand Down Expand Up @@ -150,3 +161,7 @@ def test_dynamic_components(driver, dynamic_components: AppHarness):
dynamic_components.poll_for_content(button, exp_not_equal="Click me")
== "Clicked"
)

factorial = poll_for_result(lambda: driver.find_element(By.ID, "factorial"))
assert factorial
assert factorial.text == "3628800"

0 comments on commit 62021b0

Please sign in to comment.