Skip to content

Commit

Permalink
fix pre commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jin38324 committed Nov 12, 2024
1 parent 9079074 commit 33caba3
Show file tree
Hide file tree
Showing 7 changed files with 346 additions and 311 deletions.
79 changes: 41 additions & 38 deletions examples/lightrag_api_oracle_demo..py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@

from fastapi import FastAPI, HTTPException, File, UploadFile
from contextlib import asynccontextmanager
from pydantic import BaseModel
from typing import Optional

import sys, os
import sys
import os
from pathlib import Path

import asyncio
Expand All @@ -13,7 +13,6 @@
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc
import numpy as np
from datetime import datetime

from lightrag.kg.oracle_impl import OracleDB

Expand All @@ -24,8 +23,6 @@
sys.path.append(os.path.abspath(script_directory))




# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()

Expand All @@ -51,6 +48,7 @@
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)


async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
Expand Down Expand Up @@ -80,8 +78,8 @@ async def get_embedding_dim():
embedding_dim = embedding.shape[1]
return embedding_dim


async def init():

# Detect embedding dimension
embedding_dimension = await get_embedding_dim()
print(f"Detected embedding dimension: {embedding_dimension}")
Expand All @@ -91,36 +89,36 @@ async def init():
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud

oracle_db = OracleDB(
config={
"user": "",
"password": "",
"dsn": "",
"config_dir": "",
"wallet_location": "",
"wallet_password": "",
"workspace": "",
} # specify which docs you want to store and query
)

oracle_db = OracleDB(config={
"user":"",
"password":"",
"dsn":"",
"config_dir":"",
"wallet_location":"",
"wallet_password":"",
"workspace":""
} # specify which docs you want to store and query
)

# Check if Oracle DB tables exist, if not, tables will be created
await oracle_db.check_tables()
# Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage
# We use Oracle DB as the KV/vector/graph storage
rag = LightRAG(
enable_llm_cache=False,
working_dir=WORKING_DIR,
chunk_token_size=512,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=512,
func=embedding_func,
),
graph_storage = "OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage"
)
enable_llm_cache=False,
working_dir=WORKING_DIR,
chunk_token_size=512,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=512,
func=embedding_func,
),
graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage",
)

# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.graph_storage_cls.db = oracle_db
Expand All @@ -129,6 +127,7 @@ async def init():

return rag


# Data models


Expand All @@ -152,6 +151,7 @@ class Response(BaseModel):

rag = None # 定义为全局对象


@asynccontextmanager
async def lifespan(app: FastAPI):
global rag
Expand All @@ -160,18 +160,21 @@ async def lifespan(app: FastAPI):
yield


app = FastAPI(title="LightRAG API", description="API for RAG operations",lifespan=lifespan)
app = FastAPI(
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
)


@app.post("/query", response_model=Response)
async def query_endpoint(request: QueryRequest):
try:
# loop = asyncio.get_event_loop()
result = await rag.aquery(
request.query,
param=QueryParam(
mode=request.mode, only_need_context=request.only_need_context
),
)
request.query,
param=QueryParam(
mode=request.mode, only_need_context=request.only_need_context
),
)
return Response(status="success", data=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
Expand Down Expand Up @@ -234,4 +237,4 @@ async def health_check():
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'

# 4. Health check:
# curl -X GET "http://127.0.0.1:8020/health"
# curl -X GET "http://127.0.0.1:8020/health"
49 changes: 27 additions & 22 deletions examples/lightrag_oracle_demo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import sys, os
import sys
import os
from pathlib import Path
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc
import numpy as np
from datetime import datetime
from lightrag.kg.oracle_impl import OracleDB

print(os.getcwd())
Expand All @@ -25,6 +25,7 @@
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)


async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
Expand Down Expand Up @@ -66,22 +67,21 @@ async def main():
# More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
oracle_db = OracleDB(config={
"user":"username",
"password":"xxxxxxxxx",
"dsn":"xxxxxxx_medium",
"config_dir":"dir/path/to/oracle/config",
"wallet_location":"dir/path/to/oracle/wallet",
"wallet_password":"xxxxxxxxx",
"workspace":"company" # specify which docs you want to store and query
oracle_db = OracleDB(
config={
"user": "username",
"password": "xxxxxxxxx",
"dsn": "xxxxxxx_medium",
"config_dir": "dir/path/to/oracle/config",
"wallet_location": "dir/path/to/oracle/wallet",
"wallet_password": "xxxxxxxxx",
"workspace": "company", # specify which docs you want to store and query
}
)

)

# Check if Oracle DB tables exist, if not, tables will be created
await oracle_db.check_tables()


# Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage
rag = LightRAG(
Expand All @@ -93,10 +93,10 @@ async def main():
embedding_dim=embedding_dimension,
max_token_size=512,
func=embedding_func,
),
graph_storage = "OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage"
),
graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage",
)

# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
Expand All @@ -106,18 +106,23 @@ async def main():

# Extract and Insert into LightRAG storage
with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
await rag.ainsert(f.read())
await rag.ainsert(f.read())

# Perform search in different modes
modes = ["naive", "local", "global", "hybrid"]
for mode in modes:
print("="*20, mode, "="*20)
print(await rag.aquery("What are the top themes in this story?", param=QueryParam(mode=mode)))
print("-"*100, "\n")
print("=" * 20, mode, "=" * 20)
print(
await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode=mode),
)
)
print("-" * 100, "\n")

except Exception as e:
print(f"An error occurred: {e}")


if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())
2 changes: 2 additions & 0 deletions lightrag/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ async def upsert(self, data: dict[str, dict]):
@dataclass
class BaseKVStorage(Generic[T], StorageNameSpace):
embedding_func: EmbeddingFunc

async def all_keys(self) -> list[str]:
raise NotImplementedError

Expand All @@ -85,6 +86,7 @@ async def drop(self):
@dataclass
class BaseGraphStorage(StorageNameSpace):
embedding_func: EmbeddingFunc = None

async def has_node(self, node_id: str) -> bool:
raise NotImplementedError

Expand Down
Loading

0 comments on commit 33caba3

Please sign in to comment.