Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redis-backed Entity Memory #2397

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion langchain/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
from langchain.memory.combined import CombinedMemory
from langchain.memory.entity import ConversationEntityMemory
from langchain.memory.entity import (
ConversationEntityMemory,
ConversationEntityRedisMemory,
)
from langchain.memory.kg import ConversationKGMemory
from langchain.memory.readonly import ReadOnlySharedMemory
from langchain.memory.simple import SimpleMemory
Expand All @@ -23,6 +26,7 @@
"ConversationSummaryBufferMemory",
"ConversationKGMemory",
"ConversationEntityMemory",
"ConversationEntityRedisMemory",
"ConversationSummaryMemory",
"ChatMessageHistory",
"ConversationStringBufferMemory",
Expand Down
145 changes: 138 additions & 7 deletions langchain/memory/entity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Any, Dict, List, Optional
import logging
from abc import abstractmethod
from itertools import islice
from typing import Any, Dict, Iterable, List, Optional

from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
Expand All @@ -10,20 +13,46 @@
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string

logger = logging.getLogger(__name__)

class ConversationEntityMemory(BaseChatMemory):

class BaseConversationEntityMemory(BaseChatMemory):
"""Entity extractor & summarizer to memory."""

human_prefix: str = "Human"
ai_prefix: str = "AI"
llm: BaseLanguageModel
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT
store: Dict[str, Optional[str]] = {}
entity_cache: List[str] = []
k: int = 3
chat_history_key: str = "history"

@abstractmethod
def store_get(self, key: str, default: Optional[str] = None) -> Optional[str]:
"""Get entity value from store."""
pass

@abstractmethod
def store_set(self, key: str, value: Optional[str]) -> None:
"""Set entity value in store."""
pass

@abstractmethod
def store_del(self, key: str) -> None:
"""Delete entity value from store."""
pass

@abstractmethod
def store_exists(self, key: str) -> bool:
"""Check if entity exists in store."""
pass

@abstractmethod
def store_clear(self) -> None:
"""Delete all entities from store."""
pass

@property
def buffer(self) -> List[BaseMessage]:
return self.chat_memory.messages
Expand Down Expand Up @@ -58,7 +87,7 @@ def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
entities = [w.strip() for w in output.split(",")]
entity_summaries = {}
for entity in entities:
entity_summaries[entity] = self.store.get(entity, "")
entity_summaries[entity] = self.store_get(entity, "")
self.entity_cache = entities
if self.return_messages:
buffer: Any = self.buffer[-self.k * 2 :]
Expand Down Expand Up @@ -87,16 +116,118 @@ def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt)

for entity in self.entity_cache:
existing_summary = self.store.get(entity, "")
existing_summary = self.store_get(entity, "")
output = chain.predict(
summary=existing_summary,
entity=entity,
history=buffer_string,
input=input_data,
)
self.store[entity] = output.strip()
self.store_set(entity, output.strip())

def clear(self) -> None:
"""Clear memory contents."""
self.chat_memory.clear()
self.store = {}
self.store_clear()


class ConversationEntityMemory(BaseConversationEntityMemory):
"""Basic in-memory entity store."""

store: Dict[str, Optional[str]] = {}

def store_get(self, key: str, default: Optional[str] = None) -> Optional[str]:
return self.store.get(key, default)

def store_set(self, key: str, value: Optional[str]) -> None:
self.store[key] = value

def store_del(self, key: str) -> None:
del self.store[key]

def store_exists(self, key: str) -> bool:
return key in self.store

def store_clear(self) -> None:
return self.store.clear()


class ConversationEntityRedisMemory(BaseConversationEntityMemory):
"""Redis-backed Entity store. Entities get a TTL of 1 day by default, and
that TTL is extended by 3 days every time the entity is read back.
"""

redis_client: Any
session_id: str = "default"
key_prefix: str = "memory_store"
ttl: Optional[int] = 60 * 60 * 24
recall_ttl: Optional[int] = 60 * 60 * 24 * 3

def __init__(
self,
session_id: str = "default",
url: str = "redis://localhost:6379/0",
key_prefix: str = "memory_store",
ttl: Optional[int] = 60 * 60 * 24,
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
*args: Any,
**kwargs: Any,
):
try:
import redis
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)

super().__init__(*args, **kwargs)

try:
self.redis_client = redis.Redis.from_url(url=url, decode_responses=True)
except redis.exceptions.ConnectionError as error:
logger.error(error)

self.session_id = session_id
self.key_prefix = key_prefix
self.ttl = ttl
self.recall_ttl = recall_ttl or ttl

@property
def full_key_prefix(self) -> str:
return f"{self.key_prefix}:{self.session_id}"

def store_get(self, key: str, default: Optional[str] = None) -> Optional[str]:
res = (
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
or default
or ""
)
logger.debug(f"REDIS MEM get '{self.full_key_prefix}:{key}': '{res}'")
return res

def store_set(self, key: str, value: Optional[str]) -> None:
if not value:
return self.store_del(key)
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
logger.debug(
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
)

def store_del(self, key: str) -> None:
self.redis_client.delete(f"{self.full_key_prefix}:{key}")

def store_exists(self, key: str) -> bool:
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1

def store_clear(self) -> None:
# iterate a list in batches of size batch_size
def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]:
iterator = iter(iterable)
while batch := list(islice(iterator, batch_size)):
yield batch

for keybatch in batched(
self.redis_client.scan_iter(f"{self.full_key_prefix}:*"), 500
):
self.redis_client.delete(*keybatch)