Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bulk add nodes and edges #205

Merged
merged 7 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,17 @@ The `server` directory contains an API service for interacting with the Graphiti

Please see the [server README](./server/README.md) for more information.

## Optional Environment Variables

In addition to the Neo4j and OpenAi-compatible credentials, Graphiti also has a few optional environment variables.
If you are using one of our supported models, such as Anthropic or Voyage models, the necessary environment variables
must be set.

`USE_PARALLEL_RUNTIME` is an optional boolean variable that can be set to true if you wish
to enable Neo4j's parallel runtime feature for several of our search queries.
Note that this feature is not supported for Neo4j Community edition or for smaller AuraDB instances,
as such this feature is off by default.

## Documentation

- [Guides and API documentation](https://help.getzep.com/graphiti).
Expand All @@ -186,11 +197,11 @@ Graphiti is under active development. We aim to maintain API stability while wor
- [x] Implementing node and edge CRUD operations
- [ ] Improving performance and scalability
- [ ] Achieving good performance with different LLM and embedding models
- [ ] Creating a dedicated embedder interface
- [x] Creating a dedicated embedder interface
- [ ] Supporting custom graph schemas:
- Allow developers to provide their own defined node and edge classes when ingesting episodes
- Enable more flexible knowledge representation tailored to specific use cases
- [ ] Enhancing retrieval capabilities with more robust and configurable options
- [x] Enhancing retrieval capabilities with more robust and configurable options
- [ ] Expanding test coverage to ensure reliability and catch edge cases

## Contributing
Expand Down
8 changes: 4 additions & 4 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
)
from graphiti_core.utils.bulk_utils import (
RawEpisode,
add_nodes_and_edges_bulk,
dedupe_edges_bulk,
dedupe_nodes_bulk,
extract_edge_dates_bulk,
Expand Down Expand Up @@ -451,10 +452,9 @@ async def add_episode_endpoint(episode_data: EpisodeData):
if not self.store_raw_episode_content:
episode.content = ''

await episode.save(self.driver)
await asyncio.gather(*[node.save(self.driver) for node in nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
await add_nodes_and_edges_bulk(
self.driver, [episode], episodic_edges, nodes, entity_edges
)

# Update any communities
if update_communities:
Expand Down
1 change: 1 addition & 0 deletions graphiti_core/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
load_dotenv()

DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))


def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
Expand Down
20 changes: 20 additions & 0 deletions graphiti_core/models/edges/edge_db_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN r.uuid AS uuid"""

EPISODIC_EDGE_SAVE_BULK = """
UNWIND $episodic_edges AS edge
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
MATCH (node:Entity {uuid: edge.target_node_uuid})
MERGE (episode)-[r:MENTIONS {uuid: edge.uuid}]->(node)
SET r = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at}
RETURN r.uuid AS uuid
"""

ENTITY_EDGE_SAVE = """
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})
Expand All @@ -14,6 +23,17 @@
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
RETURN r.uuid AS uuid"""

ENTITY_EDGE_SAVE_BULK = """
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
RETURN r.uuid AS uuid
"""

COMMUNITY_EDGE_SAVE = """
MATCH (community:Community {uuid: $community_uuid})
MATCH (node:Entity | Community {uuid: $entity_uuid})
Expand Down
17 changes: 17 additions & 0 deletions graphiti_core/models/nodes/node_db_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,29 @@
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid"""

EPISODIC_NODE_SAVE_BULK = """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
source: episode.source, content: episode.content,
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""

ENTITY_NODE_SAVE = """
MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid"""

ENTITY_NODE_SAVE_BULK = """
UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid})
SET n = {uuid: node.uuid, name: node.name, group_id: node.group_id, summary: node.summary, created_at: node.created_at}
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
RETURN n.uuid AS uuid
"""

