Skip to content

Commit

Permalink
Implement Redis-backed Entity Memory
Browse files Browse the repository at this point in the history
  • Loading branch information
alexiri committed Apr 6, 2023
1 parent be8b550 commit be72f4b
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 2 deletions.
6 changes: 5 additions & 1 deletion langchain/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
from langchain.memory.combined import CombinedMemory
from langchain.memory.entity import ConversationEntityMemory
from langchain.memory.entity import (
ConversationEntityMemory,
ConversationEntityRedisMemory,
)
from langchain.memory.kg import ConversationKGMemory
from langchain.memory.readonly import ReadOnlySharedMemory
from langchain.memory.simple import SimpleMemory
Expand All @@ -23,6 +26,7 @@
"ConversationSummaryBufferMemory",
"ConversationKGMemory",
"ConversationEntityMemory",
"ConversationEntityRedisMemory",
"ConversationSummaryMemory",
"ChatMessageHistory",
"ConversationStringBufferMemory",
Expand Down
87 changes: 86 additions & 1 deletion langchain/memory/entity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
from abc import abstractmethod
from typing import Any, Dict, List, Optional
from itertools import islice
from typing import Any, Dict, Iterable, List, Optional

from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
Expand All @@ -11,6 +13,8 @@
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string

logger = logging.getLogger(__name__)


class BaseConversationEntityMemory(BaseChatMemory):
"""Entity extractor & summarizer to memory."""
Expand Down Expand Up @@ -146,3 +150,84 @@ def store_exists(self, key: str) -> bool:

def store_clear(self) -> None:
return self.store.clear()


class ConversationEntityRedisMemory(BaseConversationEntityMemory):
"""Redis-backed Entity store. Entities get a TTL of 1 day by default, and
that TTL is extended by 3 days every time the entity is read back.
"""

redis_client: Any
session_id: str = "default"
key_prefix: str = "memory_store"
ttl: Optional[int] = 60 * 60 * 24
recall_ttl: Optional[int] = 60 * 60 * 24 * 3

def __init__(
self,
session_id: str = "default",
url: str = "redis://localhost:6379/0",
key_prefix: str = "memory_store",
ttl: Optional[int] = 60 * 60 * 24,
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
*args: Any,
**kwargs: Any,
):
try:
import redis
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)

super().__init__(*args, **kwargs)

try:
self.redis_client = redis.Redis.from_url(url=url, decode_responses=True)
except redis.exceptions.ConnectionError as error:
logger.error(error)

self.session_id = session_id
self.key_prefix = key_prefix
self.ttl = ttl
self.recall_ttl = recall_ttl or ttl

@property
def full_key_prefix(self) -> str:
return f"{self.key_prefix}:{self.session_id}"

def store_get(self, key: str, default: Optional[str] = None) -> Optional[str]:
res = (
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
or default
or ""
)
logger.debug(f"REDIS MEM get '{self.full_key_prefix}:{key}': '{res}'")
return res

def store_set(self, key: str, value: Optional[str]) -> None:
if not value:
return self.store_del(key)
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
logger.debug(
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
)

def store_del(self, key: str) -> None:
self.redis_client.delete(f"{self.full_key_prefix}:{key}")

def store_exists(self, key: str) -> bool:
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1

def store_clear(self) -> None:
# iterate a list in batches of size batch_size
def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]:
iterator = iter(iterable)
while batch := list(islice(iterator, batch_size)):
yield batch

for keybatch in batched(
self.redis_client.scan_iter(f"{self.full_key_prefix}:*"), 500
):
self.redis_client.delete(*keybatch)

0 comments on commit be72f4b

Please sign in to comment.