Skip to content

Commit

Permalink
support eventspec/eventchain in var operations (#4038)
Browse files Browse the repository at this point in the history
  • Loading branch information
adhami3310 authored Oct 3, 2024
1 parent ad0827c commit 73e8a4e
Show file tree
Hide file tree
Showing 9 changed files with 310 additions and 90 deletions.
22 changes: 17 additions & 5 deletions reflex/.templates/web/utils/state.js
Original file line number Diff line number Diff line change
Expand Up @@ -544,13 +544,19 @@ export const uploadFiles = async (

/**
* Create an event object.
* @param name The name of the event.
* @param payload The payload of the event.
* @param handler The client handler to process event.
* @param {string} name The name of the event.
* @param {Object.<string, Any>} payload The payload of the event.
* @param {Object.<string, (number|boolean)>} event_actions The actions to take on the event.
* @param {string} handler The client handler to process event.
* @returns The event object.
*/
export const Event = (name, payload = {}, handler = null) => {
return { name, payload, handler };
export const Event = (
name,
payload = {},
event_actions = {},
handler = null
) => {
return { name, payload, handler, event_actions };
};

/**
Expand Down Expand Up @@ -676,6 +682,12 @@ export const useEventLoop = (
if (!(args instanceof Array)) {
args = [args];
}

event_actions = events.reduce(
(acc, e) => ({ ...acc, ...e.event_actions }),
event_actions ?? {}
);

const _e = args.filter((o) => o?.preventDefault !== undefined)[0];

if (event_actions?.preventDefault && _e?.preventDefault) {
Expand Down
4 changes: 3 additions & 1 deletion reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,7 +1536,9 @@ async def on_event(self, sid, data):
"""
fields = json.loads(data)
# Get the event.
event = Event(**{k: v for k, v in fields.items() if k != "handler"})
event = Event(
**{k: v for k, v in fields.items() if k not in ("handler", "event_actions")}
)

self.token_to_sid[event.token] = sid
self.sid_to_token[sid] = event.token
Expand Down
50 changes: 32 additions & 18 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@
)
from reflex.event import (
EventChain,
EventChainVar,
EventHandler,
EventSpec,
EventVar,
call_event_fn,
call_event_handler,
get_handler_args,
Expand Down Expand Up @@ -514,7 +516,7 @@ def _create_event_chain(
Var,
EventHandler,
EventSpec,
List[Union[EventHandler, EventSpec]],
List[Union[EventHandler, EventSpec, EventVar]],
Callable,
],
) -> Union[EventChain, Var]:
Expand All @@ -532,11 +534,16 @@ def _create_event_chain(
"""
# If it's an event chain var, return it.
if isinstance(value, Var):
if value._var_type is not EventChain:
if isinstance(value, EventChainVar):
return value
elif isinstance(value, EventVar):
value = [value]
elif issubclass(value._var_type, (EventChain, EventSpec)):
return self._create_event_chain(args_spec, value.guess_type())
else:
raise ValueError(
f"Invalid event chain: {repr(value)} of type {type(value)}"
f"Invalid event chain: {str(value)} of type {value._var_type}"
)
return value
elif isinstance(value, EventChain):
# Trust that the caller knows what they're doing passing an EventChain directly
return value
Expand All @@ -547,7 +554,7 @@ def _create_event_chain(

# If the input is a list of event handlers, create an event chain.
if isinstance(value, List):
events: list[EventSpec] = []
events: List[Union[EventSpec, EventVar]] = []
for v in value:
if isinstance(v, (EventHandler, EventSpec)):
# Call the event handler to get the event.
Expand All @@ -561,6 +568,8 @@ def _create_event_chain(
"lambda inside an EventChain list."
)
events.extend(result)
elif isinstance(v, EventVar):
events.append(v)
else:
raise ValueError(f"Invalid event: {v}")

Expand All @@ -570,32 +579,30 @@ def _create_event_chain(
if isinstance(result, Var):
# Recursively call this function if the lambda returned an EventChain Var.
return self._create_event_chain(args_spec, result)
events = result
events = [*result]

# Otherwise, raise an error.
else:
raise ValueError(f"Invalid event chain: {value}")

# Add args to the event specs if necessary.
events = [e.with_args(get_handler_args(e)) for e in events]

# Collect event_actions from each spec
event_actions = {}
for e in events:
event_actions.update(e.event_actions)
events = [
(e.with_args(get_handler_args(e)) if isinstance(e, EventSpec) else e)
for e in events
]

# Return the event chain.
if isinstance(args_spec, Var):
return EventChain(
events=events,
args_spec=None,
event_actions=event_actions,
event_actions={},
)
else:
return EventChain(
events=events,
args_spec=args_spec,
event_actions=event_actions,
event_actions={},
)

def get_event_triggers(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -1030,8 +1037,11 @@ def _get_vars_from_event_triggers(
elif isinstance(event, EventChain):
event_args = []
for spec in event.events:
for args in spec.args:
event_args.extend(args)
if isinstance(spec, EventSpec):
for args in spec.args:
event_args.extend(args)
else:
event_args.append(spec)
yield event_trigger, event_args

def _get_vars(self, include_children: bool = False) -> list[Var]:
Expand Down Expand Up @@ -1105,8 +1115,12 @@ def _event_trigger_values_use_state(self) -> bool:
for trigger in self.event_triggers.values():
if isinstance(trigger, EventChain):
for event in trigger.events:
if event.handler.state_full_name:
return True
if isinstance(event, EventSpec):
if event.handler.state_full_name:
return True
else:
if event._var_state:
return True
elif isinstance(trigger, Var) and trigger._var_state:
return True
return False
Expand Down
Loading

0 comments on commit 73e8a4e

Please sign in to comment.