Skip to content

Commit

Permalink
GremlinStorage: fixes and patch to support other Gremlin compatible B…
Browse files Browse the repository at this point in the history
…D. Tested on ArcadeDB with Gremlin plugin. The main change is using "entity_name" vertex property instead of the label as a node_id since different implementations have different restrictions on label names.
  • Loading branch information
alllexx88 committed Dec 23, 2024
1 parent bfacfb9 commit 288aefb
Showing 1 changed file with 102 additions and 123 deletions.
225 changes: 102 additions & 123 deletions lightrag/kg/gremlin_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import inspect
import json
import os
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union

Expand All @@ -27,9 +26,6 @@ class GremlinStorage(BaseGraphStorage):
def load_nx_graph(file_name):
print("no preloading of graph with Gremlin in production")

# Will use this to make sure single quotes are properly escaped
escape_rx = re.compile(r"(^|[^\\])((\\\\)*\\)\\'")

def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,
Expand All @@ -51,12 +47,8 @@ def __init__(self, namespace, global_config, embedding_func):

# All vertices will have graph={GRAPH} property, so that we can
# have several logical graphs for one source
GRAPH = GremlinStorage.escape_rx.sub(
r"\1\2'",
os.environ["GREMLIN_GRAPH"].replace("'", r"\'"),
)
GRAPH = GremlinStorage._to_value_map(os.environ["GREMLIN_GRAPH"])

self.traverse_source_name = SOURCE
self.graph_name = GRAPH

