@@ -64,7 +64,7 @@ def __init__(
6464
6565 Args:
6666 name (str): The name of the session manager index.
67- session_tag (str): Tag to be added to entries to link to a specific
67+ session_tag (Optional[ str] ): Tag to be added to entries to link to a specific
6868 session. Defaults to instance uuid.
6969 prefix (Optional[str]): Prefix for the keys for this session data.
7070 Defaults to None and will be replaced with the index name.
@@ -105,7 +105,7 @@ def __init__(
105105
106106 self ._index .create (overwrite = False )
107107
108- self ._default_tag_filter = Tag (self .session_field_name ) == self ._session_tag
108+ self ._default_session_filter = Tag (self .session_field_name ) == self ._session_tag
109109
110110 def clear (self ) -> None :
111111 """Clears the chat session history."""
@@ -115,24 +115,23 @@ def delete(self) -> None:
115115 """Clear all conversation keys and remove the search index."""
116116 self ._index .delete (drop = True )
117117
118- def drop (self , id_field : Optional [str ] = None ) -> None :
118+ def drop (self , id : Optional [str ] = None ) -> None :
119119 """Remove a specific exchange from the conversation history.
120120
121121 Args:
122- id_field (Optional[str]): The id_field of the entry to delete.
122+ id (Optional[str]): The id of the session entry to delete.
123123 If None then the last entry is deleted.
124124 """
125- if id_field :
126- sep = self ._index .key_separator
127- key = sep .join ([self ._index .schema .index .name , id_field ])
128- else :
129- key = self .get_recent (top_k = 1 , raw = True )[0 ]["id" ] # type: ignore
130- self ._index .client .delete (key ) # type: ignore
125+ if id is None :
126+ id = self .get_recent (top_k = 1 , raw = True )[0 ][self .id_field_name ] # type: ignore
127+
128+ self ._index .client .delete (self ._index .key (id )) # type: ignore
131129
132130 @property
133131 def messages (self ) -> Union [List [str ], List [Dict [str , str ]]]:
134132 """Returns the full chat history."""
135133 # TODO raw or as_text?
134+ # TODO refactor method to use get_recent and support other session tags
136135 return_fields = [
137136 self .id_field_name ,
138137 self .session_field_name ,
@@ -143,7 +142,7 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]:
143142 ]
144143
145144 query = FilterQuery (
146- filter_expression = self ._default_tag_filter ,
145+ filter_expression = self ._default_session_filter ,
147146 return_fields = return_fields ,
148147 )
149148
@@ -159,7 +158,7 @@ def get_relevant(
159158 as_text : bool = False ,
160159 top_k : int = 5 ,
161160 fall_back : bool = False ,
162- tag_filter : Optional [FilterExpression ] = None ,
161+ session_tag : Optional [str ] = None ,
163162 raw : bool = False ,
164163 ) -> Union [List [str ], List [Dict [str , str ]]]:
165164 """Searches the chat history for information semantically related to
@@ -177,8 +176,8 @@ def get_relevant(
177176 top_k (int): The number of previous messages to return. Default is 5.
178177 fallback (bool): Whether to drop back to recent conversation history
179178 if no relevant context is found.
180- tag_filter (Optional[FilterExpression ]): The tag filter to filter results
181- by . Defaults to None and all messages will be searched .
179+ session_tag (Optional[str ]): Tag to be added to entries to link to a specific
180+ session . Defaults to instance uuid .
182181 raw (bool): Whether to return the full Redis hash entry or just the
183182 message.
184183
@@ -202,14 +201,20 @@ def get_relevant(
202201 self .vector_field_name ,
203202 ]
204203
204+ session_filter = (
205+ Tag (self .session_field_name ) == session_tag
206+ if session_tag
207+ else self ._default_session_filter
208+ )
209+
205210 query = RangeQuery (
206211 vector = self ._vectorizer .embed (prompt ),
207212 vector_field_name = self .vector_field_name ,
208213 return_fields = return_fields ,
209214 distance_threshold = self ._distance_threshold ,
210215 num_results = top_k ,
211216 return_score = True ,
212- filter_expression = tag_filter or self . _default_tag_filter ,
217+ filter_expression = session_filter ,
213218 )
214219 hits = self ._index .query (query )
215220
@@ -225,7 +230,7 @@ def get_recent(
225230 top_k : int = 5 ,
226231 as_text : bool = False ,
227232 raw : bool = False ,
228- tag_filter : Optional [FilterExpression ] = None ,
233+ session_tag : Optional [str ] = None ,
229234 ) -> Union [List [str ], List [Dict [str , str ]]]:
230235 """Retreive the recent conversation history in sequential order.
231236
@@ -235,8 +240,8 @@ def get_recent(
235240 or list of alternating prompts and responses.
236241 raw (bool): Whether to return the full Redis hash entry or just the
237242 prompt and response
238- tag_filter (Optional[FilterExpression ]): The tag filter to filter
239- results by . Defaults to None and all messages will be searched .
243+ session_tag (Optional[str ]): Tag to be added to entries to link to a specific
244+ session . Defaults to instance uuid .
240245
241246 Returns:
242247 Union[str, List[str]]: A single string transcription of the session
@@ -257,8 +262,14 @@ def get_recent(
257262 self .timestamp_field_name ,
258263 ]
259264
265+ session_filter = (
266+ Tag (self .session_field_name ) == session_tag
267+ if session_tag
268+ else self ._default_session_filter
269+ )
270+
260271 query = FilterQuery (
261- filter_expression = tag_filter or self . _default_tag_filter ,
272+ filter_expression = session_filter ,
262273 return_fields = return_fields ,
263274 num_results = top_k ,
264275 )
@@ -288,6 +299,8 @@ def store(
288299 Args:
289300 prompt (str): The user prompt to the LLM.
290301 response (str): The corresponding LLM response.
302+ session_tag (Optional[str]): Tag to be added to entries to link to a specific
303+ session. Defaults to instance uuid.
291304 """
292305 self .add_messages (
293306 [
@@ -306,6 +319,8 @@ def add_messages(
306319
307320 Args:
308321 messages (List[Dict[str, str]]): The list of user prompts and LLM responses.
322+ session_tag (Optional[str]): Tag to be added to entries to link to a specific
323+ session. Defaults to instance uuid.
309324 """
310325 sep = self ._index .key_separator
311326 session_tag = session_tag or self ._session_tag
@@ -337,5 +352,7 @@ def add_message(
337352
338353 Args:
339354 message (Dict[str,str]): The user prompt or LLM response.
355+ session_tag (Optional[str]): Tag to be added to entries to link to a specific
356+ session. Defaults to instance uuid.
340357 """
341358 self .add_messages ([message ], session_tag )
0 commit comments