Skip to content
Merged
7 changes: 0 additions & 7 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,3 @@ def clear_db(redis):
def app_name():
return "test_app"

@pytest.fixture
def session_tag():
return "123"

@pytest.fixture
def user_tag():
return "abc"
2 changes: 1 addition & 1 deletion docs/user_guide/session_manager_07.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
],
"source": [
"from redisvl.extensions.session_manager import SemanticSessionManager\n",
"user_session = SemanticSessionManager(name='llm_chef', session_tag='123', user_tag='abc')\n",
"user_session = SemanticSessionManager(name='llm_chef')\n",
"user_session.add_message({\"role\":\"system\", \"content\":\"You are a helpful chef, assisting people in making delicious meals\"})\n",
"\n",
"client = CohereClient()"
Expand Down
5 changes: 3 additions & 2 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class SemanticCache(BaseLLMCache):
"""Semantic Cache for Large Language Models."""

entry_id_field_name: str = "id"
entry_id_field_name: str = "_id"
prompt_field_name: str = "prompt"
vector_field_name: str = "prompt_vector"
response_field_name: str = "response"
Expand Down Expand Up @@ -222,7 +222,8 @@ def _search_cache(
cache_hits: List[Dict[str, Any]] = self._index.query(query)
# Process cache hits
for hit in cache_hits:
self._refresh_ttl(hit[self.entry_id_field_name])
key = hit["id"]
self._refresh_ttl(key)
# Check for metadata and deserialize
if self.metadata_field_name in hit:
hit[self.metadata_field_name] = self.deserialize(
Expand Down
56 changes: 24 additions & 32 deletions redisvl/extensions/session_manager/base_session.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from typing import Any, Dict, List, Optional, Union
from uuid import uuid4

from redis import Redis

from redisvl.query.filter import FilterExpression


class BaseSessionManager:
id_field_name: str = "id_field"
id_field_name: str = "_id"
role_field_name: str = "role"
content_field_name: str = "content"
tool_field_name: str = "tool_call_id"
timestamp_field_name: str = "timestamp"
session_field_name: str = "session_tag"

def __init__(
self,
name: str,
session_tag: str,
user_tag: str,
session_tag: Optional[str] = None,
):
"""Initialize session memory with index

Expand All @@ -26,29 +29,10 @@ def __init__(
Args:
name (str): The name of the session manager index.
session_tag (str): Tag to be added to entries to link to a specific
session.
user_tag (str): Tag to be added to entries to link to a specific user.
session. Defaults to instance uuid.
"""
self._name = name
self._user_tag = user_tag
self._session_tag = session_tag

def set_scope(
self,
session_tag: Optional[str] = None,
user_tag: Optional[str] = None,
) -> None:
"""Set the filter to apply to querries based on the desired scope.

This new scope persists until another call to set_scope is made, or if
scope specified in calls to get_recent.

Args:
session_tag (str): Id of the specific session to filter to. Default is
None.
user_tag (str): Id of the specific user to filter to. Default is None.
"""
raise NotImplementedError
self._session_tag = session_tag or uuid4().hex

def clear(self) -> None:
"""Clears the chat session history."""
Expand All @@ -75,23 +59,21 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]:
def get_recent(
self,
top_k: int = 5,
session_tag: Optional[str] = None,
user_tag: Optional[str] = None,
as_text: bool = False,
raw: bool = False,
session_tag: Optional[str] = None,
) -> Union[List[str], List[Dict[str, str]]]:
"""Retreive the recent conversation history in sequential order.

Args:
top_k (int): The number of previous exchanges to return. Default is 5.
Note that one exchange contains both a prompt and response.
session_tag (str): Tag to be added to entries to link to a specific
session.
user_tag (str): Tag to be added to entries to link to a specific user.
as_text (bool): Whether to return the conversation as a single string,
or list of alternating prompts and responses.
raw (bool): Whether to return the full Redis hash entry or just the
prompt and response
session_tag (str): Tag to be added to entries to link to a specific
session. Defaults to instance uuid.

Returns:
Union[str, List[str]]: A single string transcription of the session
Expand All @@ -113,6 +95,7 @@ def _format_context(
recent conversation history.
as_text (bool): Whether to return the conversation as a single string,
or list of alternating prompts and responses.

Returns:
Union[str, List[str]]: A single string transcription of the session
or list of strings if as_text is false.
Expand Down Expand Up @@ -141,33 +124,42 @@ def _format_context(
)
return statements

def store(self, prompt: str, response: str) -> None:
def store(
self, prompt: str, response: str, session_tag: Optional[str] = None
) -> None:
"""Insert a prompt:response pair into the session memory. A timestamp
is associated with each exchange so that they can be later sorted
in sequential ordering after retrieval.

Args:
prompt (str): The user prompt to the LLM.
response (str): The corresponding LLM response.
session_tag (Optional[str]): The tag to mark the message with. Defaults to None.
"""
raise NotImplementedError

def add_messages(self, messages: List[Dict[str, str]]) -> None:
def add_messages(
self, messages: List[Dict[str, str]], session_tag: Optional[str] = None
) -> None:
"""Insert a list of prompts and responses into the session memory.
A timestamp is associated with each so that they can be later sorted
in sequential ordering after retrieval.

Args:
messages (List[Dict[str, str]]): The list of user prompts and LLM responses.
session_tag (Optional[str]): The tag to mark the messages with. Defaults to None.
"""
raise NotImplementedError

def add_message(self, message: Dict[str, str]) -> None:
def add_message(
self, message: Dict[str, str], session_tag: Optional[str] = None
) -> None:
"""Insert a single prompt or response into the session memory.
A timestamp is associated with it so that it can be later sorted
in sequential ordering after retrieval.

Args:
message (Dict[str,str]): The user prompt or LLM response.
session_tag (Optional[str]): The tag to mark the message with. Defaults to None.
"""
raise NotImplementedError
Loading