Skip to content

Commit ab2274c

Browse files
joaomdmouradevin-ai-integration[bot]Joe Moura
authored
Stateful flows (#1931)
* fix: ensure persisted state overrides class defaults - Remove early return in Flow.__init__ to allow proper state initialization - Add test_flow_default_override.py to verify state override behavior - Fix issue where default values weren't being overridden by persisted state Fixes the issue where persisted state values weren't properly overriding class defaults when restarting a flow with a previously saved state ID. Co-Authored-By: Joe Moura <joao@crewai.com> * test: improve state restoration verification with has_set_count flag Co-Authored-By: Joe Moura <joao@crewai.com> * test: add has_set_count field to PoemState Co-Authored-By: Joe Moura <joao@crewai.com> * refactoring test * fix: ensure persisted state overrides class defaults - Remove early return in Flow.__init__ to allow proper state initialization - Add test_flow_default_override.py to verify state override behavior - Fix issue where default values weren't being overridden by persisted state Fixes the issue where persisted state values weren't properly overriding class defaults when restarting a flow with a previously saved state ID. Co-Authored-By: Joe Moura <joao@crewai.com> * test: improve state restoration verification with has_set_count flag Co-Authored-By: Joe Moura <joao@crewai.com> * test: add has_set_count field to PoemState Co-Authored-By: Joe Moura <joao@crewai.com> * refactoring test * Fixing flow state * fixing peristed stateful flows * linter * type fix --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Joe Moura <joao@crewai.com>
1 parent 3e4f112 commit ab2274c

9 files changed

+337
-220
lines changed

src/crewai/flow/flow.py

+48-72
Original file line numberDiff line numberDiff line change
@@ -447,14 +447,12 @@ class _FlowGeneric(cls): # type: ignore
447447
def __init__(
448448
self,
449449
persistence: Optional[FlowPersistence] = None,
450-
restore_uuid: Optional[str] = None,
451450
**kwargs: Any,
452451
) -> None:
453452
"""Initialize a new Flow instance.
454453
455454
Args:
456455
persistence: Optional persistence backend for storing flow states
457-
restore_uuid: Optional UUID to restore state from persistence
458456
**kwargs: Additional state values to initialize or override
459457
"""
460458
# Initialize basic instance attributes
@@ -464,64 +462,12 @@ def __init__(
464462
self._method_outputs: List[Any] = [] # List to store all method outputs
465463
self._persistence: Optional[FlowPersistence] = persistence
466464

467-
# Validate state model before initialization
468-
if isinstance(self.initial_state, type):
469-
if issubclass(self.initial_state, BaseModel) and not issubclass(
470-
self.initial_state, FlowState
471-
):
472-
# Check if model has id field
473-
model_fields = getattr(self.initial_state, "model_fields", None)
474-
if not model_fields or "id" not in model_fields:
475-
raise ValueError("Flow state model must have an 'id' field")
476-
477-
# Handle persistence and potential ID conflicts
478-
stored_state = None
479-
if self._persistence is not None:
480-
if (
481-
restore_uuid
482-
and kwargs
483-
and "id" in kwargs
484-
and restore_uuid != kwargs["id"]
485-
):
486-
raise ValueError(
487-
f"Conflicting IDs provided: restore_uuid='{restore_uuid}' "
488-
f"vs kwargs['id']='{kwargs['id']}'. Use only one ID for restoration."
489-
)
465+
# Initialize state with initial values
466+
self._state = self._create_initial_state()
490467

491-
# Attempt to load state, prioritizing restore_uuid
492-
if restore_uuid:
493-
self._log_flow_event(f"Loading flow state from memory for UUID: {restore_uuid}", color="bold_yellow")
494-
stored_state = self._persistence.load_state(restore_uuid)
495-
if not stored_state:
496-
raise ValueError(
497-
f"No state found for restore_uuid='{restore_uuid}'"
498-
)
499-
elif kwargs and "id" in kwargs:
500-
self._log_flow_event(f"Loading flow state from memory for ID: {kwargs['id']}", color="bold_yellow")
501-
stored_state = self._persistence.load_state(kwargs["id"])
502-
if not stored_state:
503-
# For kwargs["id"], we allow creating new state if not found
504-
self._state = self._create_initial_state()
505-
if kwargs:
506-
self._initialize_state(kwargs)
507-
return
508-
509-
# Initialize state based on persistence and kwargs
510-
if stored_state:
511-
# Create initial state and restore from persistence
512-
self._state = self._create_initial_state()
513-
self._restore_state(stored_state)
514-
# Apply any additional kwargs to override specific fields
515-
if kwargs:
516-
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "id"}
517-
if filtered_kwargs:
518-
self._initialize_state(filtered_kwargs)
519-
else:
520-
# No stored state, create new state with initial values
521-
self._state = self._create_initial_state()
522-
# Apply any additional kwargs
523-
if kwargs:
524-
self._initialize_state(kwargs)
468+
# Apply any additional kwargs
469+
if kwargs:
470+
self._initialize_state(kwargs)
525471

526472
self._telemetry.flow_creation_span(self.__class__.__name__)
527473

@@ -635,18 +581,18 @@ def method_outputs(self) -> List[Any]:
635581
@property
636582
def flow_id(self) -> str:
637583
"""Returns the unique identifier of this flow instance.
638-
584+
639585
This property provides a consistent way to access the flow's unique identifier
640586
regardless of the underlying state implementation (dict or BaseModel).
641-
587+
642588
Returns:
643589
str: The flow's unique identifier, or an empty string if not found
644-
590+
645591
Note:
646592
This property safely handles both dictionary and BaseModel state types,
647593
returning an empty string if the ID cannot be retrieved rather than raising
648594
an exception.
649-
595+
650596
Example:
651597
```python
652598
flow = MyFlow()
@@ -656,7 +602,7 @@ def flow_id(self) -> str:
656602
try:
657603
if not hasattr(self, '_state'):
658604
return ""
659-
605+
660606
if isinstance(self._state, dict):
661607
return str(self._state.get("id", ""))
662608
elif isinstance(self._state, BaseModel):
@@ -731,7 +677,6 @@ def _restore_state(self, stored_state: Dict[str, Any]) -> None:
731677
"""
732678
# When restoring from persistence, use the stored ID
733679
stored_id = stored_state.get("id")
734-
self._log_flow_event(f"Restoring flow state from memory for ID: {stored_id}", color="bold_yellow")
735680
if not stored_id:
736681
raise ValueError("Stored state must have an 'id' field")
737682

@@ -755,17 +700,48 @@ def _restore_state(self, stored_state: Dict[str, Any]) -> None:
755700
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")
756701

757702
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
703+
"""Start the flow execution.
704+
705+
Args:
706+
inputs: Optional dictionary containing input values and potentially a state ID to restore
707+
"""
708+
# Handle state restoration if ID is provided in inputs
709+
if inputs and 'id' in inputs and self._persistence is not None:
710+
restore_uuid = inputs['id']
711+
stored_state = self._persistence.load_state(restore_uuid)
712+
713+
# Override the id in the state if it exists in inputs
714+
if 'id' in inputs:
715+
if isinstance(self._state, dict):
716+
self._state['id'] = inputs['id']
717+
elif isinstance(self._state, BaseModel):
718+
setattr(self._state, 'id', inputs['id'])
719+
720+
if stored_state:
721+
self._log_flow_event(f"Loading flow state from memory for UUID: {restore_uuid}", color="yellow")
722+
# Restore the state
723+
self._restore_state(stored_state)
724+
else:
725+
self._log_flow_event(f"No flow state found for UUID: {restore_uuid}", color="red")
726+
727+
# Apply any additional inputs after restoration
728+
filtered_inputs = {k: v for k, v in inputs.items() if k != 'id'}
729+
if filtered_inputs:
730+
self._initialize_state(filtered_inputs)
731+
732+
# Start flow execution
758733
self.event_emitter.send(
759734
self,
760735
event=FlowStartedEvent(
761736
type="flow_started",
762737
flow_name=self.__class__.__name__,
763738
),
764739
)
765-
self._log_flow_event(f"Flow started with ID: {self.flow_id}", color="yellow")
740+
self._log_flow_event(f"Flow started with ID: {self.flow_id}", color="bold_magenta")
766741

767-
if inputs is not None:
742+
if inputs is not None and 'id' not in inputs:
768743
self._initialize_state(inputs)
744+
769745
return asyncio.run(self.kickoff_async())
770746

771747
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
@@ -1010,18 +986,18 @@ async def _execute_single_listener(self, listener_name: str, result: Any) -> Non
1010986

1011987
def _log_flow_event(self, message: str, color: str = "yellow", level: str = "info") -> None:
1012988
"""Centralized logging method for flow events.
1013-
989+
1014990
This method provides a consistent interface for logging flow-related events,
1015991
combining both console output with colors and proper logging levels.
1016-
992+
1017993
Args:
1018994
message: The message to log
1019995
color: Color to use for console output (default: yellow)
1020996
Available colors: purple, red, bold_green, bold_purple,
1021-
bold_blue, yellow, bold_yellow
997+
bold_blue, yellow, yellow
1022998
level: Log level to use (default: info)
1023999
Supported levels: info, warning
1024-
1000+
10251001
Note:
10261002
This method uses the Printer utility for colored console output
10271003
and the standard logging module for log level support.
@@ -1031,7 +1007,7 @@ def _log_flow_event(self, message: str, color: str = "yellow", level: str = "inf
10311007
logger.info(message)
10321008
elif level == "warning":
10331009
logger.warning(message)
1034-
1010+
10351011
def plot(self, filename: str = "crewai_flow") -> None:
10361012
self._telemetry.flow_plotting_span(
10371013
self.__class__.__name__, list(self._methods.keys())

src/crewai/flow/persistence/decorators.py

+72-48
Original file line numberDiff line numberDiff line change
@@ -54,57 +54,44 @@ async def async_method(self):
5454

5555
class PersistenceDecorator:
5656
"""Class to handle flow state persistence with consistent logging."""
57-
57+
5858
_printer = Printer() # Class-level printer instance
59-
59+
6060
@classmethod
6161
def persist_state(cls, flow_instance: Any, method_name: str, persistence_instance: FlowPersistence) -> None:
6262
"""Persist flow state with proper error handling and logging.
63-
63+
6464
This method handles the persistence of flow state data, including proper
6565
error handling and colored console output for status updates.
66-
66+
6767
Args:
6868
flow_instance: The flow instance whose state to persist
6969
method_name: Name of the method that triggered persistence
7070
persistence_instance: The persistence backend to use
71-
71+
7272
Raises:
7373
ValueError: If flow has no state or state lacks an ID
7474
RuntimeError: If state persistence fails
7575
AttributeError: If flow instance lacks required state attributes
76-
77-
Note:
78-
Uses bold_yellow color for success messages and red for errors.
79-
All operations are logged at appropriate levels (info/error).
80-
81-
Example:
82-
```python
83-
@persist
84-
def my_flow_method(self):
85-
# Method implementation
86-
pass
87-
# State will be automatically persisted after method execution
88-
```
8976
"""
9077
try:
9178
state = getattr(flow_instance, 'state', None)
9279
if state is None:
9380
raise ValueError("Flow instance has no state")
94-
81+
9582
flow_uuid: Optional[str] = None
9683
if isinstance(state, dict):
9784
flow_uuid = state.get('id')
9885
elif isinstance(state, BaseModel):
9986
flow_uuid = getattr(state, 'id', None)
100-
87+
10188
if not flow_uuid:
10289
raise ValueError("Flow state must have an 'id' field for persistence")
103-
90+
10491
# Log state saving with consistent message
105-
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="bold_yellow")
92+
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan")
10693
logger.info(LOG_MESSAGES["save_state"].format(flow_uuid))
107-
94+
10895
try:
10996
persistence_instance.save_state(
11097
flow_uuid=flow_uuid,
@@ -154,44 +141,79 @@ class MyFlow(Flow[MyState]):
154141
def begin(self):
155142
pass
156143
"""
144+
157145
def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
158146
"""Decorator that handles both class and method decoration."""
159147
actual_persistence = persistence or SQLiteFlowPersistence()
160148

161149
if isinstance(target, type):
162150
# Class decoration
163-
class_methods = {}
151+
original_init = getattr(target, "__init__")
152+
153+
@functools.wraps(original_init)
154+
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
155+
if 'persistence' not in kwargs:
156+
kwargs['persistence'] = actual_persistence
157+
original_init(self, *args, **kwargs)
158+
159+
setattr(target, "__init__", new_init)
160+
161+
# Store original methods to preserve their decorators
162+
original_methods = {}
163+
164164
for name, method in target.__dict__.items():
165-
if callable(method) and hasattr(method, "__is_flow_method__"):
166-
# Wrap each flow method with persistence
167-
if asyncio.iscoroutinefunction(method):
168-
@functools.wraps(method)
169-
async def class_async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
170-
method_coro = method(self, *args, **kwargs)
171-
if asyncio.iscoroutine(method_coro):
172-
result = await method_coro
173-
else:
174-
result = method_coro
175-
PersistenceDecorator.persist_state(self, method.__name__, actual_persistence)
165+
if callable(method) and (
166+
hasattr(method, "__is_start_method__") or
167+
hasattr(method, "__trigger_methods__") or
168+
hasattr(method, "__condition_type__") or
169+
hasattr(method, "__is_flow_method__") or
170+
hasattr(method, "__is_router__")
171+
):
172+
original_methods[name] = method
173+
174+
# Create wrapped versions of the methods that include persistence
175+
for name, method in original_methods.items():
176+
if asyncio.iscoroutinefunction(method):
177+
# Create a closure to capture the current name and method
178+
def create_async_wrapper(method_name: str, original_method: Callable):
179+
@functools.wraps(original_method)
180+
async def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
181+
result = await original_method(self, *args, **kwargs)
182+
PersistenceDecorator.persist_state(self, method_name, actual_persistence)
176183
return result
177-
class_methods[name] = class_async_wrapper
178-
else:
179-
@functools.wraps(method)
180-
def class_sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
181-
result = method(self, *args, **kwargs)
182-
PersistenceDecorator.persist_state(self, method.__name__, actual_persistence)
184+
return method_wrapper
185+
186+
wrapped = create_async_wrapper(name, method)
187+
188+
# Preserve all original decorators and attributes
189+
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
190+
if hasattr(method, attr):
191+
setattr(wrapped, attr, getattr(method, attr))
192+
setattr(wrapped, "__is_flow_method__", True)
193+
194+
# Update the class with the wrapped method
195+
setattr(target, name, wrapped)
196+
else:
197+
# Create a closure to capture the current name and method
198+
def create_sync_wrapper(method_name: str, original_method: Callable):
199+
@functools.wraps(original_method)
200+
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
201+
result = original_method(self, *args, **kwargs)
202+
PersistenceDecorator.persist_state(self, method_name, actual_persistence)
183203
return result
184-
class_methods[name] = class_sync_wrapper
204+
return method_wrapper
205+
206+
wrapped = create_sync_wrapper(name, method)
185207

186-
# Preserve flow-specific attributes
208+
# Preserve all original decorators and attributes
187209
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
188210
if hasattr(method, attr):
189-
setattr(class_methods[name], attr, getattr(method, attr))
190-
setattr(class_methods[name], "__is_flow_method__", True)
211+
setattr(wrapped, attr, getattr(method, attr))
212+
setattr(wrapped, "__is_flow_method__", True)
213+
214+
# Update the class with the wrapped method
215+
setattr(target, name, wrapped)
191216

192-
# Update class with wrapped methods
193-
for name, method in class_methods.items():
194-
setattr(target, name, method)
195217
return target
196218
else:
197219
# Method decoration
@@ -208,6 +230,7 @@ async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) ->
208230
result = method_coro
209231
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence)
210232
return result
233+
211234
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
212235
if hasattr(method, attr):
213236
setattr(method_async_wrapper, attr, getattr(method, attr))
@@ -219,6 +242,7 @@ def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
219242
result = method(flow_instance, *args, **kwargs)
220243
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence)
221244
return result
245+
222246
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
223247
if hasattr(method, attr):
224248
setattr(method_sync_wrapper, attr, getattr(method, attr))

0 commit comments

Comments
 (0)