Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add fallback memory #736

Merged
merged 3 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from unittest.mock import MagicMock

import pytest

from reworkd_platform.web.api.memory.memory_with_fallback import MemoryWithFallback


@pytest.mark.parametrize(
"method_name, args",
[
("__enter__", ()),
("__exit__", (None, None, None)),
("add_tasks", (["task1", "task2"],)),
("get_similar_tasks", ("task1", 0.8)),
("reset_class", ()),
],
)
def test_memory_with_fallback(method_name: str, args) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤯

primary = MagicMock()
secondary = MagicMock()
memory_with_fallback = MemoryWithFallback(primary, secondary)

# Use getattr() to call the method on the object with args
getattr(memory_with_fallback, method_name)(*args)
getattr(primary, method_name).assert_called_once_with(*args)
getattr(secondary, method_name).assert_not_called()

# Reset mock and make primary raise an exception
getattr(primary, method_name).reset_mock()
getattr(primary, method_name).side_effect = Exception("Primary Failed")

# Call the method again, this time it should fall back to secondary
getattr(memory_with_fallback, method_name)(*args)
getattr(primary, method_name).assert_called_once_with(*args)
getattr(secondary, method_name).assert_called_once_with(*args)
3 changes: 2 additions & 1 deletion platform/reworkd_platform/web/api/agent/dependancies.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from reworkd_platform.settings import settings
from reworkd_platform.web.api.dependencies import get_current_user
from reworkd_platform.web.api.memory.memory_with_fallback import MemoryWithFallback
from reworkd_platform.web.api.memory.null import NullAgentMemory
from reworkd_platform.web.api.memory.weaviate import WeaviateMemory

Expand All @@ -34,7 +35,7 @@ def get_agent_memory(
):
vector_db_exists = settings.vector_db_url and settings.vector_db_url != ""
if vector_db_exists and not settings.ff_mock_mode_enabled:
return WeaviateMemory(user.id)
return MemoryWithFallback(WeaviateMemory(user.id), NullAgentMemory())
else:
return NullAgentMemory()

Expand Down
10 changes: 5 additions & 5 deletions platform/reworkd_platform/web/api/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ class AgentMemory(ABC):
"""

def __enter__(self) -> "AgentMemory":
pass
raise NotImplementedError()

def __exit__(self, exc_type, exc_value, traceback) -> None:
pass
raise NotImplementedError()

def add_tasks(self, tasks: List[str]) -> None:
pass
def add_tasks(self, tasks: List[str]) -> List[str]:
raise NotImplementedError()

def get_similar_tasks(self, query: str, score_threshold: float) -> List[str]:
pass
raise NotImplementedError()

def reset_class(self):
pass
53 changes: 53 additions & 0 deletions platform/reworkd_platform/web/api/memory/memory_with_fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

from typing import List

from loguru import logger

from reworkd_platform.web.api.memory.memory import AgentMemory


class MemoryWithFallback(AgentMemory):
"""
Wrap a primary AgentMemory provider and use a fallback in the case that it fails
We do this because we've had issues with Weaviate crashing and causing memory to randomly fail
"""

def __init__(self, primary: AgentMemory, secondary: AgentMemory):
self.primary = primary
self.secondary = secondary

def __enter__(self) -> AgentMemory:
try:
return self.primary.__enter__()
except Exception as e:
logger.exception(e)
return self.secondary.__enter__()

def __exit__(self, exc_type, exc_value, traceback) -> None:
try:
self.primary.__exit__(exc_type, exc_value, traceback)
except Exception as e:
logger.exception(e)
self.secondary.__exit__(exc_type, exc_value, traceback)

def add_tasks(self, tasks: List[str]) -> List[str]:
try:
return self.primary.add_tasks(tasks)
except Exception as e:
logger.exception(e)
return self.secondary.add_tasks(tasks)

def get_similar_tasks(self, query: str, score_threshold) -> List[str]:
try:
return self.primary.get_similar_tasks(query, score_threshold)
except Exception as e:
logger.exception(e)
return self.secondary.get_similar_tasks(query, score_threshold)

def reset_class(self) -> None:
try:
self.primary.reset_class()
except Exception as e:
logger.exception(e)
self.secondary.reset_class()
4 changes: 2 additions & 2 deletions platform/reworkd_platform/web/api/memory/null.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def __enter__(self) -> AgentMemory:
def __exit__(self, exc_type, exc_value, traceback) -> None:
pass

def add_tasks(self, tasks: List[str]) -> None:
pass
def add_tasks(self, tasks: List[str]) -> List[str]:
return []

def get_similar_tasks(self, query: str, score_threshold: float) -> List[str]:
return []
Expand Down
22 changes: 13 additions & 9 deletions platform/reworkd_platform/web/api/memory/weaviate.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from __future__ import annotations

from typing import List, Dict, cast, Tuple
from typing import List, Dict, cast, Tuple, Optional

import numpy as np
import weaviate
import weaviate # type: ignore
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Weaviate
from loguru import logger
from weaviate import UnexpectedStatusCodeException

from reworkd_platform.settings import settings
from reworkd_platform.web.api.memory.memory import SimilarTasks, AgentMemory
from reworkd_platform.web.api.memory.memory import AgentMemory


def _default_schema(index_name: str, text_key: str) -> Dict:
Expand All @@ -33,7 +33,7 @@ class WeaviateMemory(AgentMemory):
Wrapper around the Weaviate vector database
"""

db: Weaviate = None
db: Optional[Weaviate] = None

def __init__(self, index_name: str):
self.index_name = CLASS_PREFIX + index_name
Expand All @@ -52,7 +52,11 @@ def __enter__(self) -> AgentMemory:
self._create_class()

# Instantiate client with embedding provider
self.embeddings = OpenAIEmbeddings(openai_api_key=settings.openai_api_key)
self.embeddings = OpenAIEmbeddings(
client=None, # Meta private value but mypy will complain its missing
openai_api_key=settings.openai_api_key,
)

self.db = Weaviate(
self.client,
self.index_name,
Expand All @@ -73,11 +77,11 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
self.client.__del__()

def add_tasks(self, tasks: List[str]) -> List[str]:
if self.db is None:
raise Exception("WeaviateMemory not initialized")
return self.db.add_texts(tasks)

def get_similar_tasks(
self, query: str, score_threshold: float = 0.7
) -> SimilarTasks:
def get_similar_tasks(self, query: str, score_threshold: float) -> List[str]:
# Get similar tasks
results = self._similarity_search_with_score(query)

Expand All @@ -87,7 +91,7 @@ def get_score(result: Tuple[str, float]) -> float:
results.sort(key=get_score, reverse=True)

# Return formatted response
return [(text, score) for [text, score] in results if score >= score_threshold]
return [text for [text, score] in results if score >= score_threshold]

def reset_class(self) -> None:
try:
Expand Down
26 changes: 26 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import weaviate

auth = (
weaviate.auth.AuthApiKey(api_key="KNaObeDhRVRaI488QkEoprZ3LriotjRIo6Rg")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... you get the gorilla

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted 💥

)

client = weaviate.Client("https://zgjbgueysdoxesgb7f8esa.gcp-d.weaviate.cloud", auth)


def _default_schema(index_name: str, text_key: str):
return {
"class": index_name,
"properties": [
{
"name": text_key,
"dataType": ["text"],
}
],
}


schema = _default_schema("testytest", "testytest")
client.schema.create_class(schema)

schema = client.schema.get()
print(schema)