Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docker-compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ services:
- "8443:8443"
volumes:
- ./nilai-api/:/app/nilai-api/
- ./packages/:/app/packages/
redis:
ports:
- "6379:6379"
Expand Down
4 changes: 1 addition & 3 deletions nilai-api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@ dependencies = [
"fastapi[standard]>=0.115.5",
"gunicorn>=23.0.0",
"nilai-common",
"numpy<2.0.0",
"python-dotenv>=1.0.1",
"sqlalchemy>=2.0.36",
"uvicorn>=0.32.1",
"httpx>=0.27.2",
"nilrag>=0.1.10",
"nilql>=0.0.0a12",
"nilrag>=0.1.11",
"openai>=1.59.9",
"pg8000>=1.31.2",
"prometheus_fastapi_instrumentator>=7.0.2",
Expand Down
103 changes: 15 additions & 88 deletions nilai-api/src/nilai_api/handlers/nilrag.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
import logging
import numpy as np

import nilql
import nilrag

from nilai_common import ChatRequest, Message
from fastapi import HTTPException, status
from nilrag.util import (
decrypt_float_list,
encrypt_float_list,
group_shares_by_id,
)
from sentence_transformers import SentenceTransformer
from typing import Union, Any, Dict, List
from typing import Union

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,17 +33,15 @@ def generate_embeddings_huggingface(
async def handle_nilrag(req: ChatRequest):
"""
Endpoint to process a client query.
1. Initialization: Secret share keys and NilDB instance.
2. Secret share query and send to NilDB.
3. Ask NilDB to compute the differences.
4. Compute distances and sort.
5. Ask NilDB to return top k chunks.
6. Append top results to LLM query
1. Get inputs from request.
2. Execute nilRAG using nilrag library.
3. & 4. Format and append top results to LLM query
"""
try:
logger.debug("Rag is starting.")
# Step 1: Initialization
# Get NilDB instance from request

# Step 1: Get inputs
# Get nilDB instances
if not req.nilrag or "nodes" not in req.nilrag:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand All @@ -70,18 +61,8 @@ async def handle_nilrag(req: ChatRequest):
)
nilDB = nilrag.NilDB(nodes)

# Initialize secret keys
num_parties = len(nilDB.nodes)
additive_key = nilql.ClusterKey.generate(
{"nodes": [{}] * num_parties}, {"sum": True}
)
xor_key = nilql.ClusterKey.generate(
{"nodes": [{}] * num_parties}, {"store": True}
)

# Step 2: Secret share query
logger.debug("Secret sharing query and sending to NilDB...")
# 2.1 Extract the user query
# Get user query
logger.debug("Extracting user query")
query = None
for message in req.messages:
if message.role == "user":
Expand All @@ -91,73 +72,19 @@ async def handle_nilrag(req: ChatRequest):
if query is None:
raise HTTPException(status_code=400, detail="No user query found")

# 2.2 Generate query embeddings: one string query is assumed.
query_embedding = generate_embeddings_huggingface([query])[0]
nilql_query_embedding = encrypt_float_list(additive_key, query_embedding)

# Step 3: Ask NilDB to compute the differences
logger.debug("Requesting computation from NilDB...")
difference_shares: List[List[Dict[str, Any]]] = await nilDB.diff_query_execute(
nilql_query_embedding
)

# Step 4: Compute distances and sort
logger.debug("Compute distances and sort...")
# 4.1 Group difference shares by ID
difference_shares_by_id = group_shares_by_id(
difference_shares, # type: ignore
lambda share: share["difference"],
)
# 4.2 Transpose the lists for each _id
difference_shares_by_id = {
id: list(map(list, zip(*differences)))
for id, differences in difference_shares_by_id.items()
}
# 4.3 Decrypt and compute distances
reconstructed = [
{
"_id": id,
"distances": np.linalg.norm(
decrypt_float_list(additive_key, difference_shares)
),
}
for id, difference_shares in difference_shares_by_id.items()
]
# 4.4 Sort id list based on the corresponding distances
sorted_ids = sorted(reconstructed, key=lambda x: x["distances"])

# Step 5: Query the top k
logger.debug("Query top k chunks...")
top_k = req.nilrag.get("num_chunks", 2)
if not isinstance(top_k, int):
raise HTTPException(
status_code=400,
detail="num_chunks must be an integer as it represents the number of chunks to be retrieved.",
)
top_k_ids = [item["_id"] for item in sorted_ids[:top_k]]

# 5.1 Query top k
chunk_shares = await nilDB.chunk_query_execute(top_k_ids)

# 5.2 Group chunk shares by ID
chunk_shares_by_id = group_shares_by_id(
chunk_shares, # type: ignore
lambda share: share["chunk"],
)
# Get number of chunks to include
num_chunks = req.nilrag.get("num_chunks", 2)

# 5.3 Decrypt chunks
top_results = [
{"_id": id, "distances": nilql.decrypt(xor_key, chunk_shares)}
for id, chunk_shares in chunk_shares_by_id.items()
]
# Step 2: Execute nilRAG
top_results = await nilDB.top_num_chunks_execute(query, num_chunks)

# Step 6: Format top results
# Step 3: Format top results
formatted_results = "\n".join(
f"- {str(result['distances'])}" for result in top_results
)
relevant_context = f"\n\nRelevant Context:\n{formatted_results}"

# Step 7: Update system message
# Step 4: Update system message
for message in req.messages:
if message.role == "system":
if message.content is None:
Expand Down
12 changes: 4 additions & 8 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.