Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support TiDB: add TiDBKVStorage, TiDBVectorDBStorage #452

Merged
merged 2 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions examples/lightrag_tidb_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import asyncio
import os

import numpy as np

from lightrag import LightRAG, QueryParam
from lightrag.kg.tidb_impl import TiDB
from lightrag.llm import siliconcloud_embedding, openai_complete_if_cache
from lightrag.utils import EmbeddingFunc

WORKING_DIR = "./dickens"

# We use SiliconCloud API to call LLM on Oracle Cloud
# More docs here https://docs.siliconflow.cn/introduction
BASE_URL = "https://api.siliconflow.cn/v1/"
APIKEY = ""
CHATMODEL = ""
EMBEDMODEL = ""

TIDB_HOST = ""
TIDB_PORT = ""
TIDB_USER = ""
TIDB_PASSWORD = ""
TIDB_DATABASE = ""


if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)


async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await openai_complete_if_cache(
CHATMODEL,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=APIKEY,
base_url=BASE_URL,
**kwargs,
)


async def embedding_func(texts: list[str]) -> np.ndarray:
return await siliconcloud_embedding(
texts,
# model=EMBEDMODEL,
api_key=APIKEY,
)


async def get_embedding_dim():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
embedding_dim = embedding.shape[1]
return embedding_dim


async def main():
try:
# Detect embedding dimension
embedding_dimension = await get_embedding_dim()
print(f"Detected embedding dimension: {embedding_dimension}")

# Create TiDB DB connection
tidb = TiDB(
config={
"host": TIDB_HOST,
"port": TIDB_PORT,
"user": TIDB_USER,
"password": TIDB_PASSWORD,
"database": TIDB_DATABASE,
"workspace": "company", # specify which docs you want to store and query
}
)

# Check if TiDB DB tables exist, if not, tables will be created
await tidb.check_tables()

# Initialize LightRAG
# We use TiDB DB as the KV/vector
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
rag = LightRAG(
enable_llm_cache=False,
working_dir=WORKING_DIR,
chunk_token_size=512,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=512,
func=embedding_func,
),
kv_storage="TiDBKVStorage",
vector_storage="TiDBVectorDBStorage",
)

if rag.llm_response_cache:
rag.llm_response_cache.db = tidb
rag.full_docs.db = tidb
rag.text_chunks.db = tidb
rag.entities_vdb.db = tidb
rag.relationships_vdb.db = tidb
rag.chunks_vdb.db = tidb

# Extract and Insert into LightRAG storage
with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
await rag.ainsert(f.read())

# Perform search in different modes
modes = ["naive", "local", "global", "hybrid"]
for mode in modes:
print("=" * 20, mode, "=" * 20)
print(
await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode=mode),
)
)
print("-" * 100, "\n")

except Exception as e:
print(f"An error occurred: {e}")


if __name__ == "__main__":
asyncio.run(main())
Loading
Loading