diff --git a/agentmemory/client.py b/agentmemory/client.py index 9276185..ed6633b 100644 --- a/agentmemory/client.py +++ b/agentmemory/client.py @@ -11,7 +11,7 @@ CLIENT_TYPE = os.environ.get("CLIENT_TYPE", DEFAULT_CLIENT_TYPE) STORAGE_PATH = os.environ.get("STORAGE_PATH", "./memory") POSTGRES_CONNECTION_STRING = os.environ.get("POSTGRES_CONNECTION_STRING") - +POSTGRES_MODEL_NAME = os.environ.get("POSTGRES_MODEL_NAME", "all-MiniLM-L6-v2") client = None @@ -28,7 +28,7 @@ def get_client(client_type=None, *args, **kwargs): raise EnvironmentError( "Postgres connection string not set in environment variables!" ) - client = PostgresClient(POSTGRES_CONNECTION_STRING) + client = PostgresClient(POSTGRES_CONNECTION_STRING, model_name=POSTGRES_MODEL_NAME) else: client = chromadb.PersistentClient(path=STORAGE_PATH, *args, **kwargs) diff --git a/agentmemory/clustering.py b/agentmemory/clustering.py index 77cdd00..afe86ba 100644 --- a/agentmemory/clustering.py +++ b/agentmemory/clustering.py @@ -11,7 +11,6 @@ def cluster(epsilon, min_samples, category, filter_metadata=None, novel=False): cluster_id = 0 for memory in memories: memory_id = memory["id"] - print("Memory ID: ", memory_id) if visited[memory_id]: continue visited[memory_id] = True diff --git a/agentmemory/main.py b/agentmemory/main.py index 49a325c..e65b45a 100644 --- a/agentmemory/main.py +++ b/agentmemory/main.py @@ -1,4 +1,7 @@ import datetime +import os + +os.environ["TOKENIZERS_PARALLELISM"] = "false" from agentmemory.helpers import ( chroma_collection_to_list, @@ -10,7 +13,6 @@ from agentmemory.client import get_client - def create_memory(category, text, metadata={}, embedding=None, id=None): """ Create a new memory in a collection. @@ -344,8 +346,6 @@ def update_memory(category, id, text=None, metadata=None, embedding=None): documents = [text] if text is not None else None metadatas = [metadata] if metadata is not None else None embeddings = [embedding] if embedding is not None else None - print('********************** UPDATE') - print(id, documents, metadatas, embeddings) # Update the memory with the new text and/or metadata memories.update( ids=[str(id)], documents=documents, metadatas=metadatas, embeddings=embeddings diff --git a/agentmemory/postgres.py b/agentmemory/postgres.py index 184c368..f302f68 100644 --- a/agentmemory/postgres.py +++ b/agentmemory/postgres.py @@ -1,10 +1,44 @@ -import json from pathlib import Path import psycopg2 from agentmemory.check_model import check_model, infer_embeddings +def handle_and_condition(and_conditions): + conditions = [] + params = [] + for condition in and_conditions: + for key, value in condition.items(): + for operator, operand in value.items(): + sql_operator = get_sql_operator(operator) + conditions.append(f"{key} {sql_operator} %s") + params.append(operand) + return conditions, params + + +def handle_or_condition(or_conditions): + or_groups = [] + params = [] + for condition in or_conditions: + conditions, new_params = handle_and_condition([condition]) + or_groups.append(" AND ".join(conditions)) + params.extend(new_params) + return f"({') OR ('.join(or_groups)})", params + + +def get_sql_operator(operator): + if operator == "$eq": + return "=" + elif operator == "$ne": + return "!=" + elif operator == "$gt": + return ">" + elif operator == "$lt": + return "<" + else: + raise ValueError(f"Operator {operator} not supported") + + class PostgresCollection: def __init__(self, category, client): self.category = category @@ -40,37 +74,61 @@ def get( ): category = self.category table_name = self.client._table_name(category) + conditions = [] + params = [] - if not ids: - if limit is None: - limit = 100 # or another default value - if offset is None: - offset = 0 + if where_document is not None: + if where_document.get("$contains", None) is not None: + where_document = where_document["$contains"] + conditions.append("document LIKE %s") + params.append(f"%{where_document}%") - query = f"SELECT * FROM {table_name} LIMIT %s OFFSET %s" - params = (limit, offset) + if where: + for key, value in where.items(): + if key == "$and": + new_conditions, new_params = handle_and_condition(value) + conditions.extend(new_conditions) + params.extend(new_params) + elif key == "$or": + or_condition, new_params = handle_or_condition(value) + conditions.append(or_condition) + params.extend(new_params) + elif key == "$contains": + conditions.append(f"document LIKE %s") + params.append(f"%{value}%") + else: + conditions.append(f"{key}=%s") + params.append(str(value)) - else: + if ids: if not all(isinstance(i, str) or isinstance(i, int) for i in ids): raise Exception( "ids must be a list of integers or strings representing integers" ) + ids = [int(i) for i in ids] + conditions.append("id=ANY(%s)") + params.append(ids) - if limit is None: - limit = len(ids) - if offset is None: - offset = 0 + if limit is None: + limit = 100 # or another default value + if offset is None: + offset = 0 - ids = [int(i) for i in ids] - query = f"SELECT * FROM {table_name} WHERE id=ANY(%s) LIMIT %s OFFSET %s" - params = (ids, limit, offset) + query = f"SELECT * FROM {table_name}" + if conditions: + query += " WHERE " + " AND ".join(conditions) + query += " LIMIT %s OFFSET %s" + params.extend([limit, offset]) + + self.client.cur.execute(query, tuple(params)) - self.client.cur.execute(query, params) rows = self.client.cur.fetchall() # Convert rows to list of dictionaries columns = [desc[0] for desc in self.client.cur.description] - metadata_columns = [col for col in columns if col not in ["id", "document", "embedding"]] + metadata_columns = [ + col for col in columns if col not in ["id", "document", "embedding"] + ] result = [] for row in rows: @@ -85,7 +143,6 @@ def get( "metadatas": [row["metadata"] for row in result], } - def peek(self, limit=10): return self.get(limit=limit) @@ -98,7 +155,9 @@ def query( where_document=None, include=["metadatas", "documents", "distances"], ): - return self.client.query(self.category, query_texts, n_results) + return self.client.query( + self.category, query_texts, n_results, where, where_document + ) def update(self, ids, documents=None, metadatas=None, embeddings=None): self.client.ensure_table_exists(self.category) @@ -107,8 +166,6 @@ def update(self, ids, documents=None, metadatas=None, embeddings=None): if documents is None: documents = [None] * len(ids) for id_, document, metadata in zip(ids, documents, metadatas): - print("updating") - print(id_, document, metadata) self.client.update(self.category, id_, document, metadata) else: for id_, document, metadata, emb in zip( @@ -121,37 +178,43 @@ def upsert(self, ids, documents=None, metadatas=None, embeddings=None): def delete(self, ids=None, where=None, where_document=None): table_name = self.client._table_name(self.category) - # check if table exists - self.client.ensure_table_exists(self.category) - - # Base of the query - query = f"DELETE FROM {table_name}" + conditions = [] params = [] - conditions = [] + if where_document is not None: + if where_document.get("$contains", None) is not None: + where_document = where_document["$contains"] + conditions.append("document LIKE %s") + params.append(f"%{where_document}%") - if ids is not None: - if not all(isinstance(i, (int, str)) and str(i).isdigit() for i in ids): + if ids: + if not all(isinstance(i, str) or isinstance(i, int) for i in ids): raise Exception( "ids must be a list of integers or strings representing integers" ) ids = [int(i) for i in ids] - conditions.append("id=ANY(%s::int[])") + conditions.append("id=ANY(%s::int[])") # Added explicit type casting params.append(ids) - if where_document is not None: - if "$contains" in where_document: - conditions.append("document LIKE %s") - params.append(f"%{where_document['$contains']}%") - # You can add more operators for 'where_document' here if needed - - if where is not None: + if where: for key, value in where.items(): - conditions.append(f"{key}=%s") - params.append(value) + if key == "$and": + new_conditions, new_params = handle_and_condition(value) + conditions.extend(new_conditions) + params.extend(new_params) + elif key == "$or": + or_condition, new_params = handle_or_condition(value) + conditions.append(or_condition) + params.extend(new_params) + elif key == "$contains": + conditions.append(f"document LIKE %s") + params.append(f"%{value}%") + else: + conditions.append(f"{key}=%s") + params.append(str(value)) if conditions: - query += " WHERE " + " AND ".join(conditions) + query = f"DELETE FROM {table_name} WHERE " + " AND ".join(conditions) else: raise Exception("No valid conditions provided for deletion.") @@ -286,9 +349,40 @@ def add(self, category, documents, metadatas, ids): cur.execute(query, tuple(values)) self.connection.commit() - def query(self, category, query_texts, n_results=5): + def query( + self, category, query_texts, n_results=5, where=None, where_document=None + ): self.ensure_table_exists(category) table_name = self._table_name(category) + conditions = [] + params = [] + + # Check if where_document is given + if where_document: + if where_document.get("$contains", None) is not None: + where_document = where_document["$contains"] + conditions.append("document LIKE %s") + params.append(f"%{where_document}%") + + if where: + for key, value in where.items(): + if key == "$and": + new_conditions, new_params = handle_and_condition(value) + conditions.extend(new_conditions) + params.extend(new_params) + elif key == "$or": + or_condition, new_params = handle_or_condition(value) + conditions.append(or_condition) + params.extend(new_params) + elif key == "$contains": + conditions.append(f"document LIKE %s") + params.append(f"%{value}%") + else: + conditions.append(f"{key}=%s") + params.append(str(value)) + + where_clause = " WHERE " + " AND ".join(conditions) if conditions else "" + results = { "ids": [], "documents": [], @@ -299,14 +393,17 @@ def query(self, category, query_texts, n_results=5): with self.connection.cursor() as cur: for emb in query_texts: query_emb = self.create_embedding(emb) - cur.execute( - f""" + params_with_emb = [query_emb] + params + [query_emb, n_results] + string = f""" SELECT id, document, embedding, embedding <-> %s AS distance, * FROM {table_name} + {where_clause} ORDER BY embedding <-> %s LIMIT %s - """, - (query_emb, query_emb, n_results), + """ + cur.execute( + string, + tuple(params_with_emb), ) rows = cur.fetchall() columns = [desc[0] for desc in cur.description] @@ -363,4 +460,4 @@ def update(self, category, id_, document=None, metadata=None, embedding=None): def close(self): self.cur.close() - self.connection.close() + self.connection.close() diff --git a/agentmemory/tests/events.py b/agentmemory/tests/events.py index 511bbe5..c15e1ff 100644 --- a/agentmemory/tests/events.py +++ b/agentmemory/tests/events.py @@ -34,7 +34,6 @@ def test_create_event(): event = get_events()[0] assert event["document"] == "test event" assert event["metadata"]["test"] == "test" - print(event["metadata"]) assert int(event["metadata"]["epoch"]) == 1 wipe_category("events") wipe_category("epoch") diff --git a/setup.py b/setup.py index f321699..c923d74 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setup( name='agentmemory', - version='0.4.4', + version='0.4.5', description='Easy-to-use memory for agents, document search, knowledge graphing and more.', long_description=long_description, # added this line long_description_content_type="text/markdown", # and this line