Skip to content

Commit

Permalink
fix: insert_many checks exists_ok (#1861)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vivek Verma authored Oct 9, 2024
1 parent c6368f3 commit 6506a52
Showing 1 changed file with 23 additions and 7 deletions.
30 changes: 23 additions & 7 deletions letta/agent_store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,6 @@ def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Op
return records

def insert_many(self, records, exists_ok=True, show_progress=False):
pass

# TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
if len(records) == 0:
return
Expand Down Expand Up @@ -506,18 +504,36 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None)
# sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))

def insert_many(self, records, exists_ok=True, show_progress=False):
pass

# TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
if len(records) == 0:
return

added_ids = [] # avoid adding duplicates
# NOTE: this has not great performance due to the excessive commits
with self.session_maker() as session:
iterable = tqdm(records) if show_progress else records
for record in iterable:
# db_record = self.db_model(**vars(record))
db_record = self.db_model(**record.dict())
session.add(db_record)
session.commit()

if record.id in added_ids:
continue

existing_record = session.query(self.db_model).filter_by(id=record.id).first()
if existing_record:
if exists_ok:
fields = record.model_dump()
fields.pop("id")
session.query(self.db_model).filter(self.db_model.id == record.id).update(fields)
session.commit()
else:
raise ValueError(f"Record with id {record.id} already exists.")

else:
db_record = self.db_model(**record.dict())
session.add(db_record)
session.commit()

added_ids.append(record.id)

def insert(self, record, exists_ok=True):
self.insert_many([record], exists_ok=exists_ok)
Expand Down

0 comments on commit 6506a52

Please sign in to comment.