diff --git a/examples/lightrag_tidb_demo.py b/examples/lightrag_tidb_demo.py new file mode 100644 index 00000000..fd73a354 --- /dev/null +++ b/examples/lightrag_tidb_demo.py @@ -0,0 +1,127 @@ +import asyncio +import os + +import numpy as np + +from lightrag import LightRAG, QueryParam +from lightrag.kg.tidb_impl import TiDB +from lightrag.llm import siliconcloud_embedding, openai_complete_if_cache +from lightrag.utils import EmbeddingFunc + +WORKING_DIR = "./dickens" + +# We use SiliconCloud API to call LLM on Oracle Cloud +# More docs here https://docs.siliconflow.cn/introduction +BASE_URL = "https://api.siliconflow.cn/v1/" +APIKEY = "" +CHATMODEL = "" +EMBEDMODEL = "" + +TIDB_HOST = "" +TIDB_PORT = "" +TIDB_USER = "" +TIDB_PASSWORD = "" +TIDB_DATABASE = "" + + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> str: + return await openai_complete_if_cache( + CHATMODEL, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=APIKEY, + base_url=BASE_URL, + **kwargs, + ) + + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await siliconcloud_embedding( + texts, + # model=EMBEDMODEL, + api_key=APIKEY, + ) + + +async def get_embedding_dim(): + test_text = ["This is a test sentence."] + embedding = await embedding_func(test_text) + embedding_dim = embedding.shape[1] + return embedding_dim + + +async def main(): + try: + # Detect embedding dimension + embedding_dimension = await get_embedding_dim() + print(f"Detected embedding dimension: {embedding_dimension}") + + # Create TiDB DB connection + tidb = TiDB( + config={ + "host": TIDB_HOST, + "port": TIDB_PORT, + "user": TIDB_USER, + "password": TIDB_PASSWORD, + "database": TIDB_DATABASE, + "workspace": "company", # specify which docs you want to store and query + } + ) + + # Check if TiDB DB tables exist, if not, tables will be created + await tidb.check_tables() + + # Initialize LightRAG + # We use TiDB DB as the KV/vector + # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt + rag = LightRAG( + enable_llm_cache=False, + working_dir=WORKING_DIR, + chunk_token_size=512, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=512, + func=embedding_func, + ), + kv_storage="TiDBKVStorage", + vector_storage="TiDBVectorDBStorage", + ) + + if rag.llm_response_cache: + rag.llm_response_cache.db = tidb + rag.full_docs.db = tidb + rag.text_chunks.db = tidb + rag.entities_vdb.db = tidb + rag.relationships_vdb.db = tidb + rag.chunks_vdb.db = tidb + + # Extract and Insert into LightRAG storage + with open("./dickens/demo.txt", "r", encoding="utf-8") as f: + await rag.ainsert(f.read()) + + # Perform search in different modes + modes = ["naive", "local", "global", "hybrid"] + for mode in modes: + print("=" * 20, mode, "=" * 20) + print( + await rag.aquery( + "What are the top themes in this story?", + param=QueryParam(mode=mode), + ) + ) + print("-" * 100, "\n") + + except Exception as e: + print(f"An error occurred: {e}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py new file mode 100644 index 00000000..66b49fe3 --- /dev/null +++ b/lightrag/kg/tidb_impl.py @@ -0,0 +1,454 @@ +import asyncio +import os +from dataclasses import dataclass +from typing import Union + +import numpy as np +from sqlalchemy import create_engine, text +from tqdm import tqdm + +from lightrag.base import BaseVectorStorage, BaseKVStorage +from lightrag.utils import logger + + +class TiDB(object): + def __init__(self, config, **kwargs): + self.host = config.get("host", None) + self.port = config.get("port", None) + self.user = config.get("user", None) + self.password = config.get("password", None) + self.database = config.get("database", None) + self.workspace = config.get("workspace", None) + connection_string = ( + f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" + f"?ssl_verify_cert=true&ssl_verify_identity=true" + ) + + try: + self.engine = create_engine(connection_string) + logger.info(f"Connected to TiDB database at {self.database}") + except Exception as e: + logger.error(f"Failed to connect to TiDB database at {self.database}") + logger.error(f"TiDB database error: {e}") + raise + + async def check_tables(self): + for k, v in TABLES.items(): + try: + await self.query(f"SELECT 1 FROM {k}".format(k=k)) + except Exception as e: + logger.error(f"Failed to check table {k} in TiDB database") + logger.error(f"TiDB database error: {e}") + try: + # print(v["ddl"]) + await self.execute(v["ddl"]) + logger.info(f"Created table {k} in TiDB database") + except Exception as e: + logger.error(f"Failed to create table {k} in TiDB database") + logger.error(f"TiDB database error: {e}") + + async def query( + self, sql: str, params: dict = None, multirows: bool = False + ) -> Union[dict, None]: + if params is None: + params = {"workspace": self.workspace} + else: + params.update({"workspace": self.workspace}) + with self.engine.connect() as conn, conn.begin(): + try: + result = conn.execute(text(sql), params) + except Exception as e: + logger.error(f"Tidb database error: {e}") + print(sql) + print(params) + raise + if multirows: + rows = result.all() + if rows: + data = [dict(zip(result.keys(), row)) for row in rows] + else: + data = [] + else: + row = result.first() + if row: + data = dict(zip(result.keys(), row)) + else: + data = None + return data + + async def execute(self, sql: str, data: list | dict = None): + # logger.info("go into TiDBDB execute method") + try: + with self.engine.connect() as conn, conn.begin(): + if data is None: + conn.execute(text(sql)) + else: + conn.execute(text(sql), parameters=data) + except Exception as e: + logger.error(f"TiDB database error: {e}") + print(sql) + print(data) + raise + + +@dataclass +class TiDBKVStorage(BaseKVStorage): + # should pass db object to self.db + def __post_init__(self): + self._data = {} + self._max_batch_size = self.global_config["embedding_batch_num"] + + ################ QUERY METHODS ################ + + async def get_by_id(self, id: str) -> Union[dict, None]: + """根据 id 获取 doc_full 数据.""" + SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] + params = {"id": id} + # print("get_by_id:"+SQL) + res = await self.db.query(SQL, params) + if res: + data = res # {"data":res} + # print (data) + return data + else: + return None + + # Query by id + async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: + """根据 id 获取 doc_chunks 数据""" + SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( + ids=",".join([f"'{id}'" for id in ids]) + ) + # print("get_by_ids:"+SQL) + res = await self.db.query(SQL, multirows=True) + if res: + data = res # [{"data":i} for i in res] + # print(data) + return data + else: + return None + + async def filter_keys(self, keys: list[str]) -> set[str]: + """过滤掉重复内容""" + SQL = SQL_TEMPLATES["filter_keys"].format( + table_name=N_T[self.namespace], + id_field=N_ID[self.namespace], + ids=",".join([f"'{id}'" for id in keys]), + ) + try: + await self.db.query(SQL) + except Exception as e: + logger.error(f"Tidb database error: {e}") + print(SQL) + res = await self.db.query(SQL, multirows=True) + if res: + exist_keys = [key["id"] for key in res] + data = set([s for s in keys if s not in exist_keys]) + else: + exist_keys = [] + data = set([s for s in keys if s not in exist_keys]) + return data + + ################ INSERT full_doc AND chunks ################ + async def upsert(self, data: dict[str, dict]): + left_data = {k: v for k, v in data.items() if k not in self._data} + self._data.update(left_data) + if self.namespace == "text_chunks": + list_data = [ + { + "__id__": k, + **{k1: v1 for k1, v1 in v.items()}, + } + for k, v in data.items() + ] + contents = [v["content"] for v in data.values()] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + embeddings_list = await asyncio.gather( + *[self.embedding_func(batch) for batch in batches] + ) + embeddings = np.concatenate(embeddings_list) + for i, d in enumerate(list_data): + d["__vector__"] = embeddings[i] + + merge_sql = SQL_TEMPLATES["upsert_chunk"] + data = [] + for item in list_data: + data.append( + { + "id": item["__id__"], + "content": item["content"], + "tokens": item["tokens"], + "chunk_order_index": item["chunk_order_index"], + "full_doc_id": item["full_doc_id"], + "content_vector": f"{item["__vector__"].tolist()}", + "workspace": self.db.workspace, + } + ) + await self.db.execute(merge_sql, data) + + if self.namespace == "full_docs": + merge_sql = SQL_TEMPLATES["upsert_doc_full"] + data = [] + for k, v in self._data.items(): + data.append( + { + "id": k, + "content": v["content"], + "workspace": self.db.workspace, + } + ) + await self.db.execute(merge_sql, data) + return left_data + + async def index_done_callback(self): + if self.namespace in ["full_docs", "text_chunks"]: + logger.info("full doc and chunk data had been saved into TiDB db!") + + +@dataclass +class TiDBVectorDBStorage(BaseVectorStorage): + cosine_better_than_threshold: float = 0.2 + + def __post_init__(self): + self._client_file_name = os.path.join( + self.global_config["working_dir"], f"vdb_{self.namespace}.json" + ) + self._max_batch_size = self.global_config["embedding_batch_num"] + self.cosine_better_than_threshold = self.global_config.get( + "cosine_better_than_threshold", self.cosine_better_than_threshold + ) + + async def query(self, query: str, top_k: int) -> list[dict]: + """search from tidb vector""" + + embeddings = await self.embedding_func([query]) + embedding = embeddings[0] + + embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]" + + params = { + "embedding_string": embedding_string, + "top_k": top_k, + "better_than_threshold": self.cosine_better_than_threshold, + } + + results = await self.db.query( + SQL_TEMPLATES[self.namespace], params=params, multirows=True + ) + print("vector search result:", results) + if not results: + return [] + return results + + ###### INSERT entities And relationships ###### + async def upsert(self, data: dict[str, dict]): + # ignore, upsert in TiDBKVStorage already + if not len(data): + logger.warning("You insert an empty data to vector DB") + return [] + if self.namespace == "chunks": + return [] + logger.info(f"Inserting {len(data)} vectors to {self.namespace}") + + list_data = [ + { + "id": k, + **{k1: v1 for k1, v1 in v.items()}, + } + for k, v in data.items() + ] + contents = [v["content"] for v in data.values()] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + embedding_tasks = [self.embedding_func(batch) for batch in batches] + embeddings_list = [] + for f in tqdm( + asyncio.as_completed(embedding_tasks), + total=len(embedding_tasks), + desc="Generating embeddings", + unit="batch", + ): + embeddings = await f + embeddings_list.append(embeddings) + embeddings = np.concatenate(embeddings_list) + for i, d in enumerate(list_data): + d["content_vector"] = embeddings[i] + + if self.namespace == "entities": + data = [] + for item in list_data: + merge_sql = SQL_TEMPLATES["upsert_entity"] + data.append( + { + "id": item["id"], + "name": item["entity_name"], + "content": item["content"], + "content_vector": f"{item["content_vector"].tolist()}", + "workspace": self.db.workspace, + } + ) + await self.db.execute(merge_sql, data) + + elif self.namespace == "relationships": + data = [] + for item in list_data: + merge_sql = SQL_TEMPLATES["upsert_relationship"] + data.append( + { + "id": item["id"], + "source_name": item["src_id"], + "target_name": item["tgt_id"], + "content": item["content"], + "content_vector": f"{item["content_vector"].tolist()}", + "workspace": self.db.workspace, + } + ) + await self.db.execute(merge_sql, data) + + +N_T = { + "full_docs": "LIGHTRAG_DOC_FULL", + "text_chunks": "LIGHTRAG_DOC_CHUNKS", + "chunks": "LIGHTRAG_DOC_CHUNKS", + "entities": "LIGHTRAG_GRAPH_NODES", + "relationships": "LIGHTRAG_GRAPH_EDGES", +} +N_ID = { + "full_docs": "doc_id", + "text_chunks": "chunk_id", + "chunks": "chunk_id", + "entities": "entity_id", + "relationships": "relation_id", +} + +TABLES = { + "LIGHTRAG_DOC_FULL": { + "ddl": """ + CREATE TABLE LIGHTRAG_DOC_FULL ( + `id` BIGINT PRIMARY KEY AUTO_RANDOM, + `doc_id` VARCHAR(256) NOT NULL, + `workspace` varchar(1024), + `content` LONGTEXT, + `meta` JSON, + `createtime` TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + `updatetime` TIMESTAMP DEFAULT NULL, + UNIQUE KEY (`doc_id`) + ); + """ + }, + "LIGHTRAG_DOC_CHUNKS": { + "ddl": """ + CREATE TABLE LIGHTRAG_DOC_CHUNKS ( + `id` BIGINT PRIMARY KEY AUTO_RANDOM, + `chunk_id` VARCHAR(256) NOT NULL, + `full_doc_id` VARCHAR(256) NOT NULL, + `workspace` varchar(1024), + `chunk_order_index` INT, + `tokens` INT, + `content` LONGTEXT, + `content_vector` VECTOR, + `createtime` DATETIME DEFAULT CURRENT_TIMESTAMP, + `updatetime` DATETIME DEFAULT NULL, + UNIQUE KEY (`chunk_id`) + ); + """ + }, + "LIGHTRAG_GRAPH_NODES": { + "ddl": """ + CREATE TABLE LIGHTRAG_GRAPH_NODES ( + `id` BIGINT PRIMARY KEY AUTO_RANDOM, + `entity_id` VARCHAR(256) NOT NULL, + `workspace` varchar(1024), + `name` VARCHAR(2048), + `content` LONGTEXT, + `content_vector` VECTOR, + `createtime` DATETIME DEFAULT CURRENT_TIMESTAMP, + `updatetime` DATETIME DEFAULT NULL, + UNIQUE KEY (`entity_id`) + ); + """ + }, + "LIGHTRAG_GRAPH_EDGES": { + "ddl": """ + CREATE TABLE LIGHTRAG_GRAPH_EDGES ( + `id` BIGINT PRIMARY KEY AUTO_RANDOM, + `relation_id` VARCHAR(256) NOT NULL, + `workspace` varchar(1024), + `source_name` VARCHAR(2048), + `target_name` VARCHAR(2048), + `content` LONGTEXT, + `content_vector` VECTOR, + `createtime` DATETIME DEFAULT CURRENT_TIMESTAMP, + `updatetime` DATETIME DEFAULT NULL, + UNIQUE KEY (`relation_id`) + ); + """ + }, + "LIGHTRAG_LLM_CACHE": { + "ddl": """ + CREATE TABLE LIGHTRAG_LLM_CACHE ( + id BIGINT PRIMARY KEY AUTO_INCREMENT, + send TEXT, + return TEXT, + model VARCHAR(1024), + createtime DATETIME DEFAULT CURRENT_TIMESTAMP, + updatetime DATETIME DEFAULT NULL + ); + """ + }, +} + + +SQL_TEMPLATES = { + # SQL for KVStorage + "get_by_id_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id = :id AND workspace = :workspace", + "get_by_id_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id = :id AND workspace = :workspace", + "get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace", + "get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace", + "filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace", + # SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE) + "upsert_doc_full": """ + INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace) + VALUES (:id, :content, :workspace) + ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP + """, + "upsert_chunk": """ + INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace) + VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace) + ON DUPLICATE KEY UPDATE + content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index), + full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP + """, + # SQL for VectorStorage + "entities": """SELECT n.name as entity_name FROM + (SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance + FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace) n + WHERE n.distance>:better_than_threshold ORDER BY n.distance DESC LIMIT :top_k""", + "relationships": """SELECT e.source_name as src_id, e.target_name as tgt_id FROM + (SELECT source_name, target_name, VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance + FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace) e + WHERE e.distance>:better_than_threshold ORDER BY e.distance DESC LIMIT :top_k""", + "chunks": """SELECT c.id FROM + (SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance + FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c + WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k""", + "upsert_entity": """ + INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace) + VALUES(:id, :name, :content, :content_vector, :workspace) + ON DUPLICATE KEY UPDATE + name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector), + workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP + """, + "upsert_relationship": """ + INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace) + VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace) + ON DUPLICATE KEY UPDATE + source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content), + content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP + """, +} diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 36576368..5a337a08 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -77,6 +77,8 @@ def import_class(*args, **kwargs): MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge") MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage") ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage") +TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage") +TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage") def always_get_an_event_loop() -> asyncio.AbstractEventLoop: @@ -260,11 +262,13 @@ def _get_storage_class(self) -> Type[BaseGraphStorage]: "JsonKVStorage": JsonKVStorage, "OracleKVStorage": OracleKVStorage, "MongoKVStorage": MongoKVStorage, + "TiDBKVStorage": TiDBKVStorage, # vector storage "NanoVectorDBStorage": NanoVectorDBStorage, "OracleVectorDBStorage": OracleVectorDBStorage, "MilvusVectorDBStorge": MilvusVectorDBStorge, "ChromaVectorDBStorage": ChromaVectorDBStorage, + "TiDBVectorDBStorage": TiDBVectorDBStorage, # graph storage "NetworkXStorage": NetworkXStorage, "Neo4JStorage": Neo4JStorage, diff --git a/requirements.txt b/requirements.txt index ad96fe7d..3cc48028 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,9 +13,12 @@ openai oracledb pymilvus pymongo +pymysql pyvis -tenacity # lmdeploy[all] +sqlalchemy +tenacity + # LLM packages tiktoken