Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lancedb #455

Merged
merged 16 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def configure_cli(config: MemGPTConfig):

def configure_archival_storage(config: MemGPTConfig):
# Configure archival storage backend
archival_storage_options = ["local", "postgres"]
archival_storage_options = ["local", "lancedb", "postgres"]
archival_storage_type = questionary.select(
"Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type
).ask()
Expand All @@ -220,8 +220,16 @@ def configure_archival_storage(config: MemGPTConfig):
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
default=config.archival_storage_uri if config.archival_storage_uri else "",
).ask()

if archival_storage_type == "lancedb":
archival_storage_uri = questionary.text(
"Enter lanncedb connection string (e.g. ./.lancedb",
default=config.archival_storage_uri if config.archival_storage_uri else "./.lancedb",
).ask()

return archival_storage_type, archival_storage_uri

# TODO: allow configuring embedding model

@app.command()
def configure():
Expand Down
133 changes: 133 additions & 0 deletions memgpt/connectors/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Optional, List, Iterator
import numpy as np
from tqdm import tqdm
import pandas as pd

from memgpt.config import MemGPTConfig
from memgpt.connectors.storage import StorageConnector, Passage
Expand Down Expand Up @@ -181,3 +182,135 @@ def generate_table_name_agent(self, agent_config: AgentConfig):

def generate_table_name(self, name: str):
return f"memgpt_{self.sanitize_table_name(name)}"


class LanceDBConnector(StorageConnector):
"""Storage via LanceDB"""

# TODO: this should probably eventually be moved into a parent DB class

def __init__(self, name: Optional[str] = None):
config = MemGPTConfig.load()

# determine table name
if name:
self.table_name = self.generate_table_name(name)
else:
self.table_name = "lancedb_tbl"


printd(f"Using table name {self.table_name}")

# create table
self.uri = config.archival_storage_uri
if config.archival_storage_uri is None:
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
import lancedb
self.db = lancedb.connect(self.uri)
self.table = None

def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]:
session = self.Session()
offset = 0
while True:
# Retrieve a chunk of records with the given page_size
db_passages_chunk = self.table.search().limit(page_size).to_list()

# If the chunk is empty, we've retrieved all records
if not db_passages_chunk:
break

# Yield a list of Passage objects converted from the chunk
yield [Passage(text=p["text"], embedding=p["vector"], doc_id=p["doc_id"], passage_id=p["passage_id"]) for p in db_passages_chunk]

# Increment the offset to get the next chunk in the next iteration
offset += page_size

def get_all(self, limit=10) -> List[Passage]:
db_passages = self.table.search().limit(limit).to_list()
return [Passage(text=p["text"], embedding=p["vector"], doc_id=p["doc_id"], passage_id=p["passage_id"]) for p in db_passages]

def get(self, id: str) -> Optional[Passage]:
db_passage = self.table.where(f"passage_id={id}").to_list()
if len(db_passage) == 0:
return None
return Passage(text=db_passage["text"], embedding=db_passage["embedding"], doc_id=db_passage["doc_id"], passage_id=db_passage["passage_id"])

def size(self) -> int:
# return size of table
if self.table:
return len(self.table.search().to_list())
else:
print(f"Table with name {self.table_name} not present")
return 0

def insert(self, passage: Passage):
data = [{"doc_id": passage.doc_id, "text":passage.text, "passage_id": passage.passage_id, "vector":passage.embedding }]

if self.table:
self.table.add(data)
else:
self.table = self.db.create_table(self.table_name, data=data, mode="overwrite")

def insert_many(self, passages: List[Passage], show_progress=True):
data = []
iterable = tqdm(passages) if show_progress else passages
for passage in iterable:
temp_dict = {"doc_id": passage.doc_id, "text":passage.text, "passage_id": passage.passage_id, "vector":passage.embedding }
data.append(temp_dict)

if self.table:
self.table.add(data)
else:
self.table = self.db.create_table(self.table_name, data=data, mode="overwrite")


def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]:
# Assuming query_vec is of same length as embeddings inside table
results = self.table.search(query_vec).limit(top_k)

# Convert the results into Passage objects
passages = [
Passage(text=result["text"], embedding=result["embedding"], doc_id=result["doc_id"], passage_id=result["passage_id"])
for result in results
]
return passages

def delete(self):
"""Drop the passage table from the database."""
# Drop the table specified by the PassageModel class
self.db.drop_table(self.table_name)

def save(self):
return

@staticmethod
def list_loaded_data():
config = MemGPTConfig.load()
import lancedb
db = lancedb.connect(config.archival_storage_uri)

tables = db.table_names()
tables = [table for table in tables if table.startswith("memgpt_")]
tables = [table.replace("memgpt_", "") for table in tables]
return tables

def sanitize_table_name(self, name: str) -> str:
# Remove leading and trailing whitespace
name = name.strip()

