Skip to content

Commit

Permalink
Bulk add nodes and edges (#205)
Browse files Browse the repository at this point in the history
* test

* only use parallel runtime if set to true

* add and test bulk add

* remove group_ids

* format

* bump version

* update readme
  • Loading branch information
prasmussen15 authored Oct 31, 2024
1 parent 63a1b11 commit b8f5267
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 32 deletions.
15 changes: 13 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,17 @@ The `server` directory contains an API service for interacting with the Graphiti

Please see the [server README](./server/README.md) for more information.

## Optional Environment Variables

In addition to the Neo4j and OpenAi-compatible credentials, Graphiti also has a few optional environment variables.
If you are using one of our supported models, such as Anthropic or Voyage models, the necessary environment variables
must be set.

`USE_PARALLEL_RUNTIME` is an optional boolean variable that can be set to true if you wish
to enable Neo4j's parallel runtime feature for several of our search queries.
Note that this feature is not supported for Neo4j Community edition or for smaller AuraDB instances,
as such this feature is off by default.

## Documentation

- [Guides and API documentation](https://help.getzep.com/graphiti).
Expand All @@ -186,11 +197,11 @@ Graphiti is under active development. We aim to maintain API stability while wor
- [x] Implementing node and edge CRUD operations
- [ ] Improving performance and scalability
- [ ] Achieving good performance with different LLM and embedding models
- [ ] Creating a dedicated embedder interface
- [x] Creating a dedicated embedder interface
- [ ] Supporting custom graph schemas:
- Allow developers to provide their own defined node and edge classes when ingesting episodes
- Enable more flexible knowledge representation tailored to specific use cases
- [ ] Enhancing retrieval capabilities with more robust and configurable options
- [x] Enhancing retrieval capabilities with more robust and configurable options
- [ ] Expanding test coverage to ensure reliability and catch edge cases

## Contributing
Expand Down
8 changes: 4 additions & 4 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
)
from graphiti_core.utils.bulk_utils import (
RawEpisode,
add_nodes_and_edges_bulk,
dedupe_edges_bulk,
dedupe_nodes_bulk,
extract_edge_dates_bulk,
Expand Down Expand Up @@ -451,10 +452,9 @@ async def add_episode_endpoint(episode_data: EpisodeData):
if not self.store_raw_episode_content:
episode.content = ''

await episode.save(self.driver)
await asyncio.gather(*[node.save(self.driver) for node in nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
await add_nodes_and_edges_bulk(
self.driver, [episode], episodic_edges, nodes, entity_edges
)

# Update any communities
if update_communities:
Expand Down
1 change: 1 addition & 0 deletions graphiti_core/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
load_dotenv()

DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))


def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
Expand Down
20 changes: 20 additions & 0 deletions graphiti_core/models/edges/edge_db_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN r.uuid AS uuid"""

EPISODIC_EDGE_SAVE_BULK = """
UNWIND $episodic_edges AS edge
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
MATCH (node:Entity {uuid: edge.target_node_uuid})
MERGE (episode)-[r:MENTIONS {uuid: edge.uuid}]->(node)
SET r = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at}
RETURN r.uuid AS uuid
"""

ENTITY_EDGE_SAVE = """
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})
Expand All @@ -14,6 +23,17 @@
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
RETURN r.uuid AS uuid"""

ENTITY_EDGE_SAVE_BULK = """
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
RETURN r.uuid AS uuid
"""

COMMUNITY_EDGE_SAVE = """
MATCH (community:Community {uuid: $community_uuid})
MATCH (node:Entity | Community {uuid: $entity_uuid})
Expand Down
17 changes: 17 additions & 0 deletions graphiti_core/models/nodes/node_db_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,29 @@
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid"""

EPISODIC_NODE_SAVE_BULK = """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
source: episode.source, content: episode.content,
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""

ENTITY_NODE_SAVE = """
MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid"""

ENTITY_NODE_SAVE_BULK = """
UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid})
SET n = {uuid: node.uuid, name: node.name, group_id: node.group_id, summary: node.summary, created_at: node.created_at}
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
RETURN n.uuid AS uuid
"""

