diff --git a/examples/lightrag_tidb_demo.py b/examples/lightrag_tidb_demo.py index fd73a354..b8e4d35c 100644 --- a/examples/lightrag_tidb_demo.py +++ b/examples/lightrag_tidb_demo.py @@ -21,8 +21,7 @@ TIDB_PORT = "" TIDB_USER = "" TIDB_PASSWORD = "" -TIDB_DATABASE = "" - +TIDB_DATABASE = "lightrag" if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) @@ -93,6 +92,7 @@ async def main(): ), kv_storage="TiDBKVStorage", vector_storage="TiDBVectorDBStorage", + graph_storage="TiDBGraphStorage", ) if rag.llm_response_cache: @@ -102,6 +102,7 @@ async def main(): rag.entities_vdb.db = tidb rag.relationships_vdb.db = tidb rag.chunks_vdb.db = tidb + rag.chunk_entity_relation_graph.db = tidb # Extract and Insert into LightRAG storage with open("./dickens/demo.txt", "r", encoding="utf-8") as f: diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 66b49fe3..2cf698e1 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -7,7 +7,7 @@ from sqlalchemy import create_engine, text from tqdm import tqdm -from lightrag.base import BaseVectorStorage, BaseKVStorage +from lightrag.base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage from lightrag.utils import logger @@ -282,33 +282,180 @@ async def upsert(self, data: dict[str, dict]): if self.namespace == "entities": data = [] for item in list_data: - merge_sql = SQL_TEMPLATES["upsert_entity"] - data.append( - { - "id": item["id"], - "name": item["entity_name"], - "content": item["content"], - "content_vector": f"{item["content_vector"].tolist()}", - "workspace": self.db.workspace, - } - ) - await self.db.execute(merge_sql, data) + param = { + "id": item["id"], + "name": item["entity_name"], + "content": item["content"], + "content_vector": f"{item["content_vector"].tolist()}", + "workspace": self.db.workspace, + } + # update entity_id if node inserted by graph_storage_instance before + has = await self.db.query(SQL_TEMPLATES["has_entity"], param) + if has["cnt"] != 0: + await self.db.execute(SQL_TEMPLATES["update_entity"], param) + continue + + data.append(param) + if data: + merge_sql = SQL_TEMPLATES["insert_entity"] + await self.db.execute(merge_sql, data) elif self.namespace == "relationships": data = [] for item in list_data: - merge_sql = SQL_TEMPLATES["upsert_relationship"] - data.append( - { - "id": item["id"], - "source_name": item["src_id"], - "target_name": item["tgt_id"], - "content": item["content"], - "content_vector": f"{item["content_vector"].tolist()}", - "workspace": self.db.workspace, - } - ) - await self.db.execute(merge_sql, data) + param = { + "id": item["id"], + "source_name": item["src_id"], + "target_name": item["tgt_id"], + "content": item["content"], + "content_vector": f"{item["content_vector"].tolist()}", + "workspace": self.db.workspace, + } + # update relation_id if node inserted by graph_storage_instance before + has = await self.db.query(SQL_TEMPLATES["has_relationship"], param) + if has["cnt"] != 0: + await self.db.execute(SQL_TEMPLATES["update_relationship"], param) + continue + + data.append(param) + if data: + merge_sql = SQL_TEMPLATES["insert_relationship"] + await self.db.execute(merge_sql, data) + + +@dataclass +class TiDBGraphStorage(BaseGraphStorage): + def __post_init__(self): + self._max_batch_size = self.global_config["embedding_batch_num"] + + #################### upsert method ################ + async def upsert_node(self, node_id: str, node_data: dict[str, str]): + entity_name = node_id + entity_type = node_data["entity_type"] + description = node_data["description"] + source_id = node_data["source_id"] + logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}") + content = entity_name + description + contents = [content] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + embeddings_list = await asyncio.gather( + *[self.embedding_func(batch) for batch in batches] + ) + embeddings = np.concatenate(embeddings_list) + content_vector = embeddings[0] + sql = SQL_TEMPLATES["upsert_node"] + data = { + "workspace": self.db.workspace, + "name": entity_name, + "entity_type": entity_type, + "description": description, + "source_chunk_id": source_id, + "content": content, + "content_vector": f"{content_vector.tolist()}", + } + await self.db.execute(sql, data) + + async def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ): + source_name = source_node_id + target_name = target_node_id + weight = edge_data["weight"] + keywords = edge_data["keywords"] + description = edge_data["description"] + source_chunk_id = edge_data["source_id"] + logger.debug( + f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}" + ) + + content = keywords + source_name + target_name + description + contents = [content] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + embeddings_list = await asyncio.gather( + *[self.embedding_func(batch) for batch in batches] + ) + embeddings = np.concatenate(embeddings_list) + content_vector = embeddings[0] + merge_sql = SQL_TEMPLATES["upsert_edge"] + data = { + "workspace": self.db.workspace, + "source_name": source_name, + "target_name": target_name, + "weight": weight, + "keywords": keywords, + "description": description, + "source_chunk_id": source_chunk_id, + "content": content, + "content_vector": f"{content_vector.tolist()}", + } + await self.db.execute(merge_sql, data) + + async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: + if algorithm not in self._node_embed_algorithms: + raise ValueError(f"Node embedding algorithm {algorithm} not supported") + return await self._node_embed_algorithms[algorithm]() + + # Query + + async def has_node(self, node_id: str) -> bool: + sql = SQL_TEMPLATES["has_entity"] + param = {"name": node_id, "workspace": self.db.workspace} + has = await self.db.query(sql, param) + return has["cnt"] != 0 + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + sql = SQL_TEMPLATES["has_relationship"] + param = { + "source_name": source_node_id, + "target_name": target_node_id, + "workspace": self.db.workspace, + } + has = await self.db.query(sql, param) + return has["cnt"] != 0 + + async def node_degree(self, node_id: str) -> int: + sql = SQL_TEMPLATES["node_degree"] + param = {"name": node_id, "workspace": self.db.workspace} + result = await self.db.query(sql, param) + return result["cnt"] + + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + degree = await self.node_degree(src_id) + await self.node_degree(tgt_id) + return degree + + async def get_node(self, node_id: str) -> Union[dict, None]: + sql = SQL_TEMPLATES["get_node"] + param = {"name": node_id, "workspace": self.db.workspace} + return await self.db.query(sql, param) + + async def get_edge( + self, source_node_id: str, target_node_id: str + ) -> Union[dict, None]: + sql = SQL_TEMPLATES["get_edge"] + param = { + "source_name": source_node_id, + "target_name": target_node_id, + "workspace": self.db.workspace, + } + return await self.db.query(sql, param) + + async def get_node_edges( + self, source_node_id: str + ) -> Union[list[tuple[str, str]], None]: + sql = SQL_TEMPLATES["get_node_edges"] + param = {"source_name": source_node_id, "workspace": self.db.workspace} + res = await self.db.query(sql, param, multirows=True) + if res: + data = [(i["source_name"], i["target_name"]) for i in res] + return data + else: + return [] N_T = { @@ -362,14 +509,17 @@ async def upsert(self, data: dict[str, dict]): "ddl": """ CREATE TABLE LIGHTRAG_GRAPH_NODES ( `id` BIGINT PRIMARY KEY AUTO_RANDOM, - `entity_id` VARCHAR(256) NOT NULL, + `entity_id` VARCHAR(256), `workspace` varchar(1024), `name` VARCHAR(2048), + `entity_type` VARCHAR(1024), + `description` LONGTEXT, + `source_chunk_id` VARCHAR(256), `content` LONGTEXT, `content_vector` VECTOR, `createtime` DATETIME DEFAULT CURRENT_TIMESTAMP, `updatetime` DATETIME DEFAULT NULL, - UNIQUE KEY (`entity_id`) + KEY (`entity_id`) ); """ }, @@ -377,15 +527,19 @@ async def upsert(self, data: dict[str, dict]): "ddl": """ CREATE TABLE LIGHTRAG_GRAPH_EDGES ( `id` BIGINT PRIMARY KEY AUTO_RANDOM, - `relation_id` VARCHAR(256) NOT NULL, + `relation_id` VARCHAR(256), `workspace` varchar(1024), `source_name` VARCHAR(2048), `target_name` VARCHAR(2048), + `weight` DECIMAL, + `keywords` TEXT, + `description` LONGTEXT, + `source_chunk_id` varchar(256), `content` LONGTEXT, `content_vector` VECTOR, `createtime` DATETIME DEFAULT CURRENT_TIMESTAMP, `updatetime` DATETIME DEFAULT NULL, - UNIQUE KEY (`relation_id`) + KEY (`relation_id`) ); """ }, @@ -416,39 +570,87 @@ async def upsert(self, data: dict[str, dict]): INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace) VALUES (:id, :content, :workspace) ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP - """, + """, "upsert_chunk": """ INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace) VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace) ON DUPLICATE KEY UPDATE content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index), full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP - """, + """, # SQL for VectorStorage "entities": """SELECT n.name as entity_name FROM (SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace) n - WHERE n.distance>:better_than_threshold ORDER BY n.distance DESC LIMIT :top_k""", + WHERE n.distance>:better_than_threshold ORDER BY n.distance DESC LIMIT :top_k + """, "relationships": """SELECT e.source_name as src_id, e.target_name as tgt_id FROM (SELECT source_name, target_name, VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace) e - WHERE e.distance>:better_than_threshold ORDER BY e.distance DESC LIMIT :top_k""", + WHERE e.distance>:better_than_threshold ORDER BY e.distance DESC LIMIT :top_k + """, "chunks": """SELECT c.id FROM (SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c - WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k""", - "upsert_entity": """ + WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k + """, + "has_entity": """ + SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace + """, + "has_relationship": """ + SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace + """, + "update_entity": """ + UPDATE LIGHTRAG_GRAPH_NODES SET + entity_id = :id, content = :content, content_vector = :content_vector, updatetime = CURRENT_TIMESTAMP + WHERE workspace = :workspace AND name = :name + """, + "update_relationship": """ + UPDATE LIGHTRAG_GRAPH_EDGES SET + relation_id = :id, content = :content, content_vector = :content_vector, updatetime = CURRENT_TIMESTAMP + WHERE workspace = :workspace AND source_name = :source_name AND target_name = :target_name + """, + "insert_entity": """ INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace) VALUES(:id, :name, :content, :content_vector, :workspace) - ON DUPLICATE KEY UPDATE - name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector), - workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP - """, - "upsert_relationship": """ + """, + "insert_relationship": """ INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace) VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace) + """, + # SQL for GraphStorage + "get_node": """ + SELECT entity_id AS id, workspace, name, entity_type, description, source_chunk_id AS source_id, content, content_vector + FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace + """, + "get_edge": """ + SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id AS source_id, content, content_vector + FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace + """, + "get_node_edges": """ + SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id, content, content_vector + FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND workspace = :workspace + """, + "node_degree": """ + SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace AND :name IN (source_name, target_name) + """, + "upsert_node": """ + INSERT INTO LIGHTRAG_GRAPH_NODES(name, content, content_vector, workspace, source_chunk_id, entity_type, description) + VALUES(:name, :content, :content_vector, :workspace, :source_chunk_id, :entity_type, :description) + ON DUPLICATE KEY UPDATE + name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector), + workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP, + source_chunk_id = VALUES(source_chunk_id), entity_type = VALUES(entity_type), description = VALUES(description) + """, + "upsert_edge": """ + INSERT INTO LIGHTRAG_GRAPH_EDGES(source_name, target_name, content, content_vector, + workspace, weight, keywords, description, source_chunk_id) + VALUES(:source_name, :target_name, :content, :content_vector, + :workspace, :weight, :keywords, :description, :source_chunk_id) ON DUPLICATE KEY UPDATE source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content), - content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP - """, + content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP, + weight = VALUES(weight), keywords = VALUES(keywords), description = VALUES(description), + source_chunk_id = VALUES(source_chunk_id) + """, } diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 69820d9a..2661d4c6 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -79,6 +79,7 @@ def import_class(*args, **kwargs): ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage") TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage") TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage") +TiDBGraphStorage = lazy_external_import(".kg.tidb_impl", "TiDBGraphStorage") AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage") @@ -282,6 +283,7 @@ def _get_storage_class(self) -> Type[BaseGraphStorage]: "Neo4JStorage": Neo4JStorage, "OracleGraphStorage": OracleGraphStorage, "AGEStorage": AGEStorage, + "TiDBGraphStorage": TiDBGraphStorage, # "ArangoDBStorage": ArangoDBStorage }