Skip to content

Commit

Permalink
Logic Optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
jin38324 committed Nov 25, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent bf5815b commit 21f1613
Showing 8 changed files with 185 additions and 136 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -13,4 +13,4 @@ ignore_this.txt
*.ignore.*
.ruff_cache/
gui/
*.log
*.log
109 changes: 58 additions & 51 deletions examples/lightrag_api_oracle_demo..py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@

from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi import Query
from contextlib import asynccontextmanager
from pydantic import BaseModel
from typing import Optional,Any
from fastapi.responses import JSONResponse
from typing import Optional, Any

import sys
import os


import sys, os
print(os.getcwd())
from pathlib import Path
script_directory = Path(__file__).resolve().parent.parent
sys.path.append(os.path.abspath(script_directory))

import asyncio
import nest_asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc
import numpy as np
from datetime import datetime

from lightrag.kg.oracle_impl import OracleDB

print(os.getcwd())
script_directory = Path(__file__).resolve().parent.parent
sys.path.append(os.path.abspath(script_directory))


# Apply nest_asyncio to solve event loop issues
@@ -47,7 +47,8 @@

if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)



async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -77,47 +78,47 @@ async def get_embedding_dim():
embedding_dim = embedding.shape[1]
return embedding_dim


async def init():

# Detect embedding dimension
embedding_dimension = 1024 #await get_embedding_dim()
embedding_dimension = 1024 # await get_embedding_dim()
print(f"Detected embedding dimension: {embedding_dimension}")
# Create Oracle DB connection
# The `config` parameter is the connection configuration of Oracle DB
# More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud

oracle_db = OracleDB(
config={
"user": "",
"password": "",
"dsn": "",
"config_dir": "path_to_config_dir",
"wallet_location": "path_to_wallet_location",
"wallet_password": "wallet_password",
"workspace": "company",
} # specify which docs you want to store and query
)

oracle_db = OracleDB(config={
"user":"",
"password":"",
"dsn":"",
"config_dir":"path_to_config_dir",
"wallet_location":"path_to_wallet_location",
"wallet_password":"wallet_password",
"workspace":"company"
} # specify which docs you want to store and query
)

# Check if Oracle DB tables exist, if not, tables will be created
await oracle_db.check_tables()
# Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage
# We use Oracle DB as the KV/vector/graph storage
rag = LightRAG(
enable_llm_cache=False,
working_dir=WORKING_DIR,
chunk_token_size=512,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=512,
func=embedding_func,
),
graph_storage = "OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage"
)
enable_llm_cache=False,
working_dir=WORKING_DIR,
chunk_token_size=512,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=512,
func=embedding_func,
),
graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage",
)

# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.graph_storage_cls.db = oracle_db
@@ -128,7 +129,7 @@ async def init():


# Extract and Insert into LightRAG storage
#with open("./dickens/book.txt", "r", encoding="utf-8") as f:
# with open("./dickens/book.txt", "r", encoding="utf-8") as f:
# await rag.ainsert(f.read())

# # Perform search in different modes
@@ -147,9 +148,11 @@ class QueryRequest(BaseModel):
only_need_context: bool = False
only_need_prompt: bool = False


class DataRequest(BaseModel):
limit: int = 100


class InsertRequest(BaseModel):
text: str

@@ -164,6 +167,7 @@ class Response(BaseModel):

rag = None


@asynccontextmanager
async def lifespan(app: FastAPI):
global rag
@@ -172,25 +176,28 @@ async def lifespan(app: FastAPI):
yield


app = FastAPI(title="LightRAG API", description="API for RAG operations",lifespan=lifespan)
app = FastAPI(
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
)


@app.post("/query", response_model=Response)
async def query_endpoint(request: QueryRequest):
#try:
# loop = asyncio.get_event_loop()
# try:
# loop = asyncio.get_event_loop()
if request.mode == "naive":
top_k = 3
else:
top_k = 60
result = await rag.aquery(
request.query,
param=QueryParam(
mode=request.mode,
only_need_context=request.only_need_context,
only_need_prompt=request.only_need_prompt,
top_k=top_k
),
)
request.query,
param=QueryParam(
mode=request.mode,
only_need_context=request.only_need_context,
only_need_prompt=request.only_need_prompt,
top_k=top_k,
),
)
return Response(status="success", data=result)
# except Exception as e:
# raise HTTPException(status_code=500, detail=str(e))
@@ -199,9 +206,9 @@ async def query_endpoint(request: QueryRequest):
@app.get("/data", response_model=Response)
async def query_all_nodes(type: str = Query("nodes"), limit: int = Query(100)):
if type == "nodes":
result = await rag.chunk_entity_relation_graph.get_all_nodes(limit = limit)
result = await rag.chunk_entity_relation_graph.get_all_nodes(limit=limit)
elif type == "edges":
result = await rag.chunk_entity_relation_graph.get_all_edges(limit = limit)
result = await rag.chunk_entity_relation_graph.get_all_edges(limit=limit)
elif type == "statistics":
result = await rag.chunk_entity_relation_graph.get_statistics()
return Response(status="success", data=result)
@@ -264,4 +271,4 @@ async def health_check():
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'

# 4. Health check:
# curl -X GET "http://127.0.0.1:8020/health"
# curl -X GET "http://127.0.0.1:8020/health"
3 changes: 1 addition & 2 deletions examples/lightrag_oracle_demo.py
Original file line number Diff line number Diff line change
@@ -97,8 +97,7 @@ async def main():
graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage",

addon_params = {"example_number":1, "language":"Simplfied Chinese"},
addon_params={"example_number": 1, "language": "Simplfied Chinese"},
)

# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
8 changes: 5 additions & 3 deletions lightrag/kg/oracle_impl.py
Original file line number Diff line number Diff line change
@@ -114,7 +114,9 @@ async def check_tables(self):

logger.info("Finished check all tables in Oracle database")

async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]:
async def query(
self, sql: str, params: dict = None, multirows: bool = False
) -> Union[dict, None]:
async with self.pool.acquire() as connection:
connection.inputtypehandler = self.input_type_handler
connection.outputtypehandler = self.output_type_handler
@@ -256,7 +258,7 @@ async def upsert(self, data: dict[str, dict]):
item["__vector__"],
]
# print(merge_sql)
await self.db.execute(merge_sql, data)
await self.db.execute(merge_sql, values)

if self.namespace == "full_docs":
for k, v in self._data.items():
@@ -266,7 +268,7 @@ async def upsert(self, data: dict[str, dict]):
)
values = [k, self._data[k]["content"], self.db.workspace]
# print(merge_sql)
await self.db.execute(merge_sql, data)
await self.db.execute(merge_sql, values)
return left_data

async def index_done_callback(self):
8 changes: 4 additions & 4 deletions lightrag/llm.py
Original file line number Diff line number Diff line change
@@ -70,8 +70,8 @@ async def openai_complete_if_cache(
model=model, messages=messages, **kwargs
)
content = response.choices[0].message.content
if r'\u' in content:
content = content.encode('utf-8').decode('unicode_escape')
if r"\u" in content:
content = content.encode("utf-8").decode("unicode_escape")
print(content)
if hashing_kv is not None:
await hashing_kv.upsert(
@@ -542,7 +542,7 @@ async def openai_embedding(
texts: list[str],
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None
api_key: str = None,
) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
@@ -551,7 +551,7 @@ async def openai_embedding(
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
model=model, input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])

Loading

0 comments on commit 21f1613

Please sign in to comment.