diff --git a/redisvl/extensions/constants.py b/redisvl/extensions/constants.py new file mode 100644 index 00000000..7d7fb841 --- /dev/null +++ b/redisvl/extensions/constants.py @@ -0,0 +1,29 @@ +""" +Constants used within the extension classes SemanticCache, BaseSessionManager, +StandardSessionManager,SemanticSessionManager and SemanticRouter. +These constants are also used within theses classes corresponding schema. +""" + +# BaseSessionManager +ID_FIELD_NAME: str = "entry_id" +ROLE_FIELD_NAME: str = "role" +CONTENT_FIELD_NAME: str = "content" +TOOL_FIELD_NAME: str = "tool_call_id" +TIMESTAMP_FIELD_NAME: str = "timestamp" +SESSION_FIELD_NAME: str = "session_tag" + +# SemanticSessionManager +SESSION_VECTOR_FIELD_NAME: str = "vector_field" + +# SemanticCache +REDIS_KEY_FIELD_NAME: str = "key" +ENTRY_ID_FIELD_NAME: str = "entry_id" +PROMPT_FIELD_NAME: str = "prompt" +RESPONSE_FIELD_NAME: str = "response" +CACHE_VECTOR_FIELD_NAME: str = "prompt_vector" +INSERTED_AT_FIELD_NAME: str = "inserted_at" +UPDATED_AT_FIELD_NAME: str = "updated_at" +METADATA_FIELD_NAME: str = "metadata" + +# SemanticRouter +ROUTE_VECTOR_FIELD_NAME: str = "vector" diff --git a/redisvl/extensions/llmcache/schema.py b/redisvl/extensions/llmcache/schema.py index 515b1421..b42fcb52 100644 --- a/redisvl/extensions/llmcache/schema.py +++ b/redisvl/extensions/llmcache/schema.py @@ -2,6 +2,13 @@ from pydantic.v1 import BaseModel, Field, root_validator, validator +from redisvl.extensions.constants import ( + CACHE_VECTOR_FIELD_NAME, + INSERTED_AT_FIELD_NAME, + PROMPT_FIELD_NAME, + RESPONSE_FIELD_NAME, + UPDATED_AT_FIELD_NAME, +) from redisvl.redis.utils import array_to_buffer, hashify from redisvl.schema import IndexSchema from redisvl.utils.utils import current_timestamp, deserialize, serialize @@ -110,12 +117,12 @@ def from_params(cls, name: str, prefix: str, vector_dims: int): return cls( index={"name": name, "prefix": prefix}, # type: ignore fields=[ # type: ignore - {"name": "prompt", "type": "text"}, - {"name": "response", "type": "text"}, - {"name": "inserted_at", "type": "numeric"}, - {"name": "updated_at", "type": "numeric"}, + {"name": PROMPT_FIELD_NAME, "type": "text"}, + {"name": RESPONSE_FIELD_NAME, "type": "text"}, + {"name": INSERTED_AT_FIELD_NAME, "type": "numeric"}, + {"name": UPDATED_AT_FIELD_NAME, "type": "numeric"}, { - "name": "prompt_vector", + "name": CACHE_VECTOR_FIELD_NAME, "type": "vector", "attrs": { "dims": vector_dims, diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 272df555..d284b602 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -3,6 +3,16 @@ from redis import Redis +from redisvl.extensions.constants import ( + CACHE_VECTOR_FIELD_NAME, + ENTRY_ID_FIELD_NAME, + INSERTED_AT_FIELD_NAME, + METADATA_FIELD_NAME, + PROMPT_FIELD_NAME, + REDIS_KEY_FIELD_NAME, + RESPONSE_FIELD_NAME, + UPDATED_AT_FIELD_NAME, +) from redisvl.extensions.llmcache.base import BaseLLMCache from redisvl.extensions.llmcache.schema import ( CacheEntry, @@ -19,15 +29,6 @@ class SemanticCache(BaseLLMCache): """Semantic Cache for Large Language Models.""" - redis_key_field_name: str = "key" - entry_id_field_name: str = "entry_id" - prompt_field_name: str = "prompt" - response_field_name: str = "response" - vector_field_name: str = "prompt_vector" - inserted_at_field_name: str = "inserted_at" - updated_at_field_name: str = "updated_at" - metadata_field_name: str = "metadata" - _index: SearchIndex _aindex: Optional[AsyncSearchIndex] = None @@ -94,12 +95,12 @@ def __init__( # Process fields and other settings self.set_threshold(distance_threshold) self.return_fields = [ - self.entry_id_field_name, - self.prompt_field_name, - self.response_field_name, - self.inserted_at_field_name, - self.updated_at_field_name, - self.metadata_field_name, + ENTRY_ID_FIELD_NAME, + PROMPT_FIELD_NAME, + RESPONSE_FIELD_NAME, + INSERTED_AT_FIELD_NAME, + UPDATED_AT_FIELD_NAME, + METADATA_FIELD_NAME, ] # Create semantic cache schema and index @@ -133,7 +134,7 @@ def __init__( validate_vector_dims( vectorizer.dims, - self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore + self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims, # type: ignore ) self._vectorizer = vectorizer @@ -145,9 +146,7 @@ def _modify_schema( """Modify the base cache schema using the provided filterable fields""" if filterable_fields is not None: - protected_field_names = set( - self.return_fields + [self.redis_key_field_name] - ) + protected_field_names = set(self.return_fields + [REDIS_KEY_FIELD_NAME]) for filter_field in filterable_fields: field_name = filter_field["name"] if field_name in protected_field_names: @@ -300,7 +299,7 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]: def _check_vector_dims(self, vector: List[float]): """Checks the size of the provided vector and raises an error if it doesn't match the search index vector dimensions.""" - schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore + schema_vector_dims = self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims # type: ignore validate_vector_dims(len(vector), schema_vector_dims) def check( @@ -363,7 +362,7 @@ def check( query = RangeQuery( vector=vector, - vector_field_name=self.vector_field_name, + vector_field_name=CACHE_VECTOR_FIELD_NAME, return_fields=self.return_fields, distance_threshold=distance_threshold, num_results=num_results, @@ -444,7 +443,7 @@ async def acheck( query = RangeQuery( vector=vector, - vector_field_name=self.vector_field_name, + vector_field_name=CACHE_VECTOR_FIELD_NAME, return_fields=self.return_fields, distance_threshold=distance_threshold, num_results=num_results, @@ -479,7 +478,7 @@ def _process_cache_results( cache_hit_dict = { k: v for k, v in cache_hit_dict.items() if k in return_fields } - cache_hit_dict[self.redis_key_field_name] = redis_key + cache_hit_dict[REDIS_KEY_FIELD_NAME] = redis_key cache_hits.append(cache_hit_dict) return redis_keys, cache_hits @@ -541,7 +540,7 @@ def store( keys = self._index.load( data=[cache_entry.to_dict()], ttl=ttl, - id_field=self.entry_id_field_name, + id_field=ENTRY_ID_FIELD_NAME, ) return keys[0] @@ -605,7 +604,7 @@ async def astore( keys = await aindex.load( data=[cache_entry.to_dict()], ttl=ttl, - id_field=self.entry_id_field_name, + id_field=ENTRY_ID_FIELD_NAME, ) return keys[0] @@ -629,13 +628,11 @@ def update(self, key: str, **kwargs) -> None: for k, v in kwargs.items(): # Make sure the item is in the index schema - if k not in set( - self._index.schema.field_names + [self.metadata_field_name] - ): + if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]): raise ValueError(f"{k} is not a valid field within the cache entry") # Check for metadata and deserialize - if k == self.metadata_field_name: + if k == METADATA_FIELD_NAME: if isinstance(v, dict): kwargs[k] = serialize(v) else: @@ -643,7 +640,7 @@ def update(self, key: str, **kwargs) -> None: "If specified, cached metadata must be a dictionary." ) - kwargs.update({self.updated_at_field_name: current_timestamp()}) + kwargs.update({UPDATED_AT_FIELD_NAME: current_timestamp()}) self._index.client.hset(key, mapping=kwargs) # type: ignore @@ -674,13 +671,11 @@ async def aupdate(self, key: str, **kwargs) -> None: for k, v in kwargs.items(): # Make sure the item is in the index schema - if k not in set( - self._index.schema.field_names + [self.metadata_field_name] - ): + if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]): raise ValueError(f"{k} is not a valid field within the cache entry") # Check for metadata and deserialize - if k == self.metadata_field_name: + if k == METADATA_FIELD_NAME: if isinstance(v, dict): kwargs[k] = serialize(v) else: @@ -688,7 +683,7 @@ async def aupdate(self, key: str, **kwargs) -> None: "If specified, cached metadata must be a dictionary." ) - kwargs.update({self.updated_at_field_name: current_timestamp()}) + kwargs.update({UPDATED_AT_FIELD_NAME: current_timestamp()}) await aindex.load(data=[kwargs], keys=[key]) diff --git a/redisvl/extensions/router/schema.py b/redisvl/extensions/router/schema.py index 11b88dc6..04272e2a 100644 --- a/redisvl/extensions/router/schema.py +++ b/redisvl/extensions/router/schema.py @@ -3,6 +3,7 @@ from pydantic.v1 import BaseModel, Field, validator +from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME from redisvl.schema import IndexInfo, IndexSchema @@ -104,7 +105,7 @@ def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema" {"name": "route_name", "type": "tag"}, {"name": "reference", "type": "text"}, { - "name": "vector", + "name": ROUTE_VECTOR_FIELD_NAME, "type": "vector", "attrs": { "algorithm": "flat", diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index bab69578..f1e235aa 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -8,6 +8,7 @@ from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer from redis.exceptions import ResponseError +from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME from redisvl.extensions.router.schema import ( DistanceAggregationMethod, Route, @@ -226,7 +227,7 @@ def _classify_route( """Classify to a single route using a vector.""" vector_range_query = RangeQuery( vector=vector, - vector_field_name="vector", + vector_field_name=ROUTE_VECTOR_FIELD_NAME, distance_threshold=distance_threshold, return_fields=["route_name"], ) @@ -278,7 +279,7 @@ def _classify_multi_route( """Classify to multiple routes, up to max_k (int), using a vector.""" vector_range_query = RangeQuery( vector=vector, - vector_field_name="vector", + vector_field_name=ROUTE_VECTOR_FIELD_NAME, distance_threshold=distance_threshold, return_fields=["route_name"], ) diff --git a/redisvl/extensions/session_manager/base_session.py b/redisvl/extensions/session_manager/base_session.py index ebf3ad9a..fa5c09d0 100644 --- a/redisvl/extensions/session_manager/base_session.py +++ b/redisvl/extensions/session_manager/base_session.py @@ -1,16 +1,15 @@ from typing import Any, Dict, List, Optional, Union +from redisvl.extensions.constants import ( + CONTENT_FIELD_NAME, + ROLE_FIELD_NAME, + TOOL_FIELD_NAME, +) from redisvl.extensions.session_manager.schema import ChatMessage from redisvl.utils.utils import create_uuid class BaseSessionManager: - id_field_name: str = "entry_id" - role_field_name: str = "role" - content_field_name: str = "content" - tool_field_name: str = "tool_call_id" - timestamp_field_name: str = "timestamp" - session_field_name: str = "session_tag" def __init__( self, @@ -107,11 +106,11 @@ def _format_context( context.append(chat_message.content) else: chat_message_dict = { - self.role_field_name: chat_message.role, - self.content_field_name: chat_message.content, + ROLE_FIELD_NAME: chat_message.role, + CONTENT_FIELD_NAME: chat_message.content, } if chat_message.tool_call_id is not None: - chat_message_dict[self.tool_field_name] = chat_message.tool_call_id + chat_message_dict[TOOL_FIELD_NAME] = chat_message.tool_call_id context.append(chat_message_dict) # type: ignore diff --git a/redisvl/extensions/session_manager/schema.py b/redisvl/extensions/session_manager/schema.py index 0e35edd2..aabc67b4 100644 --- a/redisvl/extensions/session_manager/schema.py +++ b/redisvl/extensions/session_manager/schema.py @@ -2,6 +2,15 @@ from pydantic.v1 import BaseModel, Field, root_validator +from redisvl.extensions.constants import ( + CONTENT_FIELD_NAME, + ID_FIELD_NAME, + ROLE_FIELD_NAME, + SESSION_FIELD_NAME, + SESSION_VECTOR_FIELD_NAME, + TIMESTAMP_FIELD_NAME, + TOOL_FIELD_NAME, +) from redisvl.redis.utils import array_to_buffer from redisvl.schema import IndexSchema from redisvl.utils.utils import current_timestamp @@ -31,18 +40,22 @@ class Config: @root_validator(pre=True) @classmethod def generate_id(cls, values): - if "timestamp" not in values: - values["timestamp"] = current_timestamp() - if "entry_id" not in values: - values["entry_id"] = f'{values["session_tag"]}:{values["timestamp"]}' + if TIMESTAMP_FIELD_NAME not in values: + values[TIMESTAMP_FIELD_NAME] = current_timestamp() + if ID_FIELD_NAME not in values: + values[ID_FIELD_NAME] = ( + f"{values[SESSION_FIELD_NAME]}:{values[TIMESTAMP_FIELD_NAME]}" + ) return values def to_dict(self) -> Dict: data = self.dict(exclude_none=True) # handle optional fields - if "vector_field" in data: - data["vector_field"] = array_to_buffer(data["vector_field"]) + if SESSION_VECTOR_FIELD_NAME in data: + data[SESSION_VECTOR_FIELD_NAME] = array_to_buffer( + data[SESSION_VECTOR_FIELD_NAME] + ) return data @@ -55,11 +68,11 @@ def from_params(cls, name: str, prefix: str): return cls( index={"name": name, "prefix": prefix}, # type: ignore fields=[ # type: ignore - {"name": "role", "type": "tag"}, - {"name": "content", "type": "text"}, - {"name": "tool_call_id", "type": "tag"}, - {"name": "timestamp", "type": "numeric"}, - {"name": "session_tag", "type": "tag"}, + {"name": ROLE_FIELD_NAME, "type": "tag"}, + {"name": CONTENT_FIELD_NAME, "type": "text"}, + {"name": TOOL_FIELD_NAME, "type": "tag"}, + {"name": TIMESTAMP_FIELD_NAME, "type": "numeric"}, + {"name": SESSION_FIELD_NAME, "type": "tag"}, ], ) @@ -72,13 +85,13 @@ def from_params(cls, name: str, prefix: str, vectorizer_dims: int): return cls( index={"name": name, "prefix": prefix}, # type: ignore fields=[ # type: ignore - {"name": "role", "type": "tag"}, - {"name": "content", "type": "text"}, - {"name": "tool_call_id", "type": "tag"}, - {"name": "timestamp", "type": "numeric"}, - {"name": "session_tag", "type": "tag"}, + {"name": ROLE_FIELD_NAME, "type": "tag"}, + {"name": CONTENT_FIELD_NAME, "type": "text"}, + {"name": TOOL_FIELD_NAME, "type": "tag"}, + {"name": TIMESTAMP_FIELD_NAME, "type": "numeric"}, + {"name": SESSION_FIELD_NAME, "type": "tag"}, { - "name": "vector_field", + "name": SESSION_VECTOR_FIELD_NAME, "type": "vector", "attrs": { "dims": vectorizer_dims, diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index f5f4c37b..29474904 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -2,6 +2,15 @@ from redis import Redis +from redisvl.extensions.constants import ( + CONTENT_FIELD_NAME, + ID_FIELD_NAME, + ROLE_FIELD_NAME, + SESSION_FIELD_NAME, + SESSION_VECTOR_FIELD_NAME, + TIMESTAMP_FIELD_NAME, + TOOL_FIELD_NAME, +) from redisvl.extensions.session_manager import BaseSessionManager from redisvl.extensions.session_manager.schema import ( ChatMessage, @@ -80,7 +89,7 @@ def __init__( self._index.create(overwrite=False) - self._default_session_filter = Tag(self.session_field_name) == self._session_tag + self._default_session_filter = Tag(SESSION_FIELD_NAME) == self._session_tag def clear(self) -> None: """Clears the chat session history.""" @@ -98,7 +107,7 @@ def drop(self, id: Optional[str] = None) -> None: If None then the last entry is deleted. """ if id is None: - id = self.get_recent(top_k=1, raw=True)[0][self.id_field_name] # type: ignore + id = self.get_recent(top_k=1, raw=True)[0][ID_FIELD_NAME] # type: ignore self._index.client.delete(self._index.key(id)) # type: ignore @@ -108,19 +117,19 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]: # TODO raw or as_text? # TODO refactor method to use get_recent and support other session tags return_fields = [ - self.id_field_name, - self.session_field_name, - self.role_field_name, - self.content_field_name, - self.tool_field_name, - self.timestamp_field_name, + ID_FIELD_NAME, + SESSION_FIELD_NAME, + ROLE_FIELD_NAME, + CONTENT_FIELD_NAME, + TOOL_FIELD_NAME, + TIMESTAMP_FIELD_NAME, ] query = FilterQuery( filter_expression=self._default_session_filter, return_fields=return_fields, ) - query.sort_by(self.timestamp_field_name, asc=True) + query.sort_by(TIMESTAMP_FIELD_NAME, asc=True) messages = self._index.query(query) return self._format_context(messages, as_text=False) @@ -172,22 +181,22 @@ def get_relevant( distance_threshold = distance_threshold or self._distance_threshold return_fields = [ - self.session_field_name, - self.role_field_name, - self.content_field_name, - self.timestamp_field_name, - self.tool_field_name, + SESSION_FIELD_NAME, + ROLE_FIELD_NAME, + CONTENT_FIELD_NAME, + TIMESTAMP_FIELD_NAME, + TOOL_FIELD_NAME, ] session_filter = ( - Tag(self.session_field_name) == session_tag + Tag(SESSION_FIELD_NAME) == session_tag if session_tag else self._default_session_filter ) query = RangeQuery( vector=self._vectorizer.embed(prompt), - vector_field_name=self.vector_field_name, + vector_field_name=SESSION_VECTOR_FIELD_NAME, return_fields=return_fields, distance_threshold=distance_threshold, num_results=top_k, @@ -232,16 +241,16 @@ def get_recent( raise ValueError("top_k must be an integer greater than or equal to 0") return_fields = [ - self.id_field_name, - self.session_field_name, - self.role_field_name, - self.content_field_name, - self.tool_field_name, - self.timestamp_field_name, + ID_FIELD_NAME, + SESSION_FIELD_NAME, + ROLE_FIELD_NAME, + CONTENT_FIELD_NAME, + TOOL_FIELD_NAME, + TIMESTAMP_FIELD_NAME, ] session_filter = ( - Tag(self.session_field_name) == session_tag + Tag(SESSION_FIELD_NAME) == session_tag if session_tag else self._default_session_filter ) @@ -251,7 +260,7 @@ def get_recent( return_fields=return_fields, num_results=top_k, ) - query.sort_by(self.timestamp_field_name, asc=False) + query.sort_by(TIMESTAMP_FIELD_NAME, asc=False) messages = self._index.query(query) if raw: @@ -280,8 +289,8 @@ def store( """ self.add_messages( [ - {self.role_field_name: "user", self.content_field_name: prompt}, - {self.role_field_name: "llm", self.content_field_name: response}, + {ROLE_FIELD_NAME: "user", CONTENT_FIELD_NAME: prompt}, + {ROLE_FIELD_NAME: "llm", CONTENT_FIELD_NAME: response}, ], session_tag, ) @@ -302,27 +311,25 @@ def add_messages( chat_messages: List[Dict[str, Any]] = [] for message in messages: - - content_vector = self._vectorizer.embed(message[self.content_field_name]) - + content_vector = self._vectorizer.embed(message[CONTENT_FIELD_NAME]) validate_vector_dims( len(content_vector), - self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore + self._index.schema.fields[SESSION_VECTOR_FIELD_NAME].attrs.dims, # type: ignore ) chat_message = ChatMessage( - role=message[self.role_field_name], - content=message[self.content_field_name], + role=message[ROLE_FIELD_NAME], + content=message[CONTENT_FIELD_NAME], session_tag=session_tag, vector_field=content_vector, ) - if self.tool_field_name in message: - chat_message.tool_call_id = message[self.tool_field_name] + if TOOL_FIELD_NAME in message: + chat_message.tool_call_id = message[TOOL_FIELD_NAME] chat_messages.append(chat_message.to_dict()) - self._index.load(data=chat_messages, id_field=self.id_field_name) + self._index.load(data=chat_messages, id_field=ID_FIELD_NAME) def add_message( self, message: Dict[str, str], session_tag: Optional[str] = None diff --git a/redisvl/extensions/session_manager/standard_session.py b/redisvl/extensions/session_manager/standard_session.py index 9ecfbb5d..9680133f 100644 --- a/redisvl/extensions/session_manager/standard_session.py +++ b/redisvl/extensions/session_manager/standard_session.py @@ -2,6 +2,14 @@ from redis import Redis +from redisvl.extensions.constants import ( + CONTENT_FIELD_NAME, + ID_FIELD_NAME, + ROLE_FIELD_NAME, + SESSION_FIELD_NAME, + TIMESTAMP_FIELD_NAME, + TOOL_FIELD_NAME, +) from redisvl.extensions.session_manager import BaseSessionManager from redisvl.extensions.session_manager.schema import ( ChatMessage, @@ -13,7 +21,6 @@ class StandardSessionManager(BaseSessionManager): - session_field_name: str = "session_tag" def __init__( self, @@ -63,7 +70,7 @@ def __init__( self._index.create(overwrite=False) - self._default_session_filter = Tag(self.session_field_name) == self._session_tag + self._default_session_filter = Tag(SESSION_FIELD_NAME) == self._session_tag def clear(self) -> None: """Clears the chat session history.""" @@ -81,7 +88,7 @@ def drop(self, id: Optional[str] = None) -> None: If None then the last entry is deleted. """ if id is None: - id = self.get_recent(top_k=1, raw=True)[0][self.id_field_name] # type: ignore + id = self.get_recent(top_k=1, raw=True)[0][ID_FIELD_NAME] # type: ignore self._index.client.delete(self._index.key(id)) # type: ignore @@ -91,19 +98,19 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]: # TODO raw or as_text? # TODO refactor this method to use get_recent and support other session tags? return_fields = [ - self.id_field_name, - self.session_field_name, - self.role_field_name, - self.content_field_name, - self.tool_field_name, - self.timestamp_field_name, + ID_FIELD_NAME, + SESSION_FIELD_NAME, + ROLE_FIELD_NAME, + CONTENT_FIELD_NAME, + TOOL_FIELD_NAME, + TIMESTAMP_FIELD_NAME, ] query = FilterQuery( filter_expression=self._default_session_filter, return_fields=return_fields, ) - query.sort_by(self.timestamp_field_name, asc=True) + query.sort_by(TIMESTAMP_FIELD_NAME, asc=True) messages = self._index.query(query) return self._format_context(messages, as_text=False) @@ -137,16 +144,16 @@ def get_recent( raise ValueError("top_k must be an integer greater than or equal to 0") return_fields = [ - self.id_field_name, - self.session_field_name, - self.role_field_name, - self.content_field_name, - self.tool_field_name, - self.timestamp_field_name, + ID_FIELD_NAME, + SESSION_FIELD_NAME, + ROLE_FIELD_NAME, + CONTENT_FIELD_NAME, + TOOL_FIELD_NAME, + TIMESTAMP_FIELD_NAME, ] session_filter = ( - Tag(self.session_field_name) == session_tag + Tag(SESSION_FIELD_NAME) == session_tag if session_tag else self._default_session_filter ) @@ -156,7 +163,7 @@ def get_recent( return_fields=return_fields, num_results=top_k, ) - query.sort_by(self.timestamp_field_name, asc=False) + query.sort_by(TIMESTAMP_FIELD_NAME, asc=False) messages = self._index.query(query) if raw: @@ -178,8 +185,8 @@ def store( """ self.add_messages( [ - {self.role_field_name: "user", self.content_field_name: prompt}, - {self.role_field_name: "llm", self.content_field_name: response}, + {ROLE_FIELD_NAME: "user", CONTENT_FIELD_NAME: prompt}, + {ROLE_FIELD_NAME: "llm", CONTENT_FIELD_NAME: response}, ], session_tag, ) @@ -202,17 +209,17 @@ def add_messages( for message in messages: chat_message = ChatMessage( - role=message[self.role_field_name], - content=message[self.content_field_name], + role=message[ROLE_FIELD_NAME], + content=message[CONTENT_FIELD_NAME], session_tag=session_tag, ) - if self.tool_field_name in message: - chat_message.tool_call_id = message[self.tool_field_name] + if TOOL_FIELD_NAME in message: + chat_message.tool_call_id = message[TOOL_FIELD_NAME] chat_messages.append(chat_message.to_dict()) - self._index.load(data=chat_messages, id_field=self.id_field_name) + self._index.load(data=chat_messages, id_field=ID_FIELD_NAME) def add_message( self, message: Dict[str, str], session_tag: Optional[str] = None diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_session_manager.py index 20c2955d..3f458a50 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_session_manager.py @@ -1,6 +1,7 @@ import pytest from redis.exceptions import ConnectionError +from redisvl.extensions.constants import ID_FIELD_NAME from redisvl.extensions.session_manager import ( SemanticSessionManager, StandardSessionManager, @@ -261,7 +262,7 @@ def test_standard_drop(standard_session): # test drop(id) removes the specified element context = standard_session.get_recent(top_k=10, raw=True) - middle_id = context[3][standard_session.id_field_name] + middle_id = context[3][ID_FIELD_NAME] standard_session.drop(middle_id) context = standard_session.get_recent(top_k=6) assert context == [ @@ -527,7 +528,7 @@ def test_semantic_drop(semantic_session): # test drop(id) removes the specified element context = semantic_session.get_recent(top_k=5, raw=True) - middle_id = context[2][semantic_session.id_field_name] + middle_id = context[2][ID_FIELD_NAME] semantic_session.drop(middle_id) context = semantic_session.get_recent(top_k=4) assert context == [