Skip to content

Commit

Permalink
refactor: move API to standardized pydantic schemas across CLI, Pytho…
Browse files Browse the repository at this point in the history
…n client, REST server (#1579)

Co-authored-by: cpacker <packercharles@gmail.com>
Co-authored-by: matthew zhou <matthewzhou@matthews-MacBook-Pro.local>
Co-authored-by: Zack Field <field.zackery@gmail.com>
  • Loading branch information
4 people authored Aug 17, 2024
1 parent 41581e3 commit 7ff033f
Show file tree
Hide file tree
Showing 112 changed files with 8,942 additions and 8,049 deletions.
1 change: 0 additions & 1 deletion .github/workflows/docker-integration-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ jobs:
pipx install poetry==1.8.2
poetry install -E dev -E postgres
poetry run pytest -s tests/test_client.py
poetry run pytest -s tests/test_concurrent_connections.py
- name: Print docker logs if tests fail
if: failure()
Expand Down
16 changes: 3 additions & 13 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: Run All pytest Tests

env:
MEMGPT_PGURI: ${{ secrets.MEMGPT_PGURI }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

on:
push:
Expand Down Expand Up @@ -32,10 +33,10 @@ jobs:
with:
python-version: "3.12"
poetry-version: "1.8.2"
install-args: "-E dev -E postgres -E milvus"
install-args: "-E dev -E postgres -E milvus -E crewai-tools"

- name: Initialize credentials
run: poetry run memgpt quickstart --backend memgpt
run: poetry run memgpt quickstart --backend openai

#- name: Run docker compose server
# env:
Expand Down Expand Up @@ -69,14 +70,3 @@ jobs:
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
run: |
poetry run pytest -s -vv -k "not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client" tests
- name: Run storage tests
env:
MEMGPT_PG_PORT: 8888
MEMGPT_PG_USER: memgpt
MEMGPT_PG_PASSWORD: memgpt
MEMGPT_PG_HOST: localhost
MEMGPT_PG_DB: memgpt
MEMGPT_SERVER_PASS: test_server_token
run: |
poetry run pytest -s -vv tests/test_storage.py
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,7 @@ FodyWeavers.xsd
## cached db data
pgdata/
!pgdata/.gitkeep
.persist/

## pytest mirrors
memgpt/.pytest_cache/
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ WORKDIR /app
COPY pyproject.toml poetry.lock ./
RUN poetry lock --no-update
RUN if [ "$MEMGPT_ENVIRONMENT" = "DEVELOPMENT" ] ; then \
poetry install --no-root -E "postgres server dev autogen" ; \
poetry install --no-root -E "postgres server dev" ; \
else \
poetry install --no-root -E "postgres server" && \
rm -rf $POETRY_CACHE_DIR ; \
Expand Down
287 changes: 186 additions & 101 deletions memgpt/agent.py

Large diffs are not rendered by default.

125 changes: 86 additions & 39 deletions memgpt/agent_store/chroma.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import uuid
from typing import Dict, Iterator, List, Optional, Tuple, cast
from typing import Dict, List, Optional, Tuple, cast

import chromadb
from chromadb.api.types import Include

from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.config import MemGPTConfig
from memgpt.data_types import Passage, Record, RecordType
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.passage import Passage
from memgpt.utils import datetime_to_timestamp, printd, timestamp_to_datetime


Expand Down Expand Up @@ -34,9 +34,6 @@ def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None
self.collection = self.client.get_or_create_collection(self.table_name)
self.include: Include = ["documents", "embeddings", "metadatas"]

# need to be converted to strings
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]

def get_filters(self, filters: Optional[Dict] = {}) -> Tuple[list, dict]:
# get all filters for query
if filters is not None:
Expand All @@ -54,10 +51,7 @@ def get_filters(self, filters: Optional[Dict] = {}) -> Tuple[list, dict]:
continue

# filter by other keys
if key in self.uuid_fields:
chroma_filters.append({key: {"$eq": str(value)}})
else:
chroma_filters.append({key: {"$eq": value}})
chroma_filters.append({key: {"$eq": value}})

if len(chroma_filters) > 1:
chroma_filters = {"$and": chroma_filters}
Expand All @@ -67,7 +61,7 @@ def get_filters(self, filters: Optional[Dict] = {}) -> Tuple[list, dict]:
chroma_filters = chroma_filters[0]
return ids, chroma_filters

def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000, offset: int = 0) -> Iterator[List[RecordType]]:
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000, offset: int = 0):
ids, filters = self.get_filters(filters)
while True:
# Retrieve a chunk of records with the given page_size
Expand All @@ -84,29 +78,50 @@ def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000,
# Increment the offset to get the next chunk in the next iteration
offset += page_size

