Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🧠 Small memory updates #718

Merged
merged 3 commits into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/docs/development/memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ the [Weaviate docs](https://weaviate.io/developers/weaviate).
Essentially, vector databases allow us to save task and task execution history externally, allowing agents to access
memory from many loops prior. This is done through similarity search over text.

Intuitively, when we as humans want to remember something, we try to think of something related. Eventually,
we find a collection of information related that topic in our head and act upon it.
Intuitively, when we as humans want to remember something, we try to think of related words or phrases. Eventually,
we find a collection of information related to that topic in our head and act upon it.
This framework is similar to how a Vector DB operates.

## Weaviate
Expand Down
538 changes: 344 additions & 194 deletions platform/poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions platform/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ replicate = "^0.8.3"
lanarky = "^0.7.6"
weaviate-client = "^3.19.2"
tiktoken = "^0.4.0"
numpy = "^1.24.3"


[tool.poetry.dev-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion platform/reworkd_platform/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def validate_task_count(self, run_id: str, type_: str) -> None:

if task_count >= max_:
raise PlatformaticError(
StopIteration, f"Max loops of {max_} exceeded, shutting down.", 429
StopIteration(), f"Max loops of {max_} exceeded, shutting down.", 429
)


Expand Down
69 changes: 0 additions & 69 deletions platform/reworkd_platform/web/api/agent/memory/memory.py

This file was deleted.

4 changes: 4 additions & 0 deletions platform/reworkd_platform/web/api/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""API for handling agent memory"""
from reworkd_platform.web.api.memory.views import router

__all__ = ["router"]
23 changes: 23 additions & 0 deletions platform/reworkd_platform/web/api/memory/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from abc import ABC
from typing import List, Tuple

SimilarTasks = List[Tuple[str, float]]


class AgentMemory(ABC):
"""
Base class for AgentMemory
Expose __enter__ and __exit__ to ensure connections get closed within requests
"""

def __enter__(self):
pass

def __exit__(self, exc_type, exc_value, traceback) -> None:
pass

def add_tasks(self, task: List[str]) -> None:
pass

def get_similar_tasks(self, query: str, score_threshold: float) -> List[str]:
pass
22 changes: 22 additions & 0 deletions platform/reworkd_platform/web/api/memory/null.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import List

from reworkd_platform.web.api.memory.memory import AgentMemory


class NullAgentMemory(AgentMemory):
"""
NullObjectPattern for AgentMemory
Used when database connections cannot be established
"""

def __enter__(self):
pass

def __exit__(self, exc_type, exc_value, traceback) -> None:
pass

def add_tasks(self, task: List[str]) -> None:
pass

def get_similar_tasks(self, query: str, score_threshold: float) -> List[str]:
return []
36 changes: 36 additions & 0 deletions platform/reworkd_platform/web/api/memory/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import List

from fastapi import APIRouter
from pydantic import BaseModel

from reworkd_platform.web.api.memory.memory import SimilarTasks
from reworkd_platform.web.api.memory.weaviate import WeaviateMemory

router = APIRouter()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This router should have an admin only dependancy



class MemoryAdd(BaseModel):
class_name: str
tasks: List[str]


@router.post("/memory/add")
def add_task_memory(req_body: MemoryAdd) -> List[str]:
with WeaviateMemory(req_body.class_name) as memory:
ids = memory.add_tasks(req_body.tasks)
return ids


@router.get("/memory/get")
def get_task_memory(
class_name: str, query: str, score_threshold: float = 0.7
) -> SimilarTasks:
with WeaviateMemory(class_name) as memory:
similar_tasks = memory.get_similar_tasks(query, score_threshold)
return similar_tasks


@router.delete("/memory/delete")
def delete_class(class_name: str) -> None:
with WeaviateMemory(class_name) as memory:
memory.delete_class()
121 changes: 121 additions & 0 deletions platform/reworkd_platform/web/api/memory/weaviate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from __future__ import annotations

from typing import List, Dict, cast, Tuple

import numpy as np
import weaviate
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Weaviate
from loguru import logger
from weaviate import UnexpectedStatusCodeException

from reworkd_platform.settings import settings
from reworkd_platform.web.api.memory.memory import SimilarTasks, AgentMemory


def _default_schema(index_name: str, text_key: str) -> Dict:
return {
"class": index_name,
"properties": [
{
"name": text_key,
"dataType": ["text"],
}
],
}


CLASS_PREFIX = "Reworkd_AgentGPT_"


class WeaviateMemory(AgentMemory):
"""
Wrapper around the Weaviate vector database
"""

db: Weaviate = None

def __init__(self, index_name: str):
self.index_name = CLASS_PREFIX + index_name
self.text_key = "agent_memory"

def __enter__(self):
# If the database requires authentication, retrieve the API key
auth = (
weaviate.auth.AuthApiKey(api_key=settings.vector_db_api_key)
if settings.vector_db_api_key is not None
and settings.vector_db_api_key != ""
else None
)
self.client = weaviate.Client(settings.vector_db_url, auth_client_secret=auth)

# Create the schema if it doesn't already exist
schema = _default_schema(self.index_name, self.text_key)
if not self.client.schema.contains(schema):
self.client.schema.create_class(schema)

# Instantiate client with embedding provider
self.embeddings = OpenAIEmbeddings(openai_api_key=settings.openai_api_key)
self.db = Weaviate(
self.client,
self.index_name,
self.text_key,
embedding=self.embeddings,
by_text=False,
)

return self

def __exit__(self, exc_type, exc_value, traceback):
self.client.__del__()

def add_tasks(self, tasks: List[str]) -> List[str]:
return self.db.add_texts(tasks)

def get_similar_tasks(
self, query: str, score_threshold: float = 0.7
) -> SimilarTasks:
# Get similar tasks
results = self._similarity_search_with_score(query)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
results = self._similarity_search_with_score(query)
results: Tuple[str, float] = self._similarity_search_with_score(query)


def get_score(result: Tuple[str, float]) -> float:
return result[1]

results.sort(key=get_score, reverse=True)

# Return formatted response
return [(text, score) for [text, score] in results if score >= score_threshold]

def reset_class(self):
try:
self.client.schema.delete_class(self.index_name)
except UnexpectedStatusCodeException as error:
logger.error(error)

def _similarity_search_with_score(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should add this to langchain!

self, query: str, k: int = 4
) -> List[Tuple[str, float]]:
"""
A remake of _similarity_search_with_score from langchain to use a near vector
"""
# Build query
query_obj = self.client.query.get(self.index_name, [self.text_key])
embedding = self.embeddings.embed_query(query)
vector = {"vector": embedding}

result = (
query_obj.with_near_vector(vector)
.with_limit(k)
.with_additional("vector")
.do()
)

if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")

docs_and_scores: list[tuple[str, float]] = []
for res in result["data"]["Get"][self.index_name]:
text = cast(str, res.pop(self.text_key))
score = float(np.dot(res["_additional"]["vector"], embedding))
docs_and_scores.append((text, score))
return docs_and_scores
2 changes: 2 additions & 0 deletions platform/reworkd_platform/web/api/router.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from fastapi.routing import APIRouter

from reworkd_platform.web.api import agent
from reworkd_platform.web.api import memory
from reworkd_platform.web.api import monitoring

api_router = APIRouter()
api_router.include_router(monitoring.router, prefix="/monitoring", tags=["monitoring"])
api_router.include_router(agent.router, prefix="/agent", tags=["agent"])
api_router.include_router(memory.router, prefix="/memory", tags=["memory"])