diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index 3cad6db0..d3c64617 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -2,7 +2,6 @@ import inspect import json import os -import re from dataclasses import dataclass from typing import Any, Dict, List, Tuple, Union @@ -27,9 +26,6 @@ class GremlinStorage(BaseGraphStorage): def load_nx_graph(file_name): print("no preloading of graph with Gremlin in production") - # Will use this to make sure single quotes are properly escaped - escape_rx = re.compile(r"(^|[^\\])((\\\\)*\\)\\'") - def __init__(self, namespace, global_config, embedding_func): super().__init__( namespace=namespace, @@ -51,12 +47,8 @@ def __init__(self, namespace, global_config, embedding_func): # All vertices will have graph={GRAPH} property, so that we can # have several logical graphs for one source - GRAPH = GremlinStorage.escape_rx.sub( - r"\1\2'", - os.environ["GREMLIN_GRAPH"].replace("'", r"\'"), - ) + GRAPH = GremlinStorage._to_value_map(os.environ["GREMLIN_GRAPH"]) - self.traverse_source_name = SOURCE self.graph_name = GRAPH self._driver = client.Client( @@ -87,7 +79,7 @@ async def index_done_callback(self): @staticmethod def _to_value_map(value: Any) -> str: - """Dump Python dict as Gremlin valueMap""" + """Dump supported Python object as Gremlin valueMap""" json_str = json.dumps(value, ensure_ascii=False, sort_keys=False) parsed_str = json_str.replace("'", r"\'") @@ -122,17 +114,16 @@ def _convert_properties(properties: Dict[str, Any]) -> str: """Create chained .property() commands from properties dict""" props = [] for k, v in properties.items(): - prop_name = GremlinStorage.escape_rx.sub(r"\1\2'", k.replace("'", r"\'")) - props.append(f".property('{prop_name}', {GremlinStorage._to_value_map(v)})") + prop_name = GremlinStorage._to_value_map(k) + props.append(f".property({prop_name}, {GremlinStorage._to_value_map(v)})") return "".join(props) @staticmethod - def _fix_label(label: str) -> str: - """Strip double quotes and make sure single quotes are escaped""" - label = label.strip('"').replace("'", r"\'") - label = GremlinStorage.escape_rx.sub(r"\1\2'", label) + def _fix_name(name: str) -> str: + """Strip double quotes and format as a proper field name""" + name = GremlinStorage._to_value_map(name.strip('"').replace(r"\'", "'")) - return label + return name async def _query(self, query: str) -> List[Dict[str, Any]]: """ @@ -146,66 +137,69 @@ async def _query(self, query: str) -> List[Dict[str, Any]]: """ result = list(await asyncio.wrap_future(self._driver.submit_async(query))) + if result: + result = result[0] return result async def has_node(self, node_id: str) -> bool: - entity_name_label = GremlinStorage._fix_label(node_id) + entity_name = GremlinStorage._fix_name(node_id) - query = f""" - {self.traverse_source_name} - .V().has('graph', '{self.graph_name}') - .hasLabel('{entity_name_label}') + query = f"""g + .V().has('graph', {self.graph_name}) + .has('entity_name', {entity_name}) .limit(1) - .hasNext() + .count() + .project('has_node') + .by(__.choose(__.is(gt(0)), constant(true), constant(false))) """ result = await self._query(query) logger.debug( "{%s}:query:{%s}:result:{%s}", inspect.currentframe().f_code.co_name, query, - result[0][0], + result[0]["has_node"], ) - return result[0][0] + return result[0]["has_node"] async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - entity_name_label_source = GremlinStorage._fix_label(source_node_id) - entity_name_label_target = GremlinStorage._fix_label(target_node_id) - - query = f""" - {self.traverse_source_name} - .V().has('graph', '{self.graph_name}') - .hasLabel('{entity_name_label_source}') - .bothE() - .otherV().has('graph', '{self.graph_name}') - .hasLabel('{entity_name_label_target}') + entity_name_source = GremlinStorage._fix_name(source_node_id) + entity_name_target = GremlinStorage._fix_name(target_node_id) + + query = f"""g + .V().has('graph', {self.graph_name}) + .has('entity_name', {entity_name_source}) + .outE() + .inV().has('graph', {self.graph_name}) + .has('entity_name', {entity_name_target}) .limit(1) - .hasNext() + .count() + .project('has_edge') + .by(__.choose(__.is(gt(0)), constant(true), constant(false))) """ result = await self._query(query) logger.debug( "{%s}:query:{%s}:result:{%s}", inspect.currentframe().f_code.co_name, query, - result[0][0], + result[0]["has_edge"], ) - return result[0][0] + return result[0]["has_edge"] async def get_node(self, node_id: str) -> Union[dict, None]: - entity_name_label = GremlinStorage._fix_label(node_id) - query = f""" - {self.traverse_source_name} - .V().has('graph', '{self.graph_name}') - .hasLabel('{entity_name_label}') + entity_name = GremlinStorage._fix_name(node_id) + query = f"""g + .V().has('graph', {self.graph_name}) + .has('entity_name', {entity_name}) .limit(1) .project('properties') .by(elementMap()) """ result = await self._query(query) if result: - node = result[0][0] + node = result[0] node_dict = node["properties"] logger.debug( "{%s}: query: {%s}, result: {%s}", @@ -216,19 +210,18 @@ async def get_node(self, node_id: str) -> Union[dict, None]: return node_dict async def node_degree(self, node_id: str) -> int: - entity_name_label = GremlinStorage._fix_label(node_id) - query = f""" - {self.traverse_source_name} - .V().has('graph', '{self.graph_name}') - .hasLabel('{entity_name_label}') + entity_name = GremlinStorage._fix_name(node_id) + query = f"""g + .V().has('graph', {self.graph_name}) + .has('entity_name', {entity_name}) .outE() - .inV().has('graph', '{self.graph_name}') + .inV().has('graph', {self.graph_name}) .count() .project('total_edge_count') .by() """ result = await self._query(query) - edge_count = result[0][0]["total_edge_count"] + edge_count = result[0]["total_edge_count"] logger.debug( "{%s}:query:{%s}:result:{%s}", @@ -259,31 +252,30 @@ async def get_edge( self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: """ - Find all edges between nodes of two given labels + Find all edges between nodes of two given names Args: - source_node_label (str): Label of the source nodes - target_node_label (str): Label of the target nodes + source_node_id (str): Name of the source nodes + target_node_id (str): Name of the target nodes Returns: - dict|None: Dict of found edge properties, or None of not found + dict|None: Dict of found edge properties, or None if not found """ - entity_name_label_source = GremlinStorage._fix_label(source_node_id) - entity_name_label_target = GremlinStorage._fix_label(target_node_id) - query = f""" - {self.traverse_source_name} - .V().has('graph', '{self.graph_name}') - .hasLabel('{entity_name_label_source}') + entity_name_source = GremlinStorage._fix_name(source_node_id) + entity_name_target = GremlinStorage._fix_name(target_node_id) + query = f"""g + .V().has('graph', {self.graph_name}) + .has('entity_name', {entity_name_source}) .outE() - .inV().has('graph', '{self.graph_name}') - .hasLabel('{entity_name_label_target}') + .inV().has('graph', {self.graph_name}) + .has('entity_name', {entity_name_target}) .limit(1) .project('edge_properties') .by(__.bothE().elementMap()) """ result = await self._query(query) if result: - edge_properties = result[0][0]["edge_properties"] + edge_properties = result[0]["edge_properties"] logger.debug( "{%s}:query:{%s}:result:{%s}", inspect.currentframe().f_code.co_name, @@ -294,45 +286,31 @@ async def get_edge( async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: """ - Retrieves all edges (relationships) for a particular node identified by its label. + Retrieves all edges (relationships) for a particular node identified by its name. :return: List of tuples containing edge sources and targets """ - node_label = GremlinStorage._fix_label(source_node_id) - query1 = f""" - {self.traverse_source_name} - .V().has('graph', '{self.graph_name}') - .hasLabel('{node_label}') - .out().has('graph', '{self.graph_name}') - .project('connected_label') - .by(__.label()) - """ - query2 = f""" - {self.traverse_source_name} - .V().has('graph', '{self.graph_name}') - .as('connected') - .out().has('graph', '{self.graph_name}') - .hasLabel('{node_label}') - .project('connected_label') - .by(__.select('connected').label()) + node_name = GremlinStorage._fix_name(source_node_id) + query = f"""g + .E() + .filter( + __.or( + __.outV().has('graph', {self.graph_name}) + .has('entity_name', {node_name}), + __.inV().has('graph', {self.graph_name}) + .has('entity_name', {node_name}) + ) + ) + .project('source_name', 'target_name') + .by(__.outV().values('entity_name')) + .by(__.inV().values('entity_name')) """ - result1, result2 = await asyncio.gather( - self._query(query1), self._query(query2) - ) - edges1 = ( - [(node_label, res["connected_label"]) for res in result1[0]] - if result1 - else [] - ) - edges2 = ( - [(res["connected_label"], node_label) for res in result2[0]] - if result2 - else [] - ) + result = await self._query(query) + edges = [(res["source_name"], res["target_name"]) for res in result] - return edges1 + edges2 + return edges @retry( - stop=stop_after_attempt(3), + stop=stop_after_attempt(10), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((GremlinServerError,)), ) @@ -341,28 +319,30 @@ async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): Upsert a node in the Gremlin graph. Args: - node_id: The unique identifier for the node (used as label) + node_id: The unique identifier for the node (used as name) node_data: Dictionary of node properties """ - label = GremlinStorage._fix_label(node_id) + name = GremlinStorage._fix_name(node_id) properties = GremlinStorage._convert_properties(node_data) - query = f""" - {self.traverse_source_name} - .V().has('graph', '{self.graph_name}') - .hasLabel('{label}').fold() + query = f"""g + .V().has('graph', {self.graph_name}) + .has('entity_name', {name}) + .fold() .coalesce( - unfold(), - addV('{label}')) - .property('graph', '{self.graph_name}') + __.unfold(), + __.addV('ENTITY') + .property('graph', {self.graph_name}) + .property('entity_name', {name}) + ) {properties} """ try: await self._query(query) logger.debug( - "Upserted node with label '{%s}' and properties: {%s}", - label, + "Upserted node with name {%s} and properties: {%s}", + name, properties, ) except Exception as e: @@ -370,7 +350,7 @@ async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): raise @retry( - stop=stop_after_attempt(3), + stop=stop_after_attempt(10), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((GremlinServerError,)), ) @@ -378,36 +358,35 @@ async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] ): """ - Upsert an edge and its properties between two nodes identified by their labels. + Upsert an edge and its properties between two nodes identified by their names. Args: - source_node_id (str): Label of the source node (used as identifier) - target_node_id (str): Label of the target node (used as identifier) + source_node_id (str): Name of the source node (used as identifier) + target_node_id (str): Name of the target node (used as identifier) edge_data (dict): Dictionary of properties to set on the edge """ - source_node_label = GremlinStorage._fix_label(source_node_id) - target_node_label = GremlinStorage._fix_label(target_node_id) + source_node_name = GremlinStorage._fix_name(source_node_id) + target_node_name = GremlinStorage._fix_name(target_node_id) edge_properties = GremlinStorage._convert_properties(edge_data) - query = f""" - {self.traverse_source_name} - .V().has('graph', '{self.graph_name}') - .hasLabel('{source_node_label}').as('source') - .V().has('graph', '{self.graph_name}') - .hasLabel('{target_node_label}').as('target') + query = f"""g + .V().has('graph', {self.graph_name}) + .has('entity_name', {source_node_name}).as('source') + .V().has('graph', {self.graph_name}) + .has('entity_name', {target_node_name}).as('target') .coalesce( - select('source').outE('DIRECTED').where(inV().as('target')), - select('source').addE('DIRECTED').to(select('target')) - ) - .property('graph', '{self.graph_name}') + __.select('source').outE('DIRECTED').where(__.inV().as('target')), + __.select('source').addE('DIRECTED').to(__.select('target')) + ) + .property('graph', {self.graph_name}) {edge_properties} """ try: await self._query(query) logger.debug( - "Upserted edge from '{%s}' to '{%s}' with properties: {%s}", - source_node_label, - target_node_label, + "Upserted edge from {%s} to {%s} with properties: {%s}", + source_node_name, + target_node_name, edge_properties, ) except Exception as e: