Skip to content

Makes setting scope fully optional in session managers #193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 31, 2024
7 changes: 0 additions & 7 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,3 @@ def clear_db(redis):
def app_name():
return "test_app"

@pytest.fixture
def session_tag():
return "123"

@pytest.fixture
def user_tag():
return "abc"
2 changes: 1 addition & 1 deletion docs/user_guide/session_manager_07.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
],
"source": [
"from redisvl.extensions.session_manager import SemanticSessionManager\n",
"user_session = SemanticSessionManager(name='llm_chef', session_tag='123', user_tag='abc')\n",
"user_session = SemanticSessionManager(name='llm_chef')\n",
"user_session.add_message({\"role\":\"system\", \"content\":\"You are a helpful chef, assisting people in making delicious meals\"})\n",
"\n",
"client = CohereClient()"
Expand Down
5 changes: 3 additions & 2 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class SemanticCache(BaseLLMCache):
"""Semantic Cache for Large Language Models."""

entry_id_field_name: str = "id"
entry_id_field_name: str = "_id"
prompt_field_name: str = "prompt"
vector_field_name: str = "prompt_vector"
response_field_name: str = "response"
Expand Down Expand Up @@ -222,7 +222,8 @@ def _search_cache(
cache_hits: List[Dict[str, Any]] = self._index.query(query)
# Process cache hits
for hit in cache_hits:
self._refresh_ttl(hit[self.entry_id_field_name])
key = hit["id"]
self._refresh_ttl(key)
# Check for metadata and deserialize
if self.metadata_field_name in hit:
hit[self.metadata_field_name] = self.deserialize(
Expand Down
56 changes: 24 additions & 32 deletions redisvl/extensions/session_manager/base_session.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from typing import Any, Dict, List, Optional, Union
from uuid import uuid4

from redis import Redis

from redisvl.query.filter import FilterExpression


class BaseSessionManager:
id_field_name: str = "id_field"
id_field_name: str = "_id"
role_field_name: str = "role"
content_field_name: str = "content"
tool_field_name: str = "tool_call_id"
timestamp_field_name: str = "timestamp"
session_field_name: str = "session_tag"

def __init__(
self,
name: str,
session_tag: str,
user_tag: str,
session_tag: Optional[str] = None,
):
"""Initialize session memory with index

Expand All @@ -26,29 +29,10 @@ def __init__(
Args:
name (str): The name of the session manager index.
session_tag (str): Tag to be added to entries to link to a specific
session.
user_tag (str): Tag to be added to entries to link to a specific user.
session. Defaults to instance uuid.
"""
self._name = name
self._user_tag = user_tag
self._session_tag = session_tag

def set_scope(
self,
session_tag: Optional[str] = None,
user_tag: Optional[str] = None,
) -> None:
"""Set the filter to apply to querries based on the desired scope.

This new scope persists until another call to set_scope is made, or if
scope specified in calls to get_recent.

Args:
session_tag (str): Id of the specific session to filter to. Default is
None.
user_tag (str): Id of the specific user to filter to. Default is None.
"""
raise NotImplementedError
self._session_tag = session_tag or uuid4().hex

def clear(self) -> None:
"""Clears the chat session history."""
Expand All @@ -75,23 +59,21 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]:
def get_recent(
self,
top_k: int = 5,
session_tag: Optional[str] = None,
user_tag: Optional[str] = None,
as_text: bool = False,
raw: bool = False,
session_tag: Optional[str] = None,
) -> Union[List[str], List[Dict[str, str]]]:
"""Retreive the recent conversation history in sequential order.

Args:
top_k (int): The number of previous exchanges to return. Default is 5.
Note that one exchange contains both a prompt and response.
session_tag (str): Tag to be added to entries to link to a specific
session.
user_tag (str): Tag to be added to entries to link to a specific user.
as_text (bool): Whether to return the conversation as a single string,
or list of alternating prompts and responses.
raw (bool): Whether to return the full Redis hash entry or just the
prompt and response
session_tag (str): Tag to be added to entries to link to a specific
session. Defaults to instance uuid.

Returns:
Union[str, List[str]]: A single string transcription of the session
Expand All @@ -113,6 +95,7 @@ def _format_context(
recent conversation history.
as_text (bool): Whether to return the conversation as a single string,
or list of alternating prompts and responses.

Returns:
Union[str, List[str]]: A single string transcription of the session
or list of strings if as_text is false.
Expand Down Expand Up @@ -141,33 +124,42 @@ def _format_context(
)
return statements

def store(self, prompt: str, response: str) -> None:
def store(
self, prompt: str, response: str, session_tag: Optional[str] = None
) -> None:
"""Insert a prompt:response pair into the session memory. A timestamp
is associated with each exchange so that they can be later sorted
in sequential ordering after retrieval.

Args:
prompt (str): The user prompt to the LLM.
response (str): The corresponding LLM response.
session_tag (Optional[str]): The tag to mark the message with. Defaults to None.
"""
raise NotImplementedError

def add_messages(self, messages: List[Dict[str, str]]) -> None:
def add_messages(
self, messages: List[Dict[str, str]], session_tag: Optional[str] = None
) -> None:
"""Insert a list of prompts and responses into the session memory.
A timestamp is associated with each so that they can be later sorted
in sequential ordering after retrieval.

Args:
messages (List[Dict[str, str]]): The list of user prompts and LLM responses.
session_tag (Optional[str]): The tag to mark the messages with. Defaults to None.
"""
raise NotImplementedError

def add_message(self, message: Dict[str, str]) -> None:
def add_message(
self, message: Dict[str, str], session_tag: Optional[str] = None
) -> None:
"""Insert a single prompt or response into the session memory.
A timestamp is associated with it so that it can be later sorted
in sequential ordering after retrieval.

Args:
message (Dict[str,str]): The user prompt or LLM response.
session_tag (Optional[str]): The tag to mark the message with. Defaults to None.
"""
raise NotImplementedError
Loading
Loading