@@ -54,57 +54,44 @@ async def async_method(self):
54
54
55
55
class PersistenceDecorator :
56
56
"""Class to handle flow state persistence with consistent logging."""
57
-
57
+
58
58
_printer = Printer () # Class-level printer instance
59
-
59
+
60
60
@classmethod
61
61
def persist_state (cls , flow_instance : Any , method_name : str , persistence_instance : FlowPersistence ) -> None :
62
62
"""Persist flow state with proper error handling and logging.
63
-
63
+
64
64
This method handles the persistence of flow state data, including proper
65
65
error handling and colored console output for status updates.
66
-
66
+
67
67
Args:
68
68
flow_instance: The flow instance whose state to persist
69
69
method_name: Name of the method that triggered persistence
70
70
persistence_instance: The persistence backend to use
71
-
71
+
72
72
Raises:
73
73
ValueError: If flow has no state or state lacks an ID
74
74
RuntimeError: If state persistence fails
75
75
AttributeError: If flow instance lacks required state attributes
76
-
77
- Note:
78
- Uses bold_yellow color for success messages and red for errors.
79
- All operations are logged at appropriate levels (info/error).
80
-
81
- Example:
82
- ```python
83
- @persist
84
- def my_flow_method(self):
85
- # Method implementation
86
- pass
87
- # State will be automatically persisted after method execution
88
- ```
89
76
"""
90
77
try :
91
78
state = getattr (flow_instance , 'state' , None )
92
79
if state is None :
93
80
raise ValueError ("Flow instance has no state" )
94
-
81
+
95
82
flow_uuid : Optional [str ] = None
96
83
if isinstance (state , dict ):
97
84
flow_uuid = state .get ('id' )
98
85
elif isinstance (state , BaseModel ):
99
86
flow_uuid = getattr (state , 'id' , None )
100
-
87
+
101
88
if not flow_uuid :
102
89
raise ValueError ("Flow state must have an 'id' field for persistence" )
103
-
90
+
104
91
# Log state saving with consistent message
105
- cls ._printer .print (LOG_MESSAGES ["save_state" ].format (flow_uuid ), color = "bold_yellow " )
92
+ cls ._printer .print (LOG_MESSAGES ["save_state" ].format (flow_uuid ), color = "cyan " )
106
93
logger .info (LOG_MESSAGES ["save_state" ].format (flow_uuid ))
107
-
94
+
108
95
try :
109
96
persistence_instance .save_state (
110
97
flow_uuid = flow_uuid ,
@@ -154,44 +141,79 @@ class MyFlow(Flow[MyState]):
154
141
def begin(self):
155
142
pass
156
143
"""
144
+
157
145
def decorator (target : Union [Type , Callable [..., T ]]) -> Union [Type , Callable [..., T ]]:
158
146
"""Decorator that handles both class and method decoration."""
159
147
actual_persistence = persistence or SQLiteFlowPersistence ()
160
148
161
149
if isinstance (target , type ):
162
150
# Class decoration
163
- class_methods = {}
151
+ original_init = getattr (target , "__init__" )
152
+
153
+ @functools .wraps (original_init )
154
+ def new_init (self : Any , * args : Any , ** kwargs : Any ) -> None :
155
+ if 'persistence' not in kwargs :
156
+ kwargs ['persistence' ] = actual_persistence
157
+ original_init (self , * args , ** kwargs )
158
+
159
+ setattr (target , "__init__" , new_init )
160
+
161
+ # Store original methods to preserve their decorators
162
+ original_methods = {}
163
+
164
164
for name , method in target .__dict__ .items ():
165
- if callable (method ) and hasattr (method , "__is_flow_method__" ):
166
- # Wrap each flow method with persistence
167
- if asyncio .iscoroutinefunction (method ):
168
- @functools .wraps (method )
169
- async def class_async_wrapper (self : Any , * args : Any , ** kwargs : Any ) -> Any :
170
- method_coro = method (self , * args , ** kwargs )
171
- if asyncio .iscoroutine (method_coro ):
172
- result = await method_coro
173
- else :
174
- result = method_coro
175
- PersistenceDecorator .persist_state (self , method .__name__ , actual_persistence )
165
+ if callable (method ) and (
166
+ hasattr (method , "__is_start_method__" ) or
167
+ hasattr (method , "__trigger_methods__" ) or
168
+ hasattr (method , "__condition_type__" ) or
169
+ hasattr (method , "__is_flow_method__" ) or
170
+ hasattr (method , "__is_router__" )
171
+ ):
172
+ original_methods [name ] = method
173
+
174
+ # Create wrapped versions of the methods that include persistence
175
+ for name , method in original_methods .items ():
176
+ if asyncio .iscoroutinefunction (method ):
177
+ # Create a closure to capture the current name and method
178
+ def create_async_wrapper (method_name : str , original_method : Callable ):
179
+ @functools .wraps (original_method )
180
+ async def method_wrapper (self : Any , * args : Any , ** kwargs : Any ) -> Any :
181
+ result = await original_method (self , * args , ** kwargs )
182
+ PersistenceDecorator .persist_state (self , method_name , actual_persistence )
176
183
return result
177
- class_methods [name ] = class_async_wrapper
178
- else :
179
- @functools .wraps (method )
180
- def class_sync_wrapper (self : Any , * args : Any , ** kwargs : Any ) -> Any :
181
- result = method (self , * args , ** kwargs )
182
- PersistenceDecorator .persist_state (self , method .__name__ , actual_persistence )
184
+ return method_wrapper
185
+
186
+ wrapped = create_async_wrapper (name , method )
187
+
188
+ # Preserve all original decorators and attributes
189
+ for attr in ["__is_start_method__" , "__trigger_methods__" , "__condition_type__" , "__is_router__" ]:
190
+ if hasattr (method , attr ):
191
+ setattr (wrapped , attr , getattr (method , attr ))
192
+ setattr (wrapped , "__is_flow_method__" , True )
193
+
194
+ # Update the class with the wrapped method
195
+ setattr (target , name , wrapped )
196
+ else :
197
+ # Create a closure to capture the current name and method
198
+ def create_sync_wrapper (method_name : str , original_method : Callable ):
199
+ @functools .wraps (original_method )
200
+ def method_wrapper (self : Any , * args : Any , ** kwargs : Any ) -> Any :
201
+ result = original_method (self , * args , ** kwargs )
202
+ PersistenceDecorator .persist_state (self , method_name , actual_persistence )
183
203
return result
184
- class_methods [name ] = class_sync_wrapper
204
+ return method_wrapper
205
+
206
+ wrapped = create_sync_wrapper (name , method )
185
207
186
- # Preserve flow-specific attributes
208
+ # Preserve all original decorators and attributes
187
209
for attr in ["__is_start_method__" , "__trigger_methods__" , "__condition_type__" , "__is_router__" ]:
188
210
if hasattr (method , attr ):
189
- setattr (class_methods [name ], attr , getattr (method , attr ))
190
- setattr (class_methods [name ], "__is_flow_method__" , True )
211
+ setattr (wrapped , attr , getattr (method , attr ))
212
+ setattr (wrapped , "__is_flow_method__" , True )
213
+
214
+ # Update the class with the wrapped method
215
+ setattr (target , name , wrapped )
191
216
192
- # Update class with wrapped methods
193
- for name , method in class_methods .items ():
194
- setattr (target , name , method )
195
217
return target
196
218
else :
197
219
# Method decoration
@@ -208,6 +230,7 @@ async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) ->
208
230
result = method_coro
209
231
PersistenceDecorator .persist_state (flow_instance , method .__name__ , actual_persistence )
210
232
return result
233
+
211
234
for attr in ["__is_start_method__" , "__trigger_methods__" , "__condition_type__" , "__is_router__" ]:
212
235
if hasattr (method , attr ):
213
236
setattr (method_async_wrapper , attr , getattr (method , attr ))
@@ -219,6 +242,7 @@ def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
219
242
result = method (flow_instance , * args , ** kwargs )
220
243
PersistenceDecorator .persist_state (flow_instance , method .__name__ , actual_persistence )
221
244
return result
245
+
222
246
for attr in ["__is_start_method__" , "__trigger_methods__" , "__condition_type__" , "__is_router__" ]:
223
247
if hasattr (method , attr ):
224
248
setattr (method_sync_wrapper , attr , getattr (method , attr ))
0 commit comments