def results_to_records(self, results) -> List[RecordType]:
def results_to_records(self, results):
# convert timestamps to datetime
for metadata in results["metadatas"]:
if "created_at" in metadata:
metadata["created_at"] = timestamp_to_datetime(metadata["created_at"])
for key, value in metadata.items():
if key in self.uuid_fields:
metadata[key] = uuid.UUID(value)
if results["embeddings"]: # may not be returned, depending on table type
return [
cast(RecordType, self.type(text=text, embedding=embedding, id=uuid.UUID(record_id), **metadatas)) # type: ignore
for (text, record_id, embedding, metadatas) in zip(
results["documents"], results["ids"], results["embeddings"], results["metadatas"]
)
]
passages = []
for text, record_id, embedding, metadata in zip(
results["documents"], results["ids"], results["embeddings"], results["metadatas"]
):
args = {}
for field in EmbeddingConfig.__fields__.keys():
if field in metadata:
args[field] = metadata[field]
del metadata[field]
embedding_config = EmbeddingConfig(**args)
passages.append(Passage(text=text, embedding=embedding, id=record_id, embedding_config=embedding_config, **metadata))
# return [
# Passage(text=text, embedding=embedding, id=record_id, embedding_config=EmbeddingConfig(), **metadatas)
# for (text, record_id, embedding, metadatas) in zip(
# results["documents"], results["ids"], results["embeddings"], results["metadatas"]
# )
# ]
return passages
else:
# no embeddings
return [
cast(RecordType, self.type(text=text, id=uuid.UUID(id), **metadatas)) # type: ignore
for (text, id, metadatas) in zip(results["documents"], results["ids"], results["metadatas"])
]

def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[RecordType]:
passages = []
for text, id, metadata in zip(results["documents"], results["ids"], results["metadatas"]):
args = {}
for field in EmbeddingConfig.__fields__.keys():
if field in metadata:
args[field] = metadata[field]
del metadata[field]
embedding_config = EmbeddingConfig(**args)
passages.append(Passage(text=text, embedding=None, id=id, embedding_config=embedding_config, **metadata))
return passages

# return [
# #cast(Passage, self.type(text=text, id=uuid.UUID(id), **metadatas)) # type: ignore
# Passage(text=text, embedding=None, id=id, **metadatas)
# for (text, id, metadatas) in zip(results["documents"], results["ids"], results["metadatas"])
# ]

def get_all(self, filters: Optional[Dict] = {}, limit=None):
ids, filters = self.get_filters(filters)
if self.collection.count() == 0:
return []
Expand All @@ -116,13 +131,13 @@ def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[RecordType]:
results = self.collection.get(ids=ids, include=self.include, where=filters)
return self.results_to_records(results)

def get(self, id: uuid.UUID) -> Optional[RecordType]:
def get(self, id):
results = self.collection.get(ids=[str(id)])
if len(results["ids"]) == 0:
return None
return self.results_to_records(results)[0]

def format_records(self, records: List[RecordType]):
def format_records(self, records):
assert all([isinstance(r, Passage) for r in records])

recs = []
Expand All @@ -145,34 +160,36 @@ def format_records(self, records: List[RecordType]):
# collect/format record metadata
metadatas = []
for record in recs:
embedding_config = vars(record.embedding_config)
metadata = vars(record)
metadata.pop("id")
metadata.pop("text")
metadata.pop("embedding")
metadata.pop("embedding_config")
metadata.pop("metadata_")
if "created_at" in metadata:
metadata["created_at"] = datetime_to_timestamp(metadata["created_at"])
if "metadata_" in metadata and metadata["metadata_"] is not None:
record_metadata = dict(metadata["metadata_"])
metadata.pop("metadata_")
else:
record_metadata = {}
metadata = {key: value for key, value in metadata.items() if value is not None} # null values not allowed

metadata = {**metadata, **record_metadata} # merge with metadata
metadata = {**metadata, **embedding_config} # merge with embedding config
metadata = {key: value for key, value in metadata.items() if value is not None} # null values not allowed

# convert uuids to strings
for key, value in metadata.items():
if key in self.uuid_fields:
metadata[key] = str(value)
metadatas.append(metadata)
return ids, documents, embeddings, metadatas

def insert(self, record: Record):
def insert(self, record):
ids, documents, embeddings, metadatas = self.format_records([record])
if any([e is None for e in embeddings]):
raise ValueError("Embeddings must be provided to chroma")
self.collection.upsert(documents=documents, embeddings=[e for e in embeddings if e is not None], ids=ids, metadatas=metadatas)

def insert_many(self, records: List[RecordType], show_progress=False):
def insert_many(self, records, show_progress=False):
ids, documents, embeddings, metadatas = self.format_records(records)
if any([e is None for e in embeddings]):
raise ValueError("Embeddings must be provided to chroma")
Expand All @@ -198,7 +215,7 @@ def size(self, filters: Optional[Dict] = {}) -> int:
def list_data_sources(self):
raise NotImplementedError

def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
ids, filters = self.get_filters(filters)
results = self.collection.query(query_embeddings=[query_vec], n_results=top_k, include=self.include, where=filters)

Expand Down Expand Up @@ -239,10 +256,40 @@ def query_text(self, query, count=None, start=None, filters: Optional[Dict] = {}
def get_all_cursor(
self,
filters: Optional[Dict] = {},
after: uuid.UUID = None,
before: uuid.UUID = None,
after: str = None,
before: str = None,
limit: Optional[int] = 1000,
order_by: str = "created_at",
reverse: bool = False,
):
raise ValueError("Cannot run get_all_cursor with chroma")
records = self.get_all(filters=filters)

# WARNING: very hacky and slow implementation
def get_index(id, record_list):
for i in range(len(record_list)):
if record_list[i].id == id:
return i
assert False, f"Could not find id {id} in record list"

# sort by custom field
records = sorted(records, key=lambda x: getattr(x, order_by), reverse=reverse)
if after:
index = get_index(after, records)
if index + 1 >= len(records):
return None, []
records = records[index + 1 :]
if before:
index = get_index(before, records)
if index == 0:
return None, []

# TODO: not sure if this is correct
records = records[:index]

if len(records) == 0:
return None, []

# enforce limit
if limit:
records = records[:limit]
return records[-1].id, records
Loading

0 comments on commit 7ff033f

Please sign in to comment.