Skip to content

Commit

Permalink
support TiDBGraphStorage
Browse files Browse the repository at this point in the history
  • Loading branch information
Weaxs committed Dec 18, 2024
1 parent 874f3b3 commit 344d8f2
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 43 deletions.
5 changes: 3 additions & 2 deletions examples/lightrag_tidb_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
TIDB_PORT = ""
TIDB_USER = ""
TIDB_PASSWORD = ""
TIDB_DATABASE = ""

TIDB_DATABASE = "lightrag"

if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
Expand Down Expand Up @@ -93,6 +92,7 @@ async def main():
),
kv_storage="TiDBKVStorage",
vector_storage="TiDBVectorDBStorage",
graph_storage="TiDBGraphStorage",
)

if rag.llm_response_cache:
Expand All @@ -102,6 +102,7 @@ async def main():
rag.entities_vdb.db = tidb
rag.relationships_vdb.db = tidb
rag.chunks_vdb.db = tidb
rag.chunk_entity_relation_graph.db = tidb

# Extract and Insert into LightRAG storage
with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
Expand Down
284 changes: 243 additions & 41 deletions lightrag/kg/tidb_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlalchemy import create_engine, text
from tqdm import tqdm

from lightrag.base import BaseVectorStorage, BaseKVStorage
from lightrag.base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage
from lightrag.utils import logger


Expand Down Expand Up @@ -282,33 +282,180 @@ async def upsert(self, data: dict[str, dict]):
if self.namespace == "entities":
data = []
for item in list_data:
merge_sql = SQL_TEMPLATES["upsert_entity"]
data.append(
{
"id": item["id"],
"name": item["entity_name"],
"content": item["content"],
"content_vector": f"{item["content_vector"].tolist()}",
"workspace": self.db.workspace,
}
)
await self.db.execute(merge_sql, data)
param = {
"id": item["id"],
"name": item["entity_name"],
"content": item["content"],
"content_vector": f"{item["content_vector"].tolist()}",
"workspace": self.db.workspace,
}
# update entity_id if node inserted by graph_storage_instance before
has = await self.db.query(SQL_TEMPLATES["has_entity"], param)
if has["cnt"] != 0:
await self.db.execute(SQL_TEMPLATES["update_entity"], param)
continue

data.append(param)
if data:
merge_sql = SQL_TEMPLATES["insert_entity"]
await self.db.execute(merge_sql, data)

elif self.namespace == "relationships":
data = []
for item in list_data:
merge_sql = SQL_TEMPLATES["upsert_relationship"]
data.append(
{
"id": item["id"],
"source_name": item["src_id"],
"target_name": item["tgt_id"],
"content": item["content"],
"content_vector": f"{item["content_vector"].tolist()}",
"workspace": self.db.workspace,
}
)
await self.db.execute(merge_sql, data)
param = {
"id": item["id"],
"source_name": item["src_id"],
"target_name": item["tgt_id"],
"content": item["content"],
"content_vector": f"{item["content_vector"].tolist()}",
"workspace": self.db.workspace,
}
# update relation_id if node inserted by graph_storage_instance before
has = await self.db.query(SQL_TEMPLATES["has_relationship"], param)
if has["cnt"] != 0:
await self.db.execute(SQL_TEMPLATES["update_relationship"], param)
continue

data.append(param)
if data:
merge_sql = SQL_TEMPLATES["insert_relationship"]
await self.db.execute(merge_sql, data)


@dataclass
class TiDBGraphStorage(BaseGraphStorage):
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]

#################### upsert method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
entity_name = node_id
entity_type = node_data["entity_type"]
description = node_data["description"]
source_id = node_data["source_id"]
logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}")
content = entity_name + description
contents = [content]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0]
sql = SQL_TEMPLATES["upsert_node"]
data = {
"workspace": self.db.workspace,
"name": entity_name,
"entity_type": entity_type,
"description": description,
"source_chunk_id": source_id,
"content": content,
"content_vector": f"{content_vector.tolist()}",
}
await self.db.execute(sql, data)