self._driver = client.Client(
Expand Down Expand Up @@ -87,7 +79,7 @@ async def index_done_callback(self):

@staticmethod
def _to_value_map(value: Any) -> str:
"""Dump Python dict as Gremlin valueMap"""
"""Dump supported Python object as Gremlin valueMap"""
json_str = json.dumps(value, ensure_ascii=False, sort_keys=False)
parsed_str = json_str.replace("'", r"\'")

Expand Down Expand Up @@ -122,17 +114,16 @@ def _convert_properties(properties: Dict[str, Any]) -> str:
"""Create chained .property() commands from properties dict"""
props = []
for k, v in properties.items():
prop_name = GremlinStorage.escape_rx.sub(r"\1\2'", k.replace("'", r"\'"))
props.append(f".property('{prop_name}', {GremlinStorage._to_value_map(v)})")
prop_name = GremlinStorage._to_value_map(k)
props.append(f".property({prop_name}, {GremlinStorage._to_value_map(v)})")
return "".join(props)

@staticmethod
def _fix_label(label: str) -> str:
"""Strip double quotes and make sure single quotes are escaped"""
label = label.strip('"').replace("'", r"\'")
label = GremlinStorage.escape_rx.sub(r"\1\2'", label)
def _fix_name(name: str) -> str:
"""Strip double quotes and format as a proper field name"""
name = GremlinStorage._to_value_map(name.strip('"').replace(r"\'", "'"))

return label
return name

async def _query(self, query: str) -> List[Dict[str, Any]]:
"""
Expand All @@ -146,66 +137,69 @@ async def _query(self, query: str) -> List[Dict[str, Any]]:
"""

result = list(await asyncio.wrap_future(self._driver.submit_async(query)))
if result:
result = result[0]

return result

async def has_node(self, node_id: str) -> bool:
entity_name_label = GremlinStorage._fix_label(node_id)
entity_name = GremlinStorage._fix_name(node_id)

query = f"""
{self.traverse_source_name}
.V().has('graph', '{self.graph_name}')
.hasLabel('{entity_name_label}')
query = f"""g
.V().has('graph', {self.graph_name}')
.has('entity_name', {entity_name})
.limit(1)
.hasNext()
.project('has_node')
.by()
"""
result = await self._query(query)
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query,
result[0][0],
result["has_node"],
)

return result[0][0]
return result["has_node"]

async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_label_source = GremlinStorage._fix_label(source_node_id)
entity_name_label_target = GremlinStorage._fix_label(target_node_id)
entity_name_source = GremlinStorage._fix_name(source_node_id)
entity_name_target = GremlinStorage._fix_name(target_node_id)

query = f"""
{self.traverse_source_name}
.V().has('graph', '{self.graph_name}')
.hasLabel('{entity_name_label_source}')
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name_source})
.bothE()
.otherV().has('graph', '{self.graph_name}')
.hasLabel('{entity_name_label_target}')
.otherV().has('graph', {self.graph_name})
.has('entity_name', {entity_name_target})
.limit(1)
.hasNext()
.project('has_edge')
.by()
"""
result = await self._query(query)
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query,
result[0][0],
result["has_edge"],
)

return result[0][0]
return result["has_edge"]

async def get_node(self, node_id: str) -> Union[dict, None]:
entity_name_label = GremlinStorage._fix_label(node_id)
query = f"""
{self.traverse_source_name}
.V().has('graph', '{self.graph_name}')
.hasLabel('{entity_name_label}')
entity_name = GremlinStorage._fix_name(node_id)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name})
.limit(1)
.project('properties')
.by(elementMap())
"""
result = await self._query(query)
if result:
node = result[0][0]
node = result[0]
node_dict = node["properties"]
logger.debug(
"{%s}: query: {%s}, result: {%s}",
Expand All @@ -216,19 +210,18 @@ async def get_node(self, node_id: str) -> Union[dict, None]:
return node_dict

async def node_degree(self, node_id: str) -> int:
entity_name_label = GremlinStorage._fix_label(node_id)
query = f"""
{self.traverse_source_name}
.V().has('graph', '{self.graph_name}')
.hasLabel('{entity_name_label}')
entity_name = GremlinStorage._fix_name(node_id)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name})
.outE()
.inV().has('graph', '{self.graph_name}')
.inV().has('graph', {self.graph_name})
.count()
.project('total_edge_count')
.by()
"""
result = await self._query(query)
edge_count = result[0][0]["total_edge_count"]
edge_count = result[0]["total_edge_count"]

logger.debug(
"{%s}:query:{%s}:result:{%s}",
Expand Down Expand Up @@ -259,31 +252,30 @@ async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
"""
Find all edges between nodes of two given labels
Find all edges between nodes of two given names
Args:
source_node_label (str): Label of the source nodes
target_node_label (str): Label of the target nodes
source_node_id (str): Name of the source nodes
target_node_id (str): Name of the target nodes
Returns:
dict|None: Dict of found edge properties, or None of not found
dict|None: Dict of found edge properties, or None if not found
"""
entity_name_label_source = GremlinStorage._fix_label(source_node_id)
entity_name_label_target = GremlinStorage._fix_label(target_node_id)
query = f"""
{self.traverse_source_name}
.V().has('graph', '{self.graph_name}')
.hasLabel('{entity_name_label_source}')
entity_name_source = GremlinStorage._fix_name(source_node_id)
entity_name_target = GremlinStorage._fix_name(target_node_id)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name_source})
.outE()
.inV().has('graph', '{self.graph_name}')
.hasLabel('{entity_name_label_target}')
.inV().has('graph', {self.graph_name})
.has('entity_name', {entity_name_target})
.limit(1)
.project('edge_properties')
.by(__.bothE().elementMap())
"""
result = await self._query(query)
if result:
edge_properties = result[0][0]["edge_properties"]
edge_properties = result[0]["edge_properties"]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
Expand All @@ -294,45 +286,31 @@ async def get_edge(

async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
"""
Retrieves all edges (relationships) for a particular node identified by its label.
Retrieves all edges (relationships) for a particular node identified by its name.
:return: List of tuples containing edge sources and targets
"""
node_label = GremlinStorage._fix_label(source_node_id)
query1 = f"""
{self.traverse_source_name}
.V().has('graph', '{self.graph_name}')
.hasLabel('{node_label}')
.out().has('graph', '{self.graph_name}')
.project('connected_label')
.by(__.label())
"""
query2 = f"""
{self.traverse_source_name}
.V().has('graph', '{self.graph_name}')
.as('connected')
.out().has('graph', '{self.graph_name}')
.hasLabel('{node_label}')
.project('connected_label')
.by(__.select('connected').label())
node_name = GremlinStorage._fix_name(source_node_id)
query = f"""g
.E()
.filter(
__.or(
__.outV().has('graph', {self.graph_name})
.has('entity_name', {node_name}),
__.inV().has('graph', {self.graph_name})
.has('entity_name', {node_name})
)
)
.project('source_name', 'target_name')
.by(__.outV().values('entity_name'))
.by(__.inV().values('entity_name'))
"""
result1, result2 = await asyncio.gather(
self._query(query1), self._query(query2)
)
edges1 = (
[(node_label, res["connected_label"]) for res in result1[0]]
if result1
else []
)
edges2 = (
[(res["connected_label"], node_label) for res in result2[0]]
if result2
else []
)
result = await self._query(query)
edges = [(res["source_name"], res["target_name"]) for res in result]

