Skip to content

Commit

Permalink
add fulltext search limit (#215)
Browse files Browse the repository at this point in the history
* add fulltext search limit

* format

* update

* update

* update tests

* remove unused imports

* format

* mypy
  • Loading branch information
prasmussen15 authored Nov 14, 2024
1 parent a8a73ec commit 281fe07
Show file tree
Hide file tree
Showing 8 changed files with 301 additions and 363 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ poetry add graphiti-core
```python
from graphiti_core import Graphiti
from graphiti_core.nodes import EpisodeType
from datetime import datetime
from datetime import datetime, timezone

# Initialize Graphiti
graphiti = Graphiti("bolt://localhost:7687", "neo4j", "password")
Expand All @@ -128,7 +128,7 @@ for i, episode in enumerate(episodes):
episode_body=episode,
source=EpisodeType.text,
source_description="podcast",
reference_time=datetime.now()
reference_time=datetime.now(timezone.utc)
)

# Search the graph
Expand Down
552 changes: 277 additions & 275 deletions examples/ecommerce/runner.ipynb

Large diffs are not rendered by default.

15 changes: 8 additions & 7 deletions examples/langgraph-agent/agent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"import sys\n",
"import uuid\n",
"from contextlib import suppress\n",
"from datetime import datetime\n",
"from datetime import datetime, timezone\n",
"from pathlib import Path\n",
"from typing import Annotated\n",
"\n",
Expand Down Expand Up @@ -191,7 +191,7 @@
" content=str({k: v for k, v in product.items() if k != 'images'}),\n",
" source_description='ManyBirds products',\n",
" source=EpisodeType.json,\n",
" reference_time=datetime.now(),\n",
" reference_time=datetime.now(timezone.utc),\n",
" )\n",
" for i, product in enumerate(products)\n",
" ]\n",
Expand All @@ -217,23 +217,25 @@
"metadata": {},
"outputs": [],
"source": [
"from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF\n",
"\n",
"user_name = 'jess'\n",
"\n",
"await client.add_episode(\n",
" name='User Creation',\n",
" episode_body=(f'{user_name} is interested in buying a pair of shoes'),\n",
" source=EpisodeType.text,\n",
" reference_time=datetime.now(),\n",
" reference_time=datetime.now(timezone.utc),\n",
" source_description='SalesBot',\n",
")\n",
"\n",
"# let's get Jess's node uuid\n",
"nl = await client.get_nodes_by_query(user_name)\n",
"nl = await client._search(user_name, NODE_HYBRID_SEARCH_RRF)\n",
"\n",
"user_node_uuid = nl[0].uuid\n",
"\n",
"# and the ManyBirds node uuid\n",
"nl = await client.get_nodes_by_query('ManyBirds')\n",
"nl = await client._search('ManyBirds', NODE_HYBRID_SEARCH_RRF)\n",
"manybirds_node_uuid = nl[0].uuid"
]
},
Expand Down Expand Up @@ -390,7 +392,7 @@
" name='Chatbot Response',\n",
" episode_body=f\"{state['user_name']}: {state['messages'][-1]}\\nSalesBot: {response.content}\",\n",
" source=EpisodeType.message,\n",
" reference_time=datetime.now(),\n",
" reference_time=datetime.now(timezone.utc),\n",
" source_description='Chatbot',\n",
" )\n",
" )\n",
Expand Down Expand Up @@ -443,7 +445,6 @@
"graph_builder.add_conditional_edges('agent', should_continue, {'continue': 'tools', 'end': END})\n",
"graph_builder.add_edge('tools', 'agent')\n",
"\n",
"\n",
"graph = graph_builder.compile(checkpointer=memory)"
]
},
Expand Down
74 changes: 6 additions & 68 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
from graphiti_core.search.search_config_recipes import (
EDGE_HYBRID_SEARCH_NODE_DISTANCE,
EDGE_HYBRID_SEARCH_RRF,
NODE_HYBRID_SEARCH_NODE_DISTANCE,
NODE_HYBRID_SEARCH_RRF,
)
from graphiti_core.search.search_utils import (
RELEVANT_SCHEMA_LIMIT,
Expand Down Expand Up @@ -318,7 +316,7 @@ async def add_episode_endpoint(episode_data: EpisodeData):
now = datetime.now(timezone.utc)

previous_episodes = await self.retrieve_episodes(
reference_time, last_n=3, group_ids=[group_id]
reference_time, last_n=RELEVANT_SCHEMA_LIMIT, group_ids=[group_id]
)
episode = EpisodicNode(
name=name,
Expand All @@ -343,13 +341,14 @@ async def add_episode_endpoint(episode_data: EpisodeData):
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
)

# Resolve extracted nodes with nodes already in the graph and extract facts
# Find relevant nodes already in the graph
existing_nodes_lists: list[list[EntityNode]] = list(
await asyncio.gather(
*[get_relevant_nodes([node], self.driver) for node in extracted_nodes]
*[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes]
)
)

# Resolve extracted nodes with nodes already in the graph and extract facts
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')

(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
Expand Down Expand Up @@ -693,67 +692,6 @@ async def _search(
bfs_origin_node_uuids,
)

async def get_nodes_by_query(
self,
query: str,
center_node_uuid: str | None = None,
group_ids: list[str] | None = None,
limit: int = DEFAULT_SEARCH_LIMIT,
) -> list[EntityNode]:
"""
Retrieve nodes from the graph database based on a text query.
This method performs a hybrid search using both text-based and
embedding-based approaches to find relevant nodes.
Parameters
----------
query : str
The text query to search for in the graph
center_node_uuid: str, optional
Facts will be reranked based on proximity to this node.
group_ids : list[str | None] | None, optional
The graph partitions to return data from.
limit : int | None, optional
The maximum number of results to return per search method.
If None, a default limit will be applied.
Returns
-------
list[EntityNode]
A list of EntityNode objects that match the search criteria.
Notes
-----
This method uses the following steps:
1. Generates an embedding for the input query using the LLM client's embedder.
2. Calls the hybrid_node_search function with both the text query and its embedding.
3. The hybrid search combines fulltext search and vector similarity search
to find the most relevant nodes.
The method leverages the LLM client's embedding capabilities to enhance
the search with semantic similarity matching. The 'limit' parameter is applied
to each individual search method before results are combined and deduplicated.
If not specified, a default limit (defined in the search functions) will be used.
"""
search_config = (
NODE_HYBRID_SEARCH_RRF if center_node_uuid is None else NODE_HYBRID_SEARCH_NODE_DISTANCE
)
search_config.limit = limit

nodes = (
await search(
self.driver,
self.embedder,
self.cross_encoder,
query,
group_ids,
search_config,
center_node_uuid,
)
).nodes
return nodes

async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)

Expand Down Expand Up @@ -781,8 +719,8 @@ async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_no
self.llm_client,
[source_node, target_node],
[
await get_relevant_nodes([source_node], self.driver),
await get_relevant_nodes([target_node], self.driver),
await get_relevant_nodes(self.driver, [source_node]),
await get_relevant_nodes(self.driver, [target_node]),
],
)

Expand Down
9 changes: 5 additions & 4 deletions graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ async def edge_fulltext_search(
return []

cypher_query = Query("""
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query)
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query, {limit: $limit})
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]->(m:Entity)
WHERE ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
Expand Down Expand Up @@ -296,7 +296,7 @@ async def node_fulltext_search(

records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
YIELD node AS n, score
RETURN
n.uuid AS uuid,
Expand Down Expand Up @@ -407,7 +407,7 @@ async def community_fulltext_search(

records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("community_name", $query)
CALL db.index.fulltext.queryNodes("community_name", $query, {limit: $limit})
YIELD node AS comm, score
RETURN
comm.uuid AS uuid,
Expand Down Expand Up @@ -539,8 +539,8 @@ async def hybrid_node_search(


async def get_relevant_nodes(
nodes: list[EntityNode],
driver: AsyncDriver,
nodes: list[EntityNode],
) -> list[EntityNode]:
"""
Retrieve relevant nodes based on the provided list of EntityNodes.
Expand Down Expand Up @@ -573,6 +573,7 @@ async def get_relevant_nodes(
driver,
[node.group_id for node in nodes],
)

return relevant_nodes


Expand Down
2 changes: 1 addition & 1 deletion graphiti_core/utils/bulk_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ async def dedupe_nodes_bulk(

existing_nodes_chunks: list[list[EntityNode]] = list(
await asyncio.gather(
*[get_relevant_nodes(node_chunk, driver) for node_chunk in node_chunks]
*[get_relevant_nodes(driver, node_chunk) for node_chunk in node_chunks]
)
)

Expand Down
2 changes: 1 addition & 1 deletion graphiti_core/utils/maintenance/community_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ async def update_community(
community.name = new_name

if is_new:
community_edge = (build_community_edges([entity], community, datetime.now()))[0]
community_edge = (build_community_edges([entity], community, datetime.now(timezone.utc)))[0]
await community_edge.save(driver)

await community.generate_name_embedding(embedder)
Expand Down
6 changes: 1 addition & 5 deletions tests/test_graphiti_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,10 @@ async def test_graphiti_init():

await graphiti.add_triplet(alice_node, entity_edge, bob_node)

episodes = await graphiti.retrieve_episodes(datetime.now(timezone.utc), group_ids=None)
episode_uuids = [episode.uuid for episode in episodes]

results = await graphiti._search(
"Emily: I can't log in",
COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
bfs_origin_node_uuids=episode_uuids,
group_ids=None,
group_ids=['test'],
)
pretty_results = {
'edges': [edge.fact for edge in results.edges],
Expand Down

0 comments on commit 281fe07

Please sign in to comment.