Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bulk add nodes and edges #205

Merged
merged 7 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/podcast/podcast_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ async def main(use_bulk: bool = True):
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
group_id='test',
)

# build communities
Expand Down
8 changes: 4 additions & 4 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
)
from graphiti_core.utils.bulk_utils import (
RawEpisode,
add_nodes_and_edges_bulk,
dedupe_edges_bulk,
dedupe_nodes_bulk,
extract_edge_dates_bulk,
Expand Down Expand Up @@ -451,10 +452,9 @@ async def add_episode_endpoint(episode_data: EpisodeData):
if not self.store_raw_episode_content:
episode.content = ''

await episode.save(self.driver)
await asyncio.gather(*[node.save(self.driver) for node in nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
await add_nodes_and_edges_bulk(
self.driver, [episode], episodic_edges, nodes, entity_edges
)

# Update any communities
if update_communities:
Expand Down
1 change: 1 addition & 0 deletions graphiti_core/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
load_dotenv()

DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))


def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
Expand Down
20 changes: 20 additions & 0 deletions graphiti_core/models/edges/edge_db_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN r.uuid AS uuid"""

EPISODIC_EDGE_SAVE_BULK = """
UNWIND $episodic_edges AS edge
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
MATCH (node:Entity {uuid: edge.target_node_uuid})
MERGE (episode)-[r:MENTIONS {uuid: edge.uuid}]->(node)
SET r = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at}
RETURN r.uuid AS uuid
"""

ENTITY_EDGE_SAVE = """
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})
Expand All @@ -14,6 +23,17 @@
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
RETURN r.uuid AS uuid"""

ENTITY_EDGE_SAVE_BULK = """
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
RETURN r.uuid AS uuid
"""

COMMUNITY_EDGE_SAVE = """
MATCH (community:Community {uuid: $community_uuid})
MATCH (node:Entity | Community {uuid: $entity_uuid})
Expand Down
17 changes: 17 additions & 0 deletions graphiti_core/models/nodes/node_db_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,29 @@
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid"""

EPISODIC_NODE_SAVE_BULK = """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
source: episode.source, content: episode.content,
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""

ENTITY_NODE_SAVE = """
MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid"""

ENTITY_NODE_SAVE_BULK = """
UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid})
SET n = {uuid: node.uuid, name: node.name, group_id: node.group_id, summary: node.summary, created_at: node.created_at}
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
RETURN n.uuid AS uuid
"""

COMMUNITY_NODE_SAVE = """
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
Expand Down
65 changes: 41 additions & 24 deletions graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,15 @@

import numpy as np
from neo4j import AsyncDriver, Query
from typing_extensions import LiteralString

from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
from graphiti_core.helpers import DEFAULT_DATABASE, lucene_sanitize, normalize_l2
from graphiti_core.helpers import (
DEFAULT_DATABASE,
USE_PARALLEL_RUNTIME,
lucene_sanitize,
normalize_l2,
)
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
Expand All @@ -38,7 +44,7 @@
DEFAULT_MIN_SCORE = 0.6
DEFAULT_MMR_LAMBDA = 0.5
MAX_SEARCH_DEPTH = 3
MAX_QUERY_LENGTH = 128
MAX_QUERY_LENGTH = 32


def fulltext_query(query: str, group_ids: list[str] | None = None):
Expand Down Expand Up @@ -187,8 +193,11 @@ async def edge_similarity_search(
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityEdge]:
# vector similarity search over embedded facts
query = Query("""
CYPHER runtime = parallel parallelRuntimeSupport=all
runtime_query: LiteralString = (
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
)

query: LiteralString = """
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
Expand All @@ -210,10 +219,10 @@ async def edge_similarity_search(
r.invalid_at AS invalid_at
ORDER BY score DESC
LIMIT $limit
""")
"""

records, _, _ = await driver.execute_query(
query,
runtime_query + query,
search_vector=search_vector,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
Expand Down Expand Up @@ -318,9 +327,13 @@ async def node_similarity_search(
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityNode]:
# vector similarity search over entity names
runtime_query: LiteralString = (
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
)

records, _, _ = await driver.execute_query(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all
runtime_query
+ """
MATCH (n:Entity)
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
Expand Down Expand Up @@ -425,23 +438,27 @@ async def community_similarity_search(
min_score=DEFAULT_MIN_SCORE,
) -> list[CommunityNode]:
# vector similarity search over entity names
runtime_query: LiteralString = (
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
)

records, _, _ = await driver.execute_query(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (comm:Community)
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
comm.uuid As uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.name_embedding AS name_embedding,
comm.created_at AS created_at,
comm.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
runtime_query
+ """
MATCH (comm:Community)
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
comm.uuid As uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.name_embedding AS name_embedding,
comm.created_at AS created_at,
comm.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
search_vector=search_vector,
group_ids=group_ids,
limit=limit,
Expand Down
39 changes: 38 additions & 1 deletion graphiti_core/utils/bulk_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,20 @@
from datetime import datetime
from math import ceil

from neo4j import AsyncDriver
from neo4j import AsyncDriver, AsyncManagedTransaction, Query

Check failure on line 24 in graphiti_core/utils/bulk_utils.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

graphiti_core/utils/bulk_utils.py:24:57: F401 `neo4j.Query` imported but unused
from numpy import dot, sqrt
from pydantic import BaseModel

from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
from graphiti_core.llm_client import LLMClient
from graphiti_core.models.edges.edge_db_queries import (
ENTITY_EDGE_SAVE_BULK,
EPISODIC_EDGE_SAVE_BULK,
)
from graphiti_core.models.nodes.node_db_queries import (
ENTITY_NODE_SAVE_BULK,
EPISODIC_NODE_SAVE_BULK,
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
from graphiti_core.utils import retrieve_episodes
Expand Down Expand Up @@ -75,6 +83,35 @@
return episode_tuples


async def add_nodes_and_edges_bulk(
driver: AsyncDriver,
episodic_nodes: list[EpisodicNode],
episodic_edges: list[EpisodicEdge],
entity_nodes: list[EntityNode],
entity_edges: list[EntityEdge],
):
async with driver.session() as session:
await session.execute_write(
add_nodes_and_edges_bulk_tx, episodic_nodes, episodic_edges, entity_nodes, entity_edges
)


async def add_nodes_and_edges_bulk_tx(
tx: AsyncManagedTransaction,
episodic_nodes: list[EpisodicNode],
episodic_edges: list[EpisodicEdge],
entity_nodes: list[EntityNode],
entity_edges: list[EntityEdge],
):
episodes = [dict(episode) for episode in episodic_nodes]
for episode in episodes:
episode['source'] = str(episode['source'].value)
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=[dict(entity) for entity in entity_nodes])
await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges])
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[dict(edge) for edge in entity_edges])


async def extract_nodes_and_edges_bulk(
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
Expand Down
Loading