Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use sessions search #197

Merged
merged 3 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions graphiti_core/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def delete(self, driver: AsyncDriver):
DELETE e
""",
uuid=self.uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

logger.debug(f'Deleted Edge: {self.uuid}')
Expand Down Expand Up @@ -82,7 +82,7 @@ async def save(self, driver: AsyncDriver):
uuid=self.uuid,
group_id=self.group_id,
created_at=self.created_at,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

logger.debug(f'Saved edge to neo4j: {self.uuid}')
Expand All @@ -102,7 +102,7 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
e.created_at AS created_at
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

edges = [get_episodic_edge_from_record(record) for record in records]
Expand All @@ -125,7 +125,7 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
e.created_at AS created_at
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

edges = [get_episodic_edge_from_record(record) for record in records]
Expand All @@ -148,7 +148,7 @@ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
e.created_at AS created_at
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

edges = [get_episodic_edge_from_record(record) for record in records]
Expand Down Expand Up @@ -202,7 +202,7 @@ async def save(self, driver: AsyncDriver):
expired_at=self.expired_at,
valid_at=self.valid_at,
invalid_at=self.invalid_at,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

logger.debug(f'Saved edge to neo4j: {self.uuid}')
Expand All @@ -229,7 +229,7 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
e.invalid_at AS invalid_at
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

edges = [get_entity_edge_from_record(record) for record in records]
Expand Down Expand Up @@ -259,7 +259,7 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
e.invalid_at AS invalid_at
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

edges = [get_entity_edge_from_record(record) for record in records]
Expand Down Expand Up @@ -289,7 +289,7 @@ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
e.invalid_at AS invalid_at
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

edges = [get_entity_edge_from_record(record) for record in records]
Expand All @@ -308,7 +308,7 @@ async def save(self, driver: AsyncDriver):
uuid=self.uuid,
group_id=self.group_id,
created_at=self.created_at,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

logger.debug(f'Saved edge to neo4j: {self.uuid}')
Expand All @@ -328,7 +328,7 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
e.created_at AS created_at
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

edges = [get_community_edge_from_record(record) for record in records]
Expand All @@ -349,7 +349,7 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
e.created_at AS created_at
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

edges = [get_community_edge_from_record(record) for record in records]
Expand All @@ -370,7 +370,7 @@ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
e.created_at AS created_at
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

edges = [get_community_edge_from_record(record) for record in records]
Expand Down
26 changes: 13 additions & 13 deletions graphiti_core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def delete(self, driver: AsyncDriver):
DETACH DELETE n
""",
uuid=self.uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

logger.debug(f'Deleted Node: {self.uuid}')
Expand Down Expand Up @@ -136,7 +136,7 @@ async def save(self, driver: AsyncDriver):
created_at=self.created_at,
valid_at=self.valid_at,
source=self.source.value,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

logger.debug(f'Saved Node to neo4j: {self.uuid}')
Expand All @@ -158,7 +158,7 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
e.source AS source
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

episodes = [get_episodic_node_from_record(record) for record in records]
Expand All @@ -184,7 +184,7 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
e.source AS source
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

episodes = [get_episodic_node_from_record(record) for record in records]
Expand All @@ -207,7 +207,7 @@ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
e.source AS source
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

episodes = [get_episodic_node_from_record(record) for record in records]
Expand Down Expand Up @@ -237,7 +237,7 @@ async def save(self, driver: AsyncDriver):
summary=self.summary,
name_embedding=self.name_embedding,
created_at=self.created_at,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

logger.debug(f'Saved Node to neo4j: {self.uuid}')
Expand All @@ -258,7 +258,7 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
n.summary AS summary
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

nodes = [get_entity_node_from_record(record) for record in records]
Expand All @@ -282,7 +282,7 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
n.summary AS summary
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

nodes = [get_entity_node_from_record(record) for record in records]
Expand All @@ -303,7 +303,7 @@ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
n.summary AS summary
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

nodes = [get_entity_node_from_record(record) for record in records]
Expand All @@ -324,7 +324,7 @@ async def save(self, driver: AsyncDriver):
summary=self.summary,
name_embedding=self.name_embedding,
created_at=self.created_at,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

logger.debug(f'Saved Node to neo4j: {self.uuid}')
Expand Down Expand Up @@ -354,7 +354,7 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
n.summary AS summary
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

nodes = [get_community_node_from_record(record) for record in records]
Expand All @@ -378,7 +378,7 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
n.summary AS summary
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

communities = [get_community_node_from_record(record) for record in records]
Expand All @@ -399,7 +399,7 @@ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
n.summary AS summary
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
database_=DEFAULT_DATABASE,
)

communities = [get_community_node_from_record(record) for record in records]
Expand Down
Loading