Skip to content

standard session uses hash implementation #191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 33 additions & 25 deletions redisvl/extensions/session_manager/semantic_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,34 @@
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer


class SemanticSessionIndexSchema(IndexSchema):

@classmethod
def from_params(cls, name: str, prefix: str, vectorizer_dims: int):

return cls(
index={"name": name, "prefix": prefix}, # type: ignore
fields=[ # type: ignore
{"name": "role", "type": "text"},
{"name": "content", "type": "text"},
{"name": "tool_call_id", "type": "text"},
{"name": "timestamp", "type": "numeric"},
{"name": "session_tag", "type": "tag"},
{"name": "user_tag", "type": "tag"},
{
"name": "vector_field",
"type": "vector",
"attrs": {
"dims": vectorizer_dims,
"datatype": "float32",
"distance_metric": "cosine",
"algorithm": "flat",
},
},
],
)


class SemanticSessionManager(BaseSessionManager):
session_field_name: str = "session_tag"
user_field_name: str = "user_tag"
Expand Down Expand Up @@ -68,27 +96,8 @@ def __init__(

self.set_distance_threshold(distance_threshold)

schema = IndexSchema.from_dict({"index": {"name": name, "prefix": prefix}})

schema.add_fields(
[
{"name": "role", "type": "text"},
{"name": "content", "type": "text"},
{"name": "tool_call_id", "type": "text"},
{"name": "timestamp", "type": "numeric"},
{"name": "session_tag", "type": "tag"},
{"name": "user_tag", "type": "tag"},
{
"name": "vector_field",
"type": "vector",
"attrs": {
"dims": self._vectorizer.dims,
"datatype": "float32",
"distance_metric": "cosine",
"algorithm": "flat",
},
},
]
schema = SemanticSessionIndexSchema.from_params(
name, prefix, self._vectorizer.dims
)

self._index = SearchIndex(schema=schema)
Expand Down Expand Up @@ -260,19 +269,18 @@ def get_recent(
"""Retreive the recent conversation history in sequential order.

Args:
as_text (bool): Whether to return the conversation as a single string,
or list of alternating prompts and responses.
top_k (int): The number of previous exchanges to return. Default is 5.
Note that one exchange contains both a prompt and a respoonse.
session_tag (str): Tag to be added to entries to link to a specific
session.
user_tag (str): Tag to be added to entries to link to a specific user.
as_text (bool): Whether to return the conversation as a single string,
or list of alternating prompts and responses.
raw (bool): Whether to return the full Redis hash entry or just the
prompt and response

Returns:
Union[str, List[str]]: A single string transcription of the session
or list of strings if as_text is false.
or list of strings if as_text is false.

Raises:
ValueError: if top_k is not an integer greater than or equal to 0.
Expand Down
162 changes: 115 additions & 47 deletions redisvl/extensions/session_manager/standard_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,40 @@
from redis import Redis

from redisvl.extensions.session_manager import BaseSessionManager
from redisvl.redis.connection import RedisConnectionFactory
from redisvl.index import SearchIndex
from redisvl.query import FilterQuery
from redisvl.query.filter import Tag
from redisvl.schema.schema import IndexSchema


class StandardSessionIndexSchema(IndexSchema):

@classmethod
def from_params(cls, name: str, prefix: str):

return cls(
index={"name": name, "prefix": prefix}, # type: ignore
fields=[ # type: ignore
{"name": "role", "type": "text"},
{"name": "content", "type": "text"},
{"name": "tool_call_id", "type": "text"},
{"name": "timestamp", "type": "numeric"},
{"name": "session_tag", "type": "tag"},
{"name": "user_tag", "type": "tag"},
],
)


class StandardSessionManager(BaseSessionManager):
session_field_name: str = "session_tag"
user_field_name: str = "user_tag"

def __init__(
self,
name: str,
session_tag: str,
user_tag: str,
prefix: Optional[str] = None,
redis_client: Optional[Redis] = None,
redis_url: str = "redis://localhost:6379",
connection_kwargs: Dict[str, Any] = {},
Expand All @@ -29,9 +53,11 @@ def __init__(

Args:
name (str): The name of the session manager index.
session_tag (str): Tag to be added to entries to link to a specific
session_tag (Optional[str]): Tag to be added to entries to link to a specific
session.
user_tag (str): Tag to be added to entries to link to a specific user.
user_tag (Optional[str]): Tag to be added to entries to link to a specific user.
prefix (Optional[str]): Prefix for the keys for this session data.
Defaults to None and will be replaced with the index name.
redis_client (Optional[Redis]): A Redis client instance. Defaults to
None.
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
Expand All @@ -44,14 +70,18 @@ def __init__(
"""
super().__init__(name, session_tag, user_tag)

prefix = prefix or name

schema = StandardSessionIndexSchema.from_params(name, prefix)
self._index = SearchIndex(schema=schema)

# handle redis connection
if redis_client:
self._client = redis_client
self._index.set_client(redis_client)
elif redis_url:
self._client = RedisConnectionFactory.get_redis_connection(
redis_url, **connection_kwargs
)
RedisConnectionFactory.validate_sync_redis(self._client)
self._index.connect(redis_url=redis_url, **connection_kwargs)

self._index.create(overwrite=False)

self.set_scope(session_tag, user_tag)

Expand All @@ -63,27 +93,35 @@ def set_scope(
"""Set the filter to apply to queries based on the desired scope.

This new scope persists until another call to set_scope is made, or if
scope is specified in calls to get_recent.
scope specified in calls to get_recent or get_relevant.

Args:
session_tag (str): Id of the specific session to filter to. Default is
None, which means session_tag will be unchanged.
None, which means all sessions will be in scope.
user_tag (str): Id of the specific user to filter to. Default is None,
which means user_tag will be unchanged.
which means all users will be in scope.
"""
if not (session_tag or user_tag):
return

self._session_tag = session_tag or self._session_tag
self._user_tag = user_tag or self._user_tag
tag_filter = Tag(self.user_field_name) == []
if user_tag:
tag_filter = tag_filter & (Tag(self.user_field_name) == self._user_tag)
if session_tag:
tag_filter = tag_filter & (
Tag(self.session_field_name) == self._session_tag
)

self._tag_filter = tag_filter

def clear(self) -> None:
"""Clears the chat session history."""
self._client.delete(self.key)
self._index.clear()

def delete(self) -> None:
"""Clears the chat session history."""
self._client.delete(self.key)
"""Clear all conversation keys and remove the search index."""
self._index.delete(drop=True)

def drop(self, id_field: Optional[str] = None) -> None:
"""Remove a specific exchange from the conversation history.
Expand All @@ -93,19 +131,36 @@ def drop(self, id_field: Optional[str] = None) -> None:
If None then the last entry is deleted.
"""
if id_field:
messages = self._client.lrange(self.key, 0, -1)
messages = [json.loads(msg) for msg in messages]
messages = [msg for msg in messages if msg["id_field"] != id_field]
messages = [json.dumps(msg) for msg in messages]
self.clear()
self._client.rpush(self.key, *messages)
sep = self._index.key_separator
key = sep.join([self._index.schema.index.name, id_field])
else:
self._client.rpop(self.key)
key = self.get_recent(top_k=1, raw=True)[0]["id"] # type: ignore
self._index.client.delete(key) # type: ignore

@property
def messages(self) -> Union[List[str], List[Dict[str, str]]]:
"""Returns the full chat history."""
return self.get_recent(top_k=-1)
# TODO raw or as_text?
return_fields = [
self.id_field_name,
self.session_field_name,
self.user_field_name,
self.role_field_name,
self.content_field_name,
self.tool_field_name,
self.timestamp_field_name,
]

query = FilterQuery(
filter_expression=self._tag_filter,
return_fields=return_fields,
)

sorted_query = query.query
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We added sort by support to the redisvl query classes -- so you should be able to do this there?

sorted_query.sort_by(self.timestamp_field_name, asc=True)
hits = self._index.search(sorted_query, query.params).docs

return self._format_context(hits, as_text=False)

def get_recent(
self,
Expand All @@ -119,7 +174,6 @@ def get_recent(

Args:
top_k (int): The number of previous messages to return. Default is 5.
To get all messages set top_k = -1.
session_tag (str): Tag to be added to entries to link to a specific
session.
user_tag (str): Tag to be added to entries to link to a specific user.
Expand All @@ -133,24 +187,35 @@ def get_recent(
or list of strings if as_text is false.

Raises:
ValueError: if top_k is not an integer greater than or equal to -1.
ValueError: if top_k is not an integer greater than or equal to 0.
"""
if type(top_k) != int or top_k < -1:
raise ValueError("top_k must be an integer greater than or equal to -1")
if top_k == 0:
return []
elif top_k == -1:
top_k = 0
if type(top_k) != int or top_k < 0:
raise ValueError("top_k must be an integer greater than or equal to 0")

self.set_scope(session_tag, user_tag)
messages = self._client.lrange(self.key, -top_k, -1)
messages = [json.loads(msg) for msg in messages]
if raw:
return messages
return self._format_context(messages, as_text)
return_fields = [
self.id_field_name,
self.session_field_name,
self.user_field_name,
self.role_field_name,
self.content_field_name,
self.tool_field_name,
self.timestamp_field_name,
]

query = FilterQuery(
filter_expression=self._tag_filter,
return_fields=return_fields,
num_results=top_k,
)

@property
def key(self):
return ":".join([self._name, self._user_tag, self._session_tag])
sorted_query = query.query
sorted_query.sort_by(self.timestamp_field_name, asc=False)
hits = self._index.search(sorted_query, query.params).docs

if raw:
return hits[::-1]
return self._format_context(hits[::-1], as_text)

def store(self, prompt: str, response: str) -> None:
"""Insert a prompt:response pair into the session memory. A timestamp
Expand All @@ -162,7 +227,10 @@ def store(self, prompt: str, response: str) -> None:
response (str): The corresponding LLM response.
"""
self.add_messages(
[{"role": "user", "content": prompt}, {"role": "llm", "content": response}]
[
{self.role_field_name: "user", self.content_field_name: prompt},
{self.role_field_name: "llm", self.content_field_name: response},
]
)

def add_messages(self, messages: List[Dict[str, str]]) -> None:
Expand All @@ -173,23 +241,23 @@ def add_messages(self, messages: List[Dict[str, str]]) -> None:
Args:
messages (List[Dict[str, str]]): The list of user prompts and LLM responses.
"""
sep = self._index.key_separator
payloads = []
for message in messages:
timestamp = time()
id_field = sep.join([self._user_tag, self._session_tag, str(timestamp)])
payload = {
self.id_field_name: ":".join(
[self._user_tag, self._session_tag, str(timestamp)]
),
self.id_field_name: id_field,
self.role_field_name: message[self.role_field_name],
self.content_field_name: message[self.content_field_name],
self.timestamp_field_name: timestamp,
self.session_field_name: self._session_tag,
self.user_field_name: self._user_tag,
}
if self.tool_field_name in message:
payload.update({self.tool_field_name: message[self.tool_field_name]})

payloads.append(json.dumps(payload))

self._client.rpush(self.key, *payloads)
payloads.append(payload)
self._index.load(data=payloads, id_field=self.id_field_name)

def add_message(self, message: Dict[str, str]) -> None:
"""Insert a single prompt or response into the session memory.
Expand Down
Loading
Loading