@@ -64,7 +64,7 @@ def __init__(
64
64
65
65
Args:
66
66
name (str): The name of the session manager index.
67
- session_tag (str): Tag to be added to entries to link to a specific
67
+ session_tag (Optional[ str] ): Tag to be added to entries to link to a specific
68
68
session. Defaults to instance uuid.
69
69
prefix (Optional[str]): Prefix for the keys for this session data.
70
70
Defaults to None and will be replaced with the index name.
@@ -105,7 +105,7 @@ def __init__(
105
105
106
106
self ._index .create (overwrite = False )
107
107
108
- self ._default_tag_filter = Tag (self .session_field_name ) == self ._session_tag
108
+ self ._default_session_filter = Tag (self .session_field_name ) == self ._session_tag
109
109
110
110
def clear (self ) -> None :
111
111
"""Clears the chat session history."""
@@ -115,24 +115,23 @@ def delete(self) -> None:
115
115
"""Clear all conversation keys and remove the search index."""
116
116
self ._index .delete (drop = True )
117
117
118
- def drop (self , id_field : Optional [str ] = None ) -> None :
118
+ def drop (self , id : Optional [str ] = None ) -> None :
119
119
"""Remove a specific exchange from the conversation history.
120
120
121
121
Args:
122
- id_field (Optional[str]): The id_field of the entry to delete.
122
+ id (Optional[str]): The id of the session entry to delete.
123
123
If None then the last entry is deleted.
124
124
"""
125
- if id_field :
126
- sep = self ._index .key_separator
127
- key = sep .join ([self ._index .schema .index .name , id_field ])
128
- else :
129
- key = self .get_recent (top_k = 1 , raw = True )[0 ]["id" ] # type: ignore
130
- self ._index .client .delete (key ) # type: ignore
125
+ if id is None :
126
+ id = self .get_recent (top_k = 1 , raw = True )[0 ][self .id_field_name ] # type: ignore
127
+
128
+ self ._index .client .delete (self ._index .key (id )) # type: ignore
131
129
132
130
@property
133
131
def messages (self ) -> Union [List [str ], List [Dict [str , str ]]]:
134
132
"""Returns the full chat history."""
135
133
# TODO raw or as_text?
134
+ # TODO refactor method to use get_recent and support other session tags
136
135
return_fields = [
137
136
self .id_field_name ,
138
137
self .session_field_name ,
@@ -143,7 +142,7 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]:
143
142
]
144
143
145
144
query = FilterQuery (
146
- filter_expression = self ._default_tag_filter ,
145
+ filter_expression = self ._default_session_filter ,
147
146
return_fields = return_fields ,
148
147
)
149
148
@@ -159,7 +158,7 @@ def get_relevant(
159
158
as_text : bool = False ,
160
159
top_k : int = 5 ,
161
160
fall_back : bool = False ,
162
- tag_filter : Optional [FilterExpression ] = None ,
161
+ session_tag : Optional [str ] = None ,
163
162
raw : bool = False ,
164
163
) -> Union [List [str ], List [Dict [str , str ]]]:
165
164
"""Searches the chat history for information semantically related to
@@ -177,8 +176,8 @@ def get_relevant(
177
176
top_k (int): The number of previous messages to return. Default is 5.
178
177
fallback (bool): Whether to drop back to recent conversation history
179
178
if no relevant context is found.
180
- tag_filter (Optional[FilterExpression ]): The tag filter to filter results
181
- by . Defaults to None and all messages will be searched .
179
+ session_tag (Optional[str ]): Tag to be added to entries to link to a specific
180
+ session . Defaults to instance uuid .
182
181
raw (bool): Whether to return the full Redis hash entry or just the
183
182
message.
184
183
@@ -202,14 +201,20 @@ def get_relevant(
202
201
self .vector_field_name ,
203
202
]
204
203
204
+ session_filter = (
205
+ Tag (self .session_field_name ) == session_tag
206
+ if session_tag
207
+ else self ._default_session_filter
208
+ )
209
+
205
210
query = RangeQuery (
206
211
vector = self ._vectorizer .embed (prompt ),
207
212
vector_field_name = self .vector_field_name ,
208
213
return_fields = return_fields ,
209
214
distance_threshold = self ._distance_threshold ,
210
215
num_results = top_k ,
211
216
return_score = True ,
212
- filter_expression = tag_filter or self . _default_tag_filter ,
217
+ filter_expression = session_filter ,
213
218
)
214
219
hits = self ._index .query (query )
215
220
@@ -225,7 +230,7 @@ def get_recent(
225
230
top_k : int = 5 ,
226
231
as_text : bool = False ,
227
232
raw : bool = False ,
228
- tag_filter : Optional [FilterExpression ] = None ,
233
+ session_tag : Optional [str ] = None ,
229
234
) -> Union [List [str ], List [Dict [str , str ]]]:
230
235
"""Retreive the recent conversation history in sequential order.
231
236
@@ -235,8 +240,8 @@ def get_recent(
235
240
or list of alternating prompts and responses.
236
241
raw (bool): Whether to return the full Redis hash entry or just the
237
242
prompt and response
238
- tag_filter (Optional[FilterExpression ]): The tag filter to filter
239
- results by . Defaults to None and all messages will be searched .
243
+ session_tag (Optional[str ]): Tag to be added to entries to link to a specific
244
+ session . Defaults to instance uuid .
240
245
241
246
Returns:
242
247
Union[str, List[str]]: A single string transcription of the session
@@ -257,8 +262,14 @@ def get_recent(
257
262
self .timestamp_field_name ,
258
263
]
259
264
265
+ session_filter = (
266
+ Tag (self .session_field_name ) == session_tag
267
+ if session_tag
268
+ else self ._default_session_filter
269
+ )
270
+
260
271
query = FilterQuery (
261
- filter_expression = tag_filter or self . _default_tag_filter ,
272
+ filter_expression = session_filter ,
262
273
return_fields = return_fields ,
263
274
num_results = top_k ,
264
275
)
@@ -288,6 +299,8 @@ def store(
288
299
Args:
289
300
prompt (str): The user prompt to the LLM.
290
301
response (str): The corresponding LLM response.
302
+ session_tag (Optional[str]): Tag to be added to entries to link to a specific
303
+ session. Defaults to instance uuid.
291
304
"""
292
305
self .add_messages (
293
306
[
@@ -306,6 +319,8 @@ def add_messages(
306
319
307
320
Args:
308
321
messages (List[Dict[str, str]]): The list of user prompts and LLM responses.
322
+ session_tag (Optional[str]): Tag to be added to entries to link to a specific
323
+ session. Defaults to instance uuid.
309
324
"""
310
325
sep = self ._index .key_separator
311
326
session_tag = session_tag or self ._session_tag
@@ -337,5 +352,7 @@ def add_message(
337
352
338
353
Args:
339
354
message (Dict[str,str]): The user prompt or LLM response.
355
+ session_tag (Optional[str]): Tag to be added to entries to link to a specific
356
+ session. Defaults to instance uuid.
340
357
"""
341
358
self .add_messages ([message ], session_tag )
0 commit comments