return edges1 + edges2
return edges

@retry(
stop=stop_after_attempt(3),
stop=stop_after_attempt(10),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((GremlinServerError,)),
)
Expand All @@ -341,73 +319,74 @@ async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
Upsert a node in the Gremlin graph.
Args:
node_id: The unique identifier for the node (used as label)
node_id: The unique identifier for the node (used as name)
node_data: Dictionary of node properties
"""
label = GremlinStorage._fix_label(node_id)
name = GremlinStorage._fix_name(node_id)
properties = GremlinStorage._convert_properties(node_data)

query = f"""
{self.traverse_source_name}
.V().has('graph', '{self.graph_name}')
.hasLabel('{label}').fold()
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {name})
.fold()
.coalesce(
unfold(),
addV('{label}'))
.property('graph', '{self.graph_name}')
__.unfold(),
__.addV('ENTITY')
.property('graph', {self.graph_name})
.property('entity_name', {name})
)
{properties}
"""

try:
await self._query(query)
logger.debug(
"Upserted node with label '{%s}' and properties: {%s}",
label,
"Upserted node with name {%s} and properties: {%s}",
name,
properties,
)
except Exception as e:
logger.error("Error during upsert: {%s}", e)
raise

@retry(
stop=stop_after_attempt(3),
stop=stop_after_attempt(10),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((GremlinServerError,)),
)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
):
"""
Upsert an edge and its properties between two nodes identified by their labels.
Upsert an edge and its properties between two nodes identified by their names.
Args:
source_node_id (str): Label of the source node (used as identifier)
target_node_id (str): Label of the target node (used as identifier)
source_node_id (str): Name of the source node (used as identifier)
target_node_id (str): Name of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
"""
source_node_label = GremlinStorage._fix_label(source_node_id)
target_node_label = GremlinStorage._fix_label(target_node_id)
source_node_name = GremlinStorage._fix_name(source_node_id)
target_node_name = GremlinStorage._fix_name(target_node_id)
edge_properties = GremlinStorage._convert_properties(edge_data)

query = f"""
{self.traverse_source_name}
.V().has('graph', '{self.graph_name}')
.hasLabel('{source_node_label}').as('source')
.V().has('graph', '{self.graph_name}')
.hasLabel('{target_node_label}').as('target')
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {source_node_name}).as('source')
.V().has('graph', {self.graph_name})
.has('entity_name', {target_node_name}).as('target')
.coalesce(
select('source').outE('DIRECTED').where(inV().as('target')),
select('source').addE('DIRECTED').to(select('target'))
)
.property('graph', '{self.graph_name}')
__.select('source').outE('DIRECTED').where(__.inV().as('target')),
__.select('source').addE('DIRECTED').to(__.select('target'))
)
.property('graph', {self.graph_name})
{edge_properties}
"""
try:
await self._query(query)
logger.debug(
"Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
source_node_label,
target_node_label,
"Upserted edge from {%s} to {%s} with properties: {%s}",
source_node_name,
target_node_name,
edge_properties,
)
except Exception as e:
Expand Down

0 comments on commit 288aefb

Please sign in to comment.