COMMUNITY_NODE_SAVE = """
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
Expand Down
65 changes: 41 additions & 24 deletions graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,15 @@

import numpy as np
from neo4j import AsyncDriver, Query
from typing_extensions import LiteralString

from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
from graphiti_core.helpers import DEFAULT_DATABASE, lucene_sanitize, normalize_l2
from graphiti_core.helpers import (
DEFAULT_DATABASE,
USE_PARALLEL_RUNTIME,
lucene_sanitize,
normalize_l2,
)
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
Expand All @@ -38,7 +44,7 @@
DEFAULT_MIN_SCORE = 0.6
DEFAULT_MMR_LAMBDA = 0.5
MAX_SEARCH_DEPTH = 3
MAX_QUERY_LENGTH = 128
MAX_QUERY_LENGTH = 32


def fulltext_query(query: str, group_ids: list[str] | None = None):
Expand Down Expand Up @@ -187,8 +193,11 @@ async def edge_similarity_search(
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityEdge]:
# vector similarity search over embedded facts
query = Query("""
CYPHER runtime = parallel parallelRuntimeSupport=all
runtime_query: LiteralString = (
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
)

query: LiteralString = """
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
Expand All @@ -210,10 +219,10 @@ async def edge_similarity_search(
r.invalid_at AS invalid_at
ORDER BY score DESC
LIMIT $limit
""")
"""

records, _, _ = await driver.execute_query(
query,
runtime_query + query,
search_vector=search_vector,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
Expand Down Expand Up @@ -318,9 +327,13 @@ async def node_similarity_search(
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityNode]:
# vector similarity search over entity names
runtime_query: LiteralString = (
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
)

records, _, _ = await driver.execute_query(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all
runtime_query
+ """
MATCH (n:Entity)
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
Expand Down Expand Up @@ -425,23 +438,27 @@ async def community_similarity_search(
min_score=DEFAULT_MIN_SCORE,
) -> list[CommunityNode]:
# vector similarity search over entity names
runtime_query: LiteralString = (
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
)

records, _, _ = await driver.execute_query(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (comm:Community)
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
comm.uuid As uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.name_embedding AS name_embedding,
comm.created_at AS created_at,
comm.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
runtime_query
+ """
MATCH (comm:Community)
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
comm.uuid As uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.name_embedding AS name_embedding,
comm.created_at AS created_at,
comm.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
search_vector=search_vector,
group_ids=group_ids,
limit=limit,
Expand Down
39 changes: 38 additions & 1 deletion graphiti_core/utils/bulk_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,20 @@
from datetime import datetime
from math import ceil

from neo4j import AsyncDriver
from neo4j import AsyncDriver, AsyncManagedTransaction
from numpy import dot, sqrt
from pydantic import BaseModel

from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
from graphiti_core.llm_client import LLMClient
from graphiti_core.models.edges.edge_db_queries import (
ENTITY_EDGE_SAVE_BULK,
EPISODIC_EDGE_SAVE_BULK,
)
from graphiti_core.models.nodes.node_db_queries import (
ENTITY_NODE_SAVE_BULK,
EPISODIC_NODE_SAVE_BULK,
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
from graphiti_core.utils import retrieve_episodes
Expand Down Expand Up @@ -75,6 +83,35 @@ async def retrieve_previous_episodes_bulk(
return episode_tuples


async def add_nodes_and_edges_bulk(
driver: AsyncDriver,
episodic_nodes: list[EpisodicNode],
episodic_edges: list[EpisodicEdge],
entity_nodes: list[EntityNode],
entity_edges: list[EntityEdge],
):
async with driver.session() as session:
await session.execute_write(
add_nodes_and_edges_bulk_tx, episodic_nodes, episodic_edges, entity_nodes, entity_edges
)


async def add_nodes_and_edges_bulk_tx(
tx: AsyncManagedTransaction,
episodic_nodes: list[EpisodicNode],
episodic_edges: list[EpisodicEdge],
entity_nodes: list[EntityNode],
entity_edges: list[EntityEdge],
):
episodes = [dict(episode) for episode in episodic_nodes]
for episode in episodes:
episode['source'] = str(episode['source'].value)
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=[dict(entity) for entity in entity_nodes])
await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges])
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[dict(edge) for edge in entity_edges])


async def extract_nodes_and_edges_bulk(
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.3.19"
version = "0.3.20"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <paul@getzep.com>",
Expand Down