async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
source_name = source_node_id
target_name = target_node_id
weight = edge_data["weight"]
keywords = edge_data["keywords"]
description = edge_data["description"]
source_chunk_id = edge_data["source_id"]
logger.debug(
f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}"
)

content = keywords + source_name + target_name + description
contents = [content]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0]
merge_sql = SQL_TEMPLATES["upsert_edge"]
data = {
"workspace": self.db.workspace,
"source_name": source_name,
"target_name": target_name,
"weight": weight,
"keywords": keywords,
"description": description,
"source_chunk_id": source_chunk_id,
"content": content,
"content_vector": f"{content_vector.tolist()}",
}
await self.db.execute(merge_sql, data)

async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()

# Query

async def has_node(self, node_id: str) -> bool:
sql = SQL_TEMPLATES["has_entity"]
param = {"name": node_id, "workspace": self.db.workspace}
has = await self.db.query(sql, param)
return has["cnt"] != 0

async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
sql = SQL_TEMPLATES["has_relationship"]
param = {
"source_name": source_node_id,
"target_name": target_node_id,
"workspace": self.db.workspace,
}
has = await self.db.query(sql, param)
return has["cnt"] != 0

async def node_degree(self, node_id: str) -> int:
sql = SQL_TEMPLATES["node_degree"]
param = {"name": node_id, "workspace": self.db.workspace}
result = await self.db.query(sql, param)
return result["cnt"]

async def edge_degree(self, src_id: str, tgt_id: str) -> int:
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
return degree

async def get_node(self, node_id: str) -> Union[dict, None]:
sql = SQL_TEMPLATES["get_node"]
param = {"name": node_id, "workspace": self.db.workspace}
return await self.db.query(sql, param)

async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
sql = SQL_TEMPLATES["get_edge"]
param = {
"source_name": source_node_id,
"target_name": target_node_id,
"workspace": self.db.workspace,
}
return await self.db.query(sql, param)

