Skip to content

Commit 0044981

Browse files
use session_tag nomenclature
1 parent a5ad671 commit 0044981

File tree

5 files changed

+82
-77
lines changed

5 files changed

+82
-77
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
class SemanticCache(BaseLLMCache):
1414
"""Semantic Cache for Large Language Models."""
1515

16-
entry_id_field_name: str = "id"
16+
entry_id_field_name: str = "_id"
1717
prompt_field_name: str = "prompt"
1818
vector_field_name: str = "prompt_vector"
1919
response_field_name: str = "response"

redisvl/extensions/session_manager/base_session.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
class BaseSessionManager:
10-
id_field_name: str = "id_field"
10+
id_field_name: str = "_id"
1111
role_field_name: str = "role"
1212
content_field_name: str = "content"
1313
tool_field_name: str = "tool_call_id"
@@ -61,7 +61,7 @@ def get_recent(
6161
top_k: int = 5,
6262
as_text: bool = False,
6363
raw: bool = False,
64-
tag_filter: Optional[FilterExpression] = None,
64+
session_tag: Optional[str] = None,
6565
) -> Union[List[str], List[Dict[str, str]]]:
6666
"""Retreive the recent conversation history in sequential order.
6767
@@ -72,8 +72,8 @@ def get_recent(
7272
or list of alternating prompts and responses.
7373
raw (bool): Whether to return the full Redis hash entry or just the
7474
prompt and response
75-
tag_filter (Optional[FilterExpression]) : The tag filter to filter
76-
results by. Default is None and all sessions are searched.
75+
session_tag (str): Tag to be added to entries to link to a specific
76+
session. Defaults to instance uuid.
7777
7878
Returns:
7979
Union[str, List[str]]: A single string transcription of the session

redisvl/extensions/session_manager/semantic_session.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
6565
Args:
6666
name (str): The name of the session manager index.
67-
session_tag (str): Tag to be added to entries to link to a specific
67+
session_tag (Optional[str]): Tag to be added to entries to link to a specific
6868
session. Defaults to instance uuid.
6969
prefix (Optional[str]): Prefix for the keys for this session data.
7070
Defaults to None and will be replaced with the index name.
@@ -105,7 +105,7 @@ def __init__(
105105

106106
self._index.create(overwrite=False)
107107

108-
self._default_tag_filter = Tag(self.session_field_name) == self._session_tag
108+
self._default_session_filter = Tag(self.session_field_name) == self._session_tag
109109

110110
def clear(self) -> None:
111111
"""Clears the chat session history."""
@@ -115,24 +115,23 @@ def delete(self) -> None:
115115
"""Clear all conversation keys and remove the search index."""
116116
self._index.delete(drop=True)
117117

118-
def drop(self, id_field: Optional[str] = None) -> None:
118+
def drop(self, id: Optional[str] = None) -> None:
119119
"""Remove a specific exchange from the conversation history.
120120
121121
Args:
122-
id_field (Optional[str]): The id_field of the entry to delete.
122+
id (Optional[str]): The id of the session entry to delete.
123123
If None then the last entry is deleted.
124124
"""
125-
if id_field:
126-
sep = self._index.key_separator
127-
key = sep.join([self._index.schema.index.name, id_field])
128-
else:
129-
key = self.get_recent(top_k=1, raw=True)[0]["id"] # type: ignore
130-
self._index.client.delete(key) # type: ignore
125+
if id is None:
126+
id = self.get_recent(top_k=1, raw=True)[0][self.id_field_name] # type: ignore
127+
128+
self._index.client.delete(self._index.key(id)) # type: ignore
131129

132130
@property
133131
def messages(self) -> Union[List[str], List[Dict[str, str]]]:
134132
"""Returns the full chat history."""
135133
# TODO raw or as_text?
134+
# TODO refactor method to use get_recent and support other session tags
136135
return_fields = [
137136
self.id_field_name,
138137
self.session_field_name,
@@ -143,7 +142,7 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]:
143142
]
144143

145144
query = FilterQuery(
146-
filter_expression=self._default_tag_filter,
145+
filter_expression=self._default_session_filter,
147146
return_fields=return_fields,
148147
)
149148

@@ -159,7 +158,7 @@ def get_relevant(
159158
as_text: bool = False,
160159
top_k: int = 5,
161160
fall_back: bool = False,
162-
tag_filter: Optional[FilterExpression] = None,
161+
session_tag: Optional[str] = None,
163162
raw: bool = False,
164163
) -> Union[List[str], List[Dict[str, str]]]:
165164
"""Searches the chat history for information semantically related to
@@ -177,8 +176,8 @@ def get_relevant(
177176
top_k (int): The number of previous messages to return. Default is 5.
178177
fallback (bool): Whether to drop back to recent conversation history
179178
if no relevant context is found.
180-
tag_filter (Optional[FilterExpression]): The tag filter to filter results
181-
by. Defaults to None and all messages will be searched.
179+
session_tag (Optional[str]): Tag to be added to entries to link to a specific
180+
session. Defaults to instance uuid.
182181
raw (bool): Whether to return the full Redis hash entry or just the
183182
message.
184183
@@ -202,14 +201,20 @@ def get_relevant(
202201
self.vector_field_name,
203202
]
204203

204+
session_filter = (
205+
Tag(self.session_field_name) == session_tag
206+
if session_tag
207+
else self._default_session_filter
208+
)
209+
205210
query = RangeQuery(
206211
vector=self._vectorizer.embed(prompt),
207212
vector_field_name=self.vector_field_name,
208213
return_fields=return_fields,
209214
distance_threshold=self._distance_threshold,
210215
num_results=top_k,
211216
return_score=True,
212-
filter_expression=tag_filter or self._default_tag_filter,
217+
filter_expression=session_filter,
213218
)
214219
hits = self._index.query(query)
215220

@@ -225,7 +230,7 @@ def get_recent(
225230
top_k: int = 5,
226231
as_text: bool = False,
227232
raw: bool = False,
228-
tag_filter: Optional[FilterExpression] = None,
233+
session_tag: Optional[str] = None,
229234
) -> Union[List[str], List[Dict[str, str]]]:
230235
"""Retreive the recent conversation history in sequential order.
231236
@@ -235,8 +240,8 @@ def get_recent(
235240
or list of alternating prompts and responses.
236241
raw (bool): Whether to return the full Redis hash entry or just the
237242
prompt and response
238-
tag_filter (Optional[FilterExpression]): The tag filter to filter
239-
results by. Defaults to None and all messages will be searched.
243+
session_tag (Optional[str]): Tag to be added to entries to link to a specific
244+
session. Defaults to instance uuid.
240245
241246
Returns:
242247
Union[str, List[str]]: A single string transcription of the session
@@ -257,8 +262,14 @@ def get_recent(
257262
self.timestamp_field_name,
258263
]
259264

265+
session_filter = (
266+
Tag(self.session_field_name) == session_tag
267+
if session_tag
268+
else self._default_session_filter
269+
)
270+
260271
query = FilterQuery(
261-
filter_expression=tag_filter or self._default_tag_filter,
272+
filter_expression=session_filter,
262273
return_fields=return_fields,
263274
num_results=top_k,
264275
)
@@ -288,6 +299,8 @@ def store(
288299
Args:
289300
prompt (str): The user prompt to the LLM.
290301
response (str): The corresponding LLM response.
302+
session_tag (Optional[str]): Tag to be added to entries to link to a specific
303+
session. Defaults to instance uuid.
291304
"""
292305
self.add_messages(
293306
[
@@ -306,6 +319,8 @@ def add_messages(
306319
307320
Args:
308321
messages (List[Dict[str, str]]): The list of user prompts and LLM responses.
322+
session_tag (Optional[str]): Tag to be added to entries to link to a specific
323+
session. Defaults to instance uuid.
309324
"""
310325
sep = self._index.key_separator
311326
session_tag = session_tag or self._session_tag
@@ -337,5 +352,7 @@ def add_message(
337352
338353
Args:
339354
message (Dict[str,str]): The user prompt or LLM response.
355+
session_tag (Optional[str]): Tag to be added to entries to link to a specific
356+
session. Defaults to instance uuid.
340357
"""
341358
self.add_messages([message], session_tag)

redisvl/extensions/session_manager/standard_session.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from redisvl.extensions.session_manager import BaseSessionManager
77
from redisvl.index import SearchIndex
88
from redisvl.query import FilterQuery
9-
from redisvl.query.filter import FilterExpression, Tag
9+
from redisvl.query.filter import Tag
1010
from redisvl.schema.schema import IndexSchema
1111

1212

@@ -29,7 +29,6 @@ def from_params(cls, name: str, prefix: str):
2929

3030
class StandardSessionManager(BaseSessionManager):
3131
session_field_name: str = "session_tag"
32-
user_field_name: str = "user_tag"
3332

3433
def __init__(
3534
self,
@@ -51,7 +50,7 @@ def __init__(
5150
Args:
5251
name (str): The name of the session manager index.
5352
session_tag (Optional[str]): Tag to be added to entries to link to a specific
54-
session.
53+
session. Defaults to instance uuid.
5554
prefix (Optional[str]): Prefix for the keys for this session data.
5655
Defaults to None and will be replaced with the index name.
5756
redis_client (Optional[Redis]): A Redis client instance. Defaults to
@@ -69,28 +68,17 @@ def __init__(
6968
prefix = prefix or name
7069

7170
schema = StandardSessionIndexSchema.from_params(name, prefix)
72-
self._index = SearchIndex(schema=schema)
73-
if redis_client:
74-
self._index.set_client(redis_client)
75-
else:
76-
self._index.connect(redis_url=redis_url)
7771

78-
self._index.create(overwrite=False)
79-
80-
prefix = prefix or name
81-
82-
schema = StandardSessionIndexSchema.from_params(name, prefix)
8372
self._index = SearchIndex(schema=schema)
8473

85-
# handle redis connection
8674
if redis_client:
8775
self._index.set_client(redis_client)
88-
elif redis_url:
89-
self._index.connect(redis_url=redis_url, **connection_kwargs)
76+
else:
77+
self._index.connect(redis_url=redis_url)
9078

9179
self._index.create(overwrite=False)
9280

93-
self._default_tag_filter = Tag(self.session_field_name) == self._session_tag
81+
self._default_session_filter = Tag(self.session_field_name) == self._session_tag
9482

9583
def clear(self) -> None:
9684
"""Clears the chat session history."""
@@ -100,24 +88,23 @@ def delete(self) -> None:
10088
"""Clear all conversation keys and remove the search index."""
10189
self._index.delete(drop=True)
10290

103-
def drop(self, id_field: Optional[str] = None) -> None:
91+
def drop(self, id: Optional[str] = None) -> None:
10492
"""Remove a specific exchange from the conversation history.
10593
10694
Args:
107-
id_field (Optional[str]): The id_field of the entry to delete.
95+
id (Optional[str]): The id of the session entry to delete.
10896
If None then the last entry is deleted.
10997
"""
110-
if id_field:
111-
sep = self._index.key_separator
112-
key = sep.join([self._index.schema.index.name, id_field])
113-
else:
114-
key = self.get_recent(top_k=1, raw=True)[0]["id"] # type: ignore
115-
self._index.client.delete(key) # type: ignore
98+
if id is None:
99+
id = self.get_recent(top_k=1, raw=True)[0][self.id_field_name] # type: ignore
100+
101+
self._index.client.delete(self._index.key(id)) # type: ignore
116102

117103
@property
118104
def messages(self) -> Union[List[str], List[Dict[str, str]]]:
119105
"""Returns the full chat history."""
120106
# TODO raw or as_text?
107+
# TODO refactor this method to use get_recent and support other session tags?
121108
return_fields = [
122109
self.id_field_name,
123110
self.session_field_name,
@@ -128,7 +115,7 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]:
128115
]
129116

130117
query = FilterQuery(
131-
filter_expression=self._default_tag_filter,
118+
filter_expression=self._default_session_filter,
132119
return_fields=return_fields,
133120
)
134121

@@ -143,7 +130,7 @@ def get_recent(
143130
top_k: int = 5,
144131
as_text: bool = False,
145132
raw: bool = False,
146-
tag_filter: Optional[FilterExpression] = None,
133+
session_tag: Optional[str] = None,
147134
) -> Union[List[str], List[Dict[str, str]]]:
148135
"""Retreive the recent conversation history in sequential order.
149136
@@ -153,8 +140,8 @@ def get_recent(
153140
or list of alternating prompts and responses.
154141
raw (bool): Whether to return the full Redis hash entry or just the
155142
prompt and response
156-
tag_filter (Optional[FilterExpression]) : The tag filter to filter
157-
results by. Default is None and all sessions are searched.
143+
session_tag (Optional[str]): Tag to be added to entries to link to a specific
144+
session. Defaults to instance uuid.
158145
159146
Returns:
160147
Union[str, List[str]]: A single string transcription of the session
@@ -175,8 +162,14 @@ def get_recent(
175162
self.timestamp_field_name,
176163
]
177164

165+
session_filter = (
166+
Tag(self.session_field_name) == session_tag
167+
if session_tag
168+
else self._default_session_filter
169+
)
170+
178171
query = FilterQuery(
179-
filter_expression=tag_filter or self._default_tag_filter,
172+
filter_expression=session_filter,
180173
return_fields=return_fields,
181174
num_results=top_k,
182175
)
@@ -199,7 +192,8 @@ def store(
199192
Args:
200193
prompt (str): The user prompt to the LLM.
201194
response (str): The corresponding LLM response.
202-
session_tag (Optional[str]): The tag to mark the message with. Defaults to None.
195+
session_tag (Optional[str]): Tag to be added to entries to link to a specific
196+
session. Defaults to instance uuid.
203197
"""
204198
self.add_messages(
205199
[
@@ -218,7 +212,8 @@ def add_messages(
218212
219213
Args:
220214
messages (List[Dict[str, str]]): The list of user prompts and LLM responses.
221-
session_tag (Optional[str]): The tag to mark the messages with. Defaults to None.
215+
session_tag (Optional[str]): Tag to be added to entries to link to a specific
216+
session. Defaults to instance uuid.
222217
"""
223218
sep = self._index.key_separator
224219
session_tag = session_tag or self._session_tag
@@ -249,6 +244,7 @@ def add_message(
249244
250245
Args:
251246
message (Dict[str,str]): The user prompt or LLM response.
252-
session_tag (Optional[str]): The tag to mark the message with. Defaults to None.
247+
session_tag (Optional[str]): Tag to be added to entries to link to a specific
248+
session. Defaults to instance uuid.
253249
"""
254250
self.add_messages([message], session_tag)

0 commit comments

Comments
 (0)