From 5e83016443ecd6e3f9d89247b57f79276245587f Mon Sep 17 00:00:00 2001 From: Alex Iribarren Date: Fri, 7 Apr 2023 07:34:10 +0200 Subject: [PATCH] Redis-backed Entity Memory (#2397) I wanted to be able to persist Entity Memory in a Redis database, so I abstracted `ConversationEntityMemory` to allow for pluggable Entity stores (d06f90d). Then I implemented a Entity store that... erm... stores Entities in Redis. By default, Entities will expire from memory after 24 hours, but they'll be persisted for another 3 days every time they're recalled. The idea is to give the AIs a bit of a spaced-repetition memory, but I have yet to see if this is useful. The memory is partitioned by `session_id` (user ID? chat channel? whatever, really) so entities from one user don't leak to another. While developing this, I did notice that the Entity summaries are kind of buggy (they summarize AI-generated content and not just information the human gave them, sometimes they add things like "No new information provided. Existing summary remains: As stated previously, X", etc.), but I'll tackle that later. First I wanted to get some input on this idea. --- langchain/memory/__init__.py | 6 +- langchain/memory/entity.py | 145 +++++++++++++++++++++++++++++++++-- 2 files changed, 143 insertions(+), 8 deletions(-) diff --git a/langchain/memory/__init__.py b/langchain/memory/__init__.py index 5799e734b1ad7..1938089f810d8 100644 --- a/langchain/memory/__init__.py +++ b/langchain/memory/__init__.py @@ -7,7 +7,10 @@ from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory from langchain.memory.combined import CombinedMemory -from langchain.memory.entity import ConversationEntityMemory +from langchain.memory.entity import ( + ConversationEntityMemory, + ConversationEntityRedisMemory, +) from langchain.memory.kg import ConversationKGMemory from langchain.memory.readonly import ReadOnlySharedMemory from langchain.memory.simple import SimpleMemory @@ -23,6 +26,7 @@ "ConversationSummaryBufferMemory", "ConversationKGMemory", "ConversationEntityMemory", + "ConversationEntityRedisMemory", "ConversationSummaryMemory", "ChatMessageHistory", "ConversationStringBufferMemory", diff --git a/langchain/memory/entity.py b/langchain/memory/entity.py index 95aac811a26f9..7c4b7c74e1dee 100644 --- a/langchain/memory/entity.py +++ b/langchain/memory/entity.py @@ -1,4 +1,7 @@ -from typing import Any, Dict, List, Optional +import logging +from abc import abstractmethod +from itertools import islice +from typing import Any, Dict, Iterable, List, Optional from langchain.chains.llm import LLMChain from langchain.memory.chat_memory import BaseChatMemory @@ -10,8 +13,10 @@ from langchain.prompts.base import BasePromptTemplate from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string +logger = logging.getLogger(__name__) -class ConversationEntityMemory(BaseChatMemory): + +class BaseConversationEntityMemory(BaseChatMemory): """Entity extractor & summarizer to memory.""" human_prefix: str = "Human" @@ -19,11 +24,35 @@ class ConversationEntityMemory(BaseChatMemory): llm: BaseLanguageModel entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT - store: Dict[str, Optional[str]] = {} entity_cache: List[str] = [] k: int = 3 chat_history_key: str = "history" + @abstractmethod + def store_get(self, key: str, default: Optional[str] = None) -> Optional[str]: + """Get entity value from store.""" + pass + + @abstractmethod + def store_set(self, key: str, value: Optional[str]) -> None: + """Set entity value in store.""" + pass + + @abstractmethod + def store_del(self, key: str) -> None: + """Delete entity value from store.""" + pass + + @abstractmethod + def store_exists(self, key: str) -> bool: + """Check if entity exists in store.""" + pass + + @abstractmethod + def store_clear(self) -> None: + """Delete all entities from store.""" + pass + @property def buffer(self) -> List[BaseMessage]: return self.chat_memory.messages @@ -58,7 +87,7 @@ def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: entities = [w.strip() for w in output.split(",")] entity_summaries = {} for entity in entities: - entity_summaries[entity] = self.store.get(entity, "") + entity_summaries[entity] = self.store_get(entity, "") self.entity_cache = entities if self.return_messages: buffer: Any = self.buffer[-self.k * 2 :] @@ -87,16 +116,118 @@ def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt) for entity in self.entity_cache: - existing_summary = self.store.get(entity, "") + existing_summary = self.store_get(entity, "") output = chain.predict( summary=existing_summary, entity=entity, history=buffer_string, input=input_data, ) - self.store[entity] = output.strip() + self.store_set(entity, output.strip()) def clear(self) -> None: """Clear memory contents.""" self.chat_memory.clear() - self.store = {} + self.store_clear() + + +class ConversationEntityMemory(BaseConversationEntityMemory): + """Basic in-memory entity store.""" + + store: Dict[str, Optional[str]] = {} + + def store_get(self, key: str, default: Optional[str] = None) -> Optional[str]: + return self.store.get(key, default) + + def store_set(self, key: str, value: Optional[str]) -> None: + self.store[key] = value + + def store_del(self, key: str) -> None: + del self.store[key] + + def store_exists(self, key: str) -> bool: + return key in self.store + + def store_clear(self) -> None: + return self.store.clear() + + +class ConversationEntityRedisMemory(BaseConversationEntityMemory): + """Redis-backed Entity store. Entities get a TTL of 1 day by default, and + that TTL is extended by 3 days every time the entity is read back. + """ + + redis_client: Any + session_id: str = "default" + key_prefix: str = "memory_store" + ttl: Optional[int] = 60 * 60 * 24 + recall_ttl: Optional[int] = 60 * 60 * 24 * 3 + + def __init__( + self, + session_id: str = "default", + url: str = "redis://localhost:6379/0", + key_prefix: str = "memory_store", + ttl: Optional[int] = 60 * 60 * 24, + recall_ttl: Optional[int] = 60 * 60 * 24 * 3, + *args: Any, + **kwargs: Any, + ): + try: + import redis + except ImportError: + raise ValueError( + "Could not import redis python package. " + "Please install it with `pip install redis`." + ) + + super().__init__(*args, **kwargs) + + try: + self.redis_client = redis.Redis.from_url(url=url, decode_responses=True) + except redis.exceptions.ConnectionError as error: + logger.error(error) + + self.session_id = session_id + self.key_prefix = key_prefix + self.ttl = ttl + self.recall_ttl = recall_ttl or ttl + + @property + def full_key_prefix(self) -> str: + return f"{self.key_prefix}:{self.session_id}" + + def store_get(self, key: str, default: Optional[str] = None) -> Optional[str]: + res = ( + self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl) + or default + or "" + ) + logger.debug(f"REDIS MEM get '{self.full_key_prefix}:{key}': '{res}'") + return res + + def store_set(self, key: str, value: Optional[str]) -> None: + if not value: + return self.store_del(key) + self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl) + logger.debug( + f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}" + ) + + def store_del(self, key: str) -> None: + self.redis_client.delete(f"{self.full_key_prefix}:{key}") + + def store_exists(self, key: str) -> bool: + return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1 + + def store_clear(self) -> None: + # iterate a list in batches of size batch_size + def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]: + iterator = iter(iterable) + while batch := list(islice(iterator, batch_size)): + yield batch + + for keybatch in batched( + self.redis_client.scan_iter(f"{self.full_key_prefix}:*"), 500 + ): + self.redis_client.delete(*keybatch)