async def get_node_edges(
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
sql = SQL_TEMPLATES["get_node_edges"]
param = {"source_name": source_node_id, "workspace": self.db.workspace}
res = await self.db.query(sql, param, multirows=True)
if res:
data = [(i["source_name"], i["target_name"]) for i in res]
return data
else:
return []


N_T = {
Expand Down Expand Up @@ -362,30 +509,37 @@ async def upsert(self, data: dict[str, dict]):
"ddl": """
CREATE TABLE LIGHTRAG_GRAPH_NODES (
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
`entity_id` VARCHAR(256) NOT NULL,
`entity_id` VARCHAR(256),
`workspace` varchar(1024),
`name` VARCHAR(2048),
`entity_type` VARCHAR(1024),
`description` LONGTEXT,
`source_chunk_id` VARCHAR(256),
`content` LONGTEXT,
`content_vector` VECTOR,
`createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
`updatetime` DATETIME DEFAULT NULL,
UNIQUE KEY (`entity_id`)
KEY (`entity_id`)
);
"""
},
"LIGHTRAG_GRAPH_EDGES": {
"ddl": """
CREATE TABLE LIGHTRAG_GRAPH_EDGES (
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
`relation_id` VARCHAR(256) NOT NULL,
`relation_id` VARCHAR(256),
`workspace` varchar(1024),
`source_name` VARCHAR(2048),
`target_name` VARCHAR(2048),
`weight` DECIMAL,
`keywords` TEXT,
`description` LONGTEXT,
`source_chunk_id` varchar(256),
`content` LONGTEXT,
`content_vector` VECTOR,
`createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
`updatetime` DATETIME DEFAULT NULL,
UNIQUE KEY (`relation_id`)
KEY (`relation_id`)
);
"""
},
Expand Down Expand Up @@ -416,39 +570,87 @@ async def upsert(self, data: dict[str, dict]):
INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace)
VALUES (:id, :content, :workspace)
ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
""",
""",
"upsert_chunk": """
INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace)
VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace)
ON DUPLICATE KEY UPDATE
content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index),
full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
""",
""",
# SQL for VectorStorage
"entities": """SELECT n.name as entity_name FROM
(SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance
FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace) n
WHERE n.distance>:better_than_threshold ORDER BY n.distance DESC LIMIT :top_k""",
WHERE n.distance>:better_than_threshold ORDER BY n.distance DESC LIMIT :top_k
""",
"relationships": """SELECT e.source_name as src_id, e.target_name as tgt_id FROM
(SELECT source_name, target_name, VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace) e
WHERE e.distance>:better_than_threshold ORDER BY e.distance DESC LIMIT :top_k""",
WHERE e.distance>:better_than_threshold ORDER BY e.distance DESC LIMIT :top_k
""",
"chunks": """SELECT c.id FROM
(SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c
WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k""",
"upsert_entity": """
WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k
""",
"has_entity": """
SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace
""",
"has_relationship": """
SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace
""",
"update_entity": """
UPDATE LIGHTRAG_GRAPH_NODES SET
entity_id = :id, content = :content, content_vector = :content_vector, updatetime = CURRENT_TIMESTAMP
WHERE workspace = :workspace AND name = :name
""",
"update_relationship": """
UPDATE LIGHTRAG_GRAPH_EDGES SET
relation_id = :id, content = :content, content_vector = :content_vector, updatetime = CURRENT_TIMESTAMP
WHERE workspace = :workspace AND source_name = :source_name AND target_name = :target_name
""",
"insert_entity": """
INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace)
VALUES(:id, :name, :content, :content_vector, :workspace)
ON DUPLICATE KEY UPDATE
name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector),
workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
""",
"upsert_relationship": """
""",
"insert_relationship": """
INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace)
VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace)
""",
# SQL for GraphStorage
"get_node": """
SELECT entity_id AS id, workspace, name, entity_type, description, source_chunk_id AS source_id, content, content_vector
FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace
""",
"get_edge": """
SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id AS source_id, content, content_vector
FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace
""",
"get_node_edges": """
SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id, content, content_vector
FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND workspace = :workspace
""",
"node_degree": """
SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace AND :name IN (source_name, target_name)
""",
"upsert_node": """
INSERT INTO LIGHTRAG_GRAPH_NODES(name, content, content_vector, workspace, source_chunk_id, entity_type, description)
VALUES(:name, :content, :content_vector, :workspace, :source_chunk_id, :entity_type, :description)
ON DUPLICATE KEY UPDATE
name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector),
workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP,
source_chunk_id = VALUES(source_chunk_id), entity_type = VALUES(entity_type), description = VALUES(description)
""",
"upsert_edge": """
INSERT INTO LIGHTRAG_GRAPH_EDGES(source_name, target_name, content, content_vector,
workspace, weight, keywords, description, source_chunk_id)
VALUES(:source_name, :target_name, :content, :content_vector,
:workspace, :weight, :keywords, :description, :source_chunk_id)
ON DUPLICATE KEY UPDATE
source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content),
content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
""",
content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP,
weight = VALUES(weight), keywords = VALUES(keywords), description = VALUES(description),
source_chunk_id = VALUES(source_chunk_id)
""",
}
2 changes: 2 additions & 0 deletions lightrag/lightrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def import_class(*args, **kwargs):
ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
TiDBGraphStorage = lazy_external_import(".kg.tidb_impl", "TiDBGraphStorage")
AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage")


Expand Down Expand Up @@ -282,6 +283,7 @@ def _get_storage_class(self) -> Type[BaseGraphStorage]:
"Neo4JStorage": Neo4JStorage,
"OracleGraphStorage": OracleGraphStorage,
"AGEStorage": AGEStorage,
"TiDBGraphStorage": TiDBGraphStorage,
# "ArangoDBStorage": ArangoDBStorage
}

Expand Down

0 comments on commit 344d8f2

Please sign in to comment.