Skip to content

Commit

Permalink
Fix concurrent request handling in LightRAG API
Browse files Browse the repository at this point in the history
  • Loading branch information
Latta committed Dec 20, 2024
1 parent e5dc186 commit a43d14d
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
63 changes: 63 additions & 0 deletions api/openai_lightrag_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,53 @@ def create_app(args):
# Get embedding dimensions
embedding_dim = asyncio.run(get_embedding_dim(args.embedding_model))

async def async_openai_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
):
"""Async wrapper for OpenAI completion"""
return await openai_complete_if_cache(
args.model,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)

# Initialize RAG with OpenAI configuration
global rag
rag = LightRAG(
working_dir=args.working_dir,
llm_model_func=async_openai_complete,
llm_model_name=args.model,
llm_model_max_token_size=args.max_tokens,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dim,
max_token_size=args.max_embed_tokens,
func=lambda texts: openai_embedding(texts, model=args.embedding_model),
),
graph_storage="Neo4JStorage",
log_level=args.log_level,
)
# Setup logging
logging.basicConfig(
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
)

# Initialize FastAPI app
app = FastAPI(
title="LightRAG API",
description="API for querying text using LightRAG with OpenAI integration",
)

# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)

# Initialize document manager
doc_manager = DocumentManager(args.input_dir)

# Get embedding dimensions
embedding_dim = asyncio.run(get_embedding_dim(args.embedding_model))

async def async_openai_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
):
Expand Down Expand Up @@ -273,6 +320,22 @@ async def query_text(request: QueryRequest):
param=QueryParam(mode=request.mode, stream=request.stream),
)

if request.stream:
result = ""
async for chunk in response:
result += chunk
return QueryResponse(response=result)
else:
return QueryResponse(response=response)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
async def query_text(request: QueryRequest):
try:
response = await rag.aquery(
request.query,
param=QueryParam(mode=request.mode, stream=request.stream),
)

if request.stream:
result = ""
async for chunk in response:
Expand Down
2 changes: 2 additions & 0 deletions lightrag/lightrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def _get_storage_class(self) -> Type[BaseGraphStorage]:
}

def insert(self, string_or_strings):
return asyncio.run(self.ainsert(string_or_strings))
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert(string_or_strings))

Expand Down Expand Up @@ -511,6 +512,7 @@ async def ainsert_custom_kg(self, custom_kg: dict):
await self._insert_done()

def query(self, query: str, param: QueryParam = QueryParam()):
return asyncio.run(self.aquery(query, param))
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param))

Expand Down

0 comments on commit a43d14d

Please sign in to comment.