From 469fa9f57479f6dcd82dad4adbb077f93c9e4ac0 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Sun, 22 Dec 2024 00:38:38 +0100 Subject: [PATCH] Added lollms integration with lightrag Removed a depricated function from ollamaserver --- api/README_LOLLMS.md | 177 +++++++++++++++ api/lollms_lightrag_server.py | 401 ++++++++++++++++++++++++++++++++++ api/ollama_lightrag_server.py | 4 +- lightrag/llm.py | 111 ++++++++++ 4 files changed, 691 insertions(+), 2 deletions(-) create mode 100644 api/README_LOLLMS.md create mode 100644 api/lollms_lightrag_server.py diff --git a/api/README_LOLLMS.md b/api/README_LOLLMS.md new file mode 100644 index 00000000..d56ac909 --- /dev/null +++ b/api/README_LOLLMS.md @@ -0,0 +1,177 @@ +# LightRAG API Server + +A powerful FastAPI-based server for managing and querying documents using LightRAG (Light Retrieval-Augmented Generation). This server provides a REST API interface for document management and intelligent querying using various LLM models through LoLLMS. + +## Features + +- 🔍 Multiple search modes (naive, local, global, hybrid) +- 📡 Streaming and non-streaming responses +- 📝 Document management (insert, batch upload, clear) +- ⚙️ Highly configurable model parameters +- 📚 Support for text and file uploads +- 🔧 RESTful API with automatic documentation +- 🚀 Built with FastAPI for high performance + +## Prerequisites + +- Python 3.8+ +- LoLLMS server running locally or remotely +- Required Python packages: + - fastapi + - uvicorn + - lightrag + - pydantic + +## Installation +If you are using windows, you will need to donwload and install visual c++ build tools from [https://visualstudio.microsoft.com/visual-cpp-build-tools/ ](https://visualstudio.microsoft.com/visual-cpp-build-tools/) +Make sure you install the VS 2022 C++ x64/x86 Build tools like from indivisual componants tab: +![image](https://github.com/user-attachments/assets/3723e15b-0a2c-42ed-aebf-e595a9f9c946) + +This is mandatory for builmding some modules. + +1. Clone the repository: +```bash +git clone https://github.com/ParisNeo/LightRAG.git +cd api +``` + +2. Install dependencies: +```bash +pip install -r requirements.txt +``` + +3. Make sure LoLLMS is running and accessible. + +## Configuration + +The server can be configured using command-line arguments: + +```bash +python ollama_lightollama_lightrag_server.py --help +``` + +Available options: + +| Parameter | Default | Description | +|-----------|---------|-------------| +| --host | 0.0.0.0 | Server host | +| --port | 9621 | Server port | +| --model | mistral-nemo:latest | LLM model name | +| --embedding-model | bge-m3:latest | Embedding model name | +| --lollms-host | http://localhost:11434 | LoLLMS host URL | +| --working-dir | ./rag_storage | Working directory for RAG | +| --max-async | 4 | Maximum async operations | +| --max-tokens | 32768 | Maximum token size | +| --embedding-dim | 1024 | Embedding dimensions | +| --max-embed-tokens | 8192 | Maximum embedding token size | +| --input-file | ./book.txt | Initial input file | +| --log-level | INFO | Logging level | + +## Quick Start + +1. Basic usage with default settings: +```bash +python ollama_lightrag_server.py +``` + +2. Custom configuration: +```bash +python ollama_lightrag_server.py --model llama2:13b --port 8080 --working-dir ./custom_rag +``` + +Make sure the models are installed in your lollms instance +```bash +python ollama_lightrag_server.py --model mistral-nemo:latest --embedding-model bge-m3 --embedding-dim 1024 +``` + +## API Endpoints + +### Query Endpoints + +#### POST /query +Query the RAG system with options for different search modes. + +```bash +curl -X POST "http://localhost:9621/query" \ + -H "Content-Type: application/json" \ + -d '{"query": "Your question here", "mode": "hybrid"}' +``` + +#### POST /query/stream +Stream responses from the RAG system. + +```bash +curl -X POST "http://localhost:9621/query/stream" \ + -H "Content-Type: application/json" \ + -d '{"query": "Your question here", "mode": "hybrid"}' +``` + +### Document Management Endpoints + +#### POST /documents/text +Insert text directly into the RAG system. + +```bash +curl -X POST "http://localhost:9621/documents/text" \ + -H "Content-Type: application/json" \ + -d '{"text": "Your text content here", "description": "Optional description"}' +``` + +#### POST /documents/file +Upload a single file to the RAG system. + +```bash +curl -X POST "http://localhost:9621/documents/file" \ + -F "file=@/path/to/your/document.txt" \ + -F "description=Optional description" +``` + +#### POST /documents/batch +Upload multiple files at once. + +```bash +curl -X POST "http://localhost:9621/documents/batch" \ + -F "files=@/path/to/doc1.txt" \ + -F "files=@/path/to/doc2.txt" +``` + +#### DELETE /documents +Clear all documents from the RAG system. + +```bash +curl -X DELETE "http://localhost:9621/documents" +``` + +### Utility Endpoints + +#### GET /health +Check server health and configuration. + +```bash +curl "http://localhost:9621/health" +``` + +## Development + +### Running in Development Mode + +```bash +uvicorn ollama_lightrag_server:app --reload --port 9621 +``` + +### API Documentation + +When the server is running, visit: +- Swagger UI: http://localhost:9621/docs +- ReDoc: http://localhost:9621/redoc + + +## License + +This project is licensed under the MIT License - see the LICENSE file for details. + +## Acknowledgments + +- Built with [FastAPI](https://fastapi.tiangolo.com/) +- Uses [LightRAG](https://github.com/HKUDS/LightRAG) for document processing +- Powered by [LoLLMS](https://lollms.ai/) for LLM inference diff --git a/api/lollms_lightrag_server.py b/api/lollms_lightrag_server.py new file mode 100644 index 00000000..4babcaa8 --- /dev/null +++ b/api/lollms_lightrag_server.py @@ -0,0 +1,401 @@ +from fastapi import FastAPI, HTTPException, File, UploadFile, Form +from pydantic import BaseModel +import logging +import argparse +from lightrag import LightRAG, QueryParam +from lightrag.llm import lollms_model_complete, lollms_embed +from lightrag.utils import EmbeddingFunc +from typing import Optional, List +from enum import Enum +from pathlib import Path +import shutil +import aiofiles +from ascii_colors import trace_exception + + +def parse_args(): + parser = argparse.ArgumentParser( + description="LightRAG FastAPI Server with separate working and input directories" + ) + + # Server configuration + parser.add_argument( + "--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)" + ) + parser.add_argument( + "--port", type=int, default=9621, help="Server port (default: 9621)" + ) + + # Directory configuration + parser.add_argument( + "--working-dir", + default="./rag_storage", + help="Working directory for RAG storage (default: ./rag_storage)", + ) + parser.add_argument( + "--input-dir", + default="./inputs", + help="Directory containing input documents (default: ./inputs)", + ) + + # Model configuration + parser.add_argument( + "--model", + default="mistral-nemo:latest", + help="LLM model name (default: mistral-nemo:latest)", + ) + parser.add_argument( + "--embedding-model", + default="bge-m3:latest", + help="Embedding model name (default: bge-m3:latest)", + ) + parser.add_argument( + "--lollms-host", + default="http://localhost:11434", + help="lollms host URL (default: http://localhost:11434)", + ) + + # RAG configuration + parser.add_argument( + "--max-async", type=int, default=4, help="Maximum async operations (default: 4)" + ) + parser.add_argument( + "--max-tokens", + type=int, + default=32768, + help="Maximum token size (default: 32768)", + ) + parser.add_argument( + "--embedding-dim", + type=int, + default=1024, + help="Embedding dimensions (default: 1024)", + ) + parser.add_argument( + "--max-embed-tokens", + type=int, + default=8192, + help="Maximum embedding token size (default: 8192)", + ) + + # Logging configuration + parser.add_argument( + "--log-level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Logging level (default: INFO)", + ) + + return parser.parse_args() + + +class DocumentManager: + """Handles document operations and tracking""" + + def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")): + self.input_dir = Path(input_dir) + self.supported_extensions = supported_extensions + self.indexed_files = set() + + # Create input directory if it doesn't exist + self.input_dir.mkdir(parents=True, exist_ok=True) + + def scan_directory(self) -> List[Path]: + """Scan input directory for new files""" + new_files = [] + for ext in self.supported_extensions: + for file_path in self.input_dir.rglob(f"*{ext}"): + if file_path not in self.indexed_files: + new_files.append(file_path) + return new_files + + def mark_as_indexed(self, file_path: Path): + """Mark a file as indexed""" + self.indexed_files.add(file_path) + + def is_supported_file(self, filename: str) -> bool: + """Check if file type is supported""" + return any(filename.lower().endswith(ext) for ext in self.supported_extensions) + + +# Pydantic models +class SearchMode(str, Enum): + naive = "naive" + local = "local" + global_ = "global" + hybrid = "hybrid" + + +class QueryRequest(BaseModel): + query: str + mode: SearchMode = SearchMode.hybrid + stream: bool = False + + +class QueryResponse(BaseModel): + response: str + + +class InsertTextRequest(BaseModel): + text: str + description: Optional[str] = None + + +class InsertResponse(BaseModel): + status: str + message: str + document_count: int + + +def create_app(args): + # Setup logging + logging.basicConfig( + format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) + ) + + # Initialize FastAPI app + app = FastAPI( + title="LightRAG API", + description="API for querying text using LightRAG with separate storage and input directories", + ) + + # Create working directory if it doesn't exist + Path(args.working_dir).mkdir(parents=True, exist_ok=True) + + # Initialize document manager + doc_manager = DocumentManager(args.input_dir) + + # Initialize RAG + rag = LightRAG( + working_dir=args.working_dir, + llm_model_func=lollms_model_complete, + llm_model_name=args.model, + llm_model_max_async=args.max_async, + llm_model_max_token_size=args.max_tokens, + llm_model_kwargs={ + "host": args.lollms_host, + "options": {"num_ctx": args.max_tokens}, + }, + embedding_func=EmbeddingFunc( + embedding_dim=args.embedding_dim, + max_token_size=args.max_embed_tokens, + func=lambda texts: lollms_embed( + texts, embed_model=args.embedding_model, host=args.lollms_host + ), + ), + ) + + @app.on_event("startup") + async def startup_event(): + """Index all files in input directory during startup""" + try: + new_files = doc_manager.scan_directory() + for file_path in new_files: + try: + # Use async file reading + async with aiofiles.open(file_path, "r", encoding="utf-8") as f: + content = await f.read() + # Use the async version of insert directly + await rag.ainsert(content) + doc_manager.mark_as_indexed(file_path) + logging.info(f"Indexed file: {file_path}") + except Exception as e: + trace_exception(e) + logging.error(f"Error indexing file {file_path}: {str(e)}") + + logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}") + + except Exception as e: + logging.error(f"Error during startup indexing: {str(e)}") + + @app.post("/documents/scan") + async def scan_for_new_documents(): + """Manually trigger scanning for new documents""" + try: + new_files = doc_manager.scan_directory() + indexed_count = 0 + + for file_path in new_files: + try: + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + rag.insert(content) + doc_manager.mark_as_indexed(file_path) + indexed_count += 1 + except Exception as e: + logging.error(f"Error indexing file {file_path}: {str(e)}") + + return { + "status": "success", + "indexed_count": indexed_count, + "total_documents": len(doc_manager.indexed_files), + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post("/documents/upload") + async def upload_to_input_dir(file: UploadFile = File(...)): + """Upload a file to the input directory""" + try: + if not doc_manager.is_supported_file(file.filename): + raise HTTPException( + status_code=400, + detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", + ) + + file_path = doc_manager.input_dir / file.filename + with open(file_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + + # Immediately index the uploaded file + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + rag.insert(content) + doc_manager.mark_as_indexed(file_path) + + return { + "status": "success", + "message": f"File uploaded and indexed: {file.filename}", + "total_documents": len(doc_manager.indexed_files), + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post("/query", response_model=QueryResponse) + async def query_text(request: QueryRequest): + try: + response = await rag.aquery( + request.query, + param=QueryParam(mode=request.mode, stream=request.stream), + ) + + if request.stream: + result = "" + async for chunk in response: + result += chunk + return QueryResponse(response=result) + else: + return QueryResponse(response=response) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post("/query/stream") + async def query_text_stream(request: QueryRequest): + try: + response = rag.query( + request.query, param=QueryParam(mode=request.mode, stream=True) + ) + + async def stream_generator(): + async for chunk in response: + yield chunk + + return stream_generator() + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post("/documents/text", response_model=InsertResponse) + async def insert_text(request: InsertTextRequest): + try: + rag.insert(request.text) + return InsertResponse( + status="success", + message="Text successfully inserted", + document_count=len(rag), + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post("/documents/file", response_model=InsertResponse) + async def insert_file(file: UploadFile = File(...), description: str = Form(None)): + try: + content = await file.read() + + if file.filename.endswith((".txt", ".md")): + text = content.decode("utf-8") + rag.insert(text) + else: + raise HTTPException( + status_code=400, + detail="Unsupported file type. Only .txt and .md files are supported", + ) + + return InsertResponse( + status="success", + message=f"File '{file.filename}' successfully inserted", + document_count=len(rag), + ) + except UnicodeDecodeError: + raise HTTPException(status_code=400, detail="File encoding not supported") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post("/documents/batch", response_model=InsertResponse) + async def insert_batch(files: List[UploadFile] = File(...)): + try: + inserted_count = 0 + failed_files = [] + + for file in files: + try: + content = await file.read() + if file.filename.endswith((".txt", ".md")): + text = content.decode("utf-8") + rag.insert(text) + inserted_count += 1 + else: + failed_files.append(f"{file.filename} (unsupported type)") + except Exception as e: + failed_files.append(f"{file.filename} ({str(e)})") + + status_message = f"Successfully inserted {inserted_count} documents" + if failed_files: + status_message += f". Failed files: {', '.join(failed_files)}" + + return InsertResponse( + status="success" if inserted_count > 0 else "partial_success", + message=status_message, + document_count=len(rag), + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.delete("/documents", response_model=InsertResponse) + async def clear_documents(): + try: + rag.text_chunks = [] + rag.entities_vdb = None + rag.relationships_vdb = None + return InsertResponse( + status="success", + message="All documents cleared successfully", + document_count=0, + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.get("/health") + async def get_status(): + """Get current system status""" + return { + "status": "healthy", + "working_directory": str(args.working_dir), + "input_directory": str(args.input_dir), + "indexed_files": len(doc_manager.indexed_files), + "configuration": { + "model": args.model, + "embedding_model": args.embedding_model, + "max_tokens": args.max_tokens, + "lollms_host": args.lollms_host, + }, + } + + return app + + +if __name__ == "__main__": + args = parse_args() + import uvicorn + + app = create_app(args) + uvicorn.run(app, host=args.host, port=args.port) diff --git a/api/ollama_lightrag_server.py b/api/ollama_lightrag_server.py index 850e814f..055532c8 100644 --- a/api/ollama_lightrag_server.py +++ b/api/ollama_lightrag_server.py @@ -3,7 +3,7 @@ import logging import argparse from lightrag import LightRAG, QueryParam -from lightrag.llm import ollama_model_complete, ollama_embedding +from lightrag.llm import ollama_model_complete, ollama_embed from lightrag.utils import EmbeddingFunc from typing import Optional, List from enum import Enum @@ -179,7 +179,7 @@ def create_app(args): embedding_func=EmbeddingFunc( embedding_dim=args.embedding_dim, max_token_size=args.max_embed_tokens, - func=lambda texts: ollama_embedding( + func=lambda texts: ollama_embed( texts, embed_model=args.embedding_model, host=args.ollama_host ), ), diff --git a/lightrag/llm.py b/lightrag/llm.py index e89af0d8..190039f1 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -339,6 +339,62 @@ async def inner(): return response["message"]["content"] +async def lollms_model_if_cache( + model, + prompt, + system_prompt=None, + history_messages=[], + base_url="http://localhost:9600", + **kwargs, +) -> Union[str, AsyncIterator[str]]: + """Client implementation for lollms generation.""" + + stream = True if kwargs.get("stream") else False + + # Extract lollms specific parameters + request_data = { + "prompt": prompt, + "model_name": model, + "personality": kwargs.get("personality", -1), + "n_predict": kwargs.get("n_predict", None), + "stream": stream, + "temperature": kwargs.get("temperature", 0.1), + "top_k": kwargs.get("top_k", 50), + "top_p": kwargs.get("top_p", 0.95), + "repeat_penalty": kwargs.get("repeat_penalty", 0.8), + "repeat_last_n": kwargs.get("repeat_last_n", 40), + "seed": kwargs.get("seed", None), + "n_threads": kwargs.get("n_threads", 8), + } + + # Prepare the full prompt including history + full_prompt = "" + if system_prompt: + full_prompt += f"{system_prompt}\n" + for msg in history_messages: + full_prompt += f"{msg['role']}: {msg['content']}\n" + full_prompt += prompt + + request_data["prompt"] = full_prompt + + async with aiohttp.ClientSession() as session: + if stream: + + async def inner(): + async with session.post( + f"{base_url}/lollms_generate", json=request_data + ) as response: + async for line in response.content: + yield line.decode().strip() + + return inner() + else: + async with session.post( + f"{base_url}/lollms_generate", json=request_data + ) as response: + return await response.text() + + @lru_cache(maxsize=1) def initialize_lmdeploy_pipeline( model, @@ -597,6 +653,32 @@ async def ollama_model_complete( ) +async def lollms_model_complete( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> Union[str, AsyncIterator[str]]: + """Complete function for lollms model generation.""" + + # Extract and remove keyword_extraction from kwargs if present + keyword_extraction = kwargs.pop("keyword_extraction", None) + + # Get model name from config + model_name = kwargs["hashing_kv"].global_config["llm_model_name"] + + # If keyword extraction is needed, we might need to modify the prompt + # or add specific parameters for JSON output (if lollms supports it) + if keyword_extraction: + # Note: You might need to adjust this based on how lollms handles structured output + pass + + return await lollms_model_if_cache( + model_name, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -1026,6 +1108,35 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: return data["embeddings"] +async def lollms_embed( + texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs +) -> np.ndarray: + """ + Generate embeddings for a list of texts using lollms server. + + Args: + texts: List of strings to embed + embed_model: Model name (not used directly as lollms uses configured vectorizer) + base_url: URL of the lollms server + **kwargs: Additional arguments passed to the request + + Returns: + np.ndarray: Array of embeddings + """ + async with aiohttp.ClientSession() as session: + embeddings = [] + for text in texts: + request_data = {"text": text} + + async with session.post( + f"{base_url}/lollms_embed", json=request_data + ) as response: + result = await response.json() + embeddings.append(result["vector"]) + + return np.array(embeddings) + + class Model(BaseModel): """ This is a Pydantic model class named 'Model' that is used to define a custom language model.