COMMUNITY_NODE_SAVE = """
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
Expand Down
65 changes: 41 additions & 24 deletions graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,15 @@

import numpy as np
from neo4j import AsyncDriver, Query
from typing_extensions import LiteralString

from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
from graphiti_core.helpers import DEFAULT_DATABASE, lucene_sanitize, normalize_l2
from graphiti_core.helpers import (
DEFAULT_DATABASE,
USE_PARALLEL_RUNTIME,
lucene_sanitize,
normalize_l2,
)
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
Expand All @@ -38,7 +44,7 @@
DEFAULT_MIN_SCORE = 0.6
DEFAULT_MMR_LAMBDA = 0.5
MAX_SEARCH_DEPTH = 3
MAX_QUERY_LENGTH = 128
MAX_QUERY_LENGTH = 32


def fulltext_query(query: str, group_ids: list[str] | None = None):
Expand Down Expand Up @@ -187,8 +193,11 @@ async def edge_similarity_search(
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityEdge]:
# vector similarity search over embedded facts
query = Query("""
CYPHER runtime = parallel parallelRuntimeSupport=all
runtime_query: LiteralString = (
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
)

query: LiteralString = """
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
Expand All @@ -210,10 +219,10 @@ async def edge_similarity_search(
r.invalid_at AS invalid_at
ORDER BY score DESC
LIMIT $limit
""")
"""

records, _, _ = await driver.execute_query(
query,
runtime_query + query,
search_vector=search_vector,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
Expand Down Expand Up @@ -318,9 +327,13 @@ async def node_similarity_search(
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityNode]:
# vector similarity search over entity names
runtime_query: LiteralString = (
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
)

records, _, _ = await driver.execute_query(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all
runtime_query
+ """
MATCH (n:Entity)
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
Expand Down Expand Up @@ -425,23 +438,27 @@ async def community_similarity_search(
min_score=DEFAULT_MIN_SCORE,
) -> list[CommunityNode]:
# vector similarity search over entity names
runtime_query: LiteralString = (
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
)

records, _, _ = await driver.execute_query(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (comm:Community)
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
comm.uuid As uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.name_embedding AS name_embedding,
comm.created_at AS created_at,
comm.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
runtime_query
+ """
MATCH (comm:Community)
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
comm.uuid As uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.name_embedding AS name_embedding,
comm.created_at AS created_at,
comm.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
search_vector=search_vector,
group_ids=group_ids,
limit=limit,
Expand Down
39 changes: 38 additions & 1 deletion graphiti_core/utils/bulk_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,20 @@
from datetime import datetime
from math import ceil

from neo4j import AsyncDriver
from neo4j import AsyncDriver, AsyncManagedTransaction
from numpy import dot, sqrt
from pydantic import BaseModel

from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
from graphiti_core.llm_client import LLMClient
from graphiti_core.models.edges.edge_db_queries import (
ENTITY_EDGE_SAVE_BULK,
EPISODIC_EDGE_SAVE_BULK,
)
from graphiti_core.models.nodes.node_db_queries import (
ENTITY_NODE_SAVE_BULK,
EPISODIC_NODE_SAVE_BULK,
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
from graphiti_core.utils import retrieve_episodes
Expand Down Expand Up @@ -75,6 +83,35 @@ async def retrieve_previous_episodes_bulk(
return episode_tuples


async def add_nodes_and_edges_bulk(
driver: AsyncDriver,
episodic_nodes: list[EpisodicNode],
episodic_edges: list[EpisodicEdge],
entity_nodes: list[EntityNode],
entity_edges: list[EntityEdge],
):
async with driver.session() as session:
await session.execute_write(
add_nodes_and_edges_bulk_tx, episodic_nodes, episodic_edges, entity_nodes, entity_edges
)


async def add_nodes_and_edges_bulk_tx(
tx: AsyncManagedTransaction,
episodic_nodes: list[EpisodicNode],
episodic_edges: list[EpisodicEdge],
entity_nodes: list[EntityNode],
entity_edges: list[EntityEdge],
):
episodes = [dict(episode) for episode in episodic_nodes]
for episode in episodes:
episode['source'] = str(episode['source'].value)
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=[dict(entity) for entity in entity_nodes])
await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges])
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[dict(edge) for edge in entity_edges])


async def extract_nodes_and_edges_bulk(
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.3.19"
version = "0.3.20"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <paul@getzep.com>",
Expand Down

0 comments on commit b8f5267

Please sign in to comment.