# Replace spaces and invalid characters with underscores
name = re.sub(r"\s+|\W+", "_", name)

# Truncate to the maximum identifier length
max_length = 63
if len(name) > max_length:
name = name[:max_length].rstrip("_")

# Convert to lowercase
name = name.lower()

return name

def generate_table_name(self, name: str):
return f"memgpt_{self.sanitize_table_name(name)}"
10 changes: 10 additions & 0 deletions memgpt/connectors/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def get_storage_connector(name: Optional[str] = None, agent_config: Optional[Age
from memgpt.connectors.db import PostgresStorageConnector

return PostgresStorageConnector(name=name, agent_config=agent_config)

elif storage_type == "lancedb":
from memgpt.connectors.db import LanceDBConnector

return LanceDBConnector(name=name)

else:
raise NotImplementedError(f"Storage type {storage_type} not implemented")
Expand All @@ -62,6 +67,11 @@ def list_loaded_data():
from memgpt.connectors.db import PostgresStorageConnector

return PostgresStorageConnector.list_loaded_data()

elif storage_type == "lancedb":
from memgpt.connectors.db import LanceDBConnector

return LanceDBConnector.list_loaded_data()
else:
raise NotImplementedError(f"Storage type {storage_type} not implemented")

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pg8000 = {version = "^1.30.3", optional = true}
torch = {version = ">=2.0.0, !=2.0.1, !=2.1.0", optional = true}
websockets = "^12.0"
docstring-parser = "^0.15"
lancedb = "^0.3.3"
sarahwooders marked this conversation as resolved.
Show resolved Hide resolved

[tool.poetry.extras]
legacy = ["faiss-cpu", "numpy"]
Expand Down
33 changes: 33 additions & 0 deletions tests/test_load_archival.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,39 @@ def test_postgres():
recursive=True,
)

def test_lancedb():
return

subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"])
import lancedb # Try to import again after installing

# override config path with enviornment variable
# TODO: make into temporary file
os.environ["MEMGPT_CONFIG_PATH"] = "test_config.cfg"
print("env", os.getenv("MEMGPT_CONFIG_PATH"))
config = memgpt.config.MemGPTConfig(archival_storage_type="lancedb", config_path=os.getenv("MEMGPT_CONFIG_PATH"))
print(config)
config.save()

#loading dataset from hugging face
name = "tmp_hf_dataset"

dataset = load_dataset("MemGPT/example_short_stories")

cache_dir = os.getenv("HF_DATASETS_CACHE")
if cache_dir is None:
# Construct the default path if the environment variable is not set.
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets")

config = memgpt.config.MemGPTConfig(archival_storage_type="lancedb")

load_directory(
name=name,
input_dir=cache_dir,
recursive=True,
)



def test_chroma():
return
Expand Down
59 changes: 57 additions & 2 deletions tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pgvector # Try to import again after installing

from memgpt.connectors.storage import StorageConnector, Passage
from memgpt.connectors.db import PostgresStorageConnector
from memgpt.connectors.db import PostgresStorageConnector, LanceDBConnector
from memgpt.embeddings import embedding_model
from memgpt.config import MemGPTConfig, AgentConfig

Expand Down Expand Up @@ -56,6 +56,34 @@ def test_postgres_openai():
# db.delete()
# print("...finished")

def test_lancedb_openai():
assert os.getenv("LANCEDB_TEST_URL") is not None
if os.getenv("OPENAI_API_KEY") is None:
return # soft pass

config = MemGPTConfig(archival_storage_type="lancedb", archival_storage_uri=os.getenv("LANCEDB_TEST_URL"))
print(config.config_path)
assert config.archival_storage_uri is not None
print(config)

embed_model = embedding_model()

passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]

db = LanceDBConnector(name="test-openai")

for passage in passage:
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))

print(db.get_all())

query = "why was she crying"
query_vec = embed_model.get_text_embedding(query)
res = db.query(None, query_vec, top_k=2)

assert len(res) == 2, f"Expected 2 results, got {len(res)}"
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"


@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL"), reason="Missing PG URI")
def test_postgres_local():
Expand Down Expand Up @@ -100,5 +128,32 @@ def test_postgres_local():
# db.delete()
# print("...finished")

def test_lancedb_local():
assert os.getenv("LANCEDB_TEST_URL") is not None

config = MemGPTConfig(
archival_storage_type="lancedb",
archival_storage_uri=os.getenv("LANCEDB_TEST_URL"),
embedding_model="local",
embedding_dim=384, # use HF model
)
print(config.config_path)
assert config.archival_storage_uri is not None

embed_model = embedding_model()

passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]

# test_postgres()
db = LanceDBConnector(name="test-local")

for passage in passage:
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))

print(db.get_all())

query = "why was she crying"
query_vec = embed_model.get_text_embedding(query)
res = db.query(None, query_vec, top_k=2)

assert len(res) == 2, f"Expected 2 results, got {len(res)}"
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"