Skip to content

Commit

Permalink
Merge pull request #2 from fastenhealth/llamaindex
Browse files Browse the repository at this point in the history
Add retrieval evaluation pipeline
  • Loading branch information
dgbaenar authored Aug 30, 2024
2 parents 2a0e559 + fe22b9e commit e4b518b
Show file tree
Hide file tree
Showing 111 changed files with 709,835 additions and 318 deletions.
4 changes: 0 additions & 4 deletions .flake8

This file was deleted.

2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.9"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
13 changes: 13 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ __pycache__/
.Python
env/
venv/
venv-dev/
ENV/
env.bak/
venv.bak/
Expand Down Expand Up @@ -115,3 +116,15 @@ dmypy.json
.idea/
*.sublime-project
*.sublime-workspace
scripts/rag_evaluation/data/flat_files/

# data
*/data
scripts/rag_evaluation/evaluate_generation/generation_speed/data/*.csv
scripts/rag_evaluation/evaluate_retrieval/data/output/*csv

*.jsonl
/models
/data/
!/evaluation/data/
evaluation/evaluation_dataset
4 changes: 2 additions & 2 deletions .ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
line-length = 130
indent-width = 4

# Assume Python 3.10
target-version = "py310"
# Assume Python 3.9
target-version = "py39"
30 changes: 29 additions & 1 deletion app/config/elasticsearch_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from elasticsearch import Elasticsearch

from .settings import settings
from .settings import settings, logger


def get_es_client():
Expand All @@ -9,3 +9,31 @@ def get_es_client():
basic_auth=(settings.es_user, settings.es_password),
max_retries=10,
)


def get_mapping():
return {
"mappings": {
"properties": {
"content": {
"type": "text"
},
"embedding": {
"type": "dense_vector",
"dims": 384
},
"metadata": {
"type": "object"
}
}
}
}


def create_index_if_not_exists(index_name):
es_client = get_es_client()
if not es_client.indices.exists(index=index_name):
es_client.indices.create(index=index_name,
body=get_mapping())
logger.info(f"Index '{index_name}' created.")
return es_client
28 changes: 22 additions & 6 deletions app/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,30 @@ class Settings:
es_password: str = os.getenv("ES_PASSWORD", "changeme")
embedding_model_name: str = os.getenv("EMBEDDING_MODEL_NAME",
"all-MiniLM-L6-v2")
llama_host: str = os.getenv("LLAMA_HOST", "http://localhost:8080")
llama_prompt: str = os.getenv("LLAMA_PROMPT",
"A chat between a curious user and an intelligent, \
polite medical assistant. The assistant provides detailed,\
helpful answers to the user's medical questions,\
including accurate references where applicable.")
host: str = os.getenv("LLAMA_HOST", "http://localhost:8090")
system_prompt: str = os.getenv("LLAMA_PROMPT",
("You are an intelligent, polite medical assistant embedded "
"within a Retrieval-Augmented Generation (RAG) system. "
"Your responses are based strictly on information retrieved "
"from a database, specifically FHIR data chunks. These chunks may not always be clear. "
"If you do not find relevant information, acknowledge it and do not attempt to fabricate answers. "
"Provide detailed, helpful, and accurate responses, and include references where applicable. "
"If information is not available, politely inform the user that you cannot provide an answer."))
index_name: str = os.getenv("INDEX_NAME", "fasten-index")
upload_dir: str = os.getenv("UPLOAD_DIR", "./data/")
model_prompt: str = ("<|system|>"
"{system_prompt}<|end|>"
"<|user|>"
"Context information is below.\n "
"---------------------\n "
"{context}\n"
"---------------------\n "
"Given the context information (if there is any), "
"this is my message: "
"{message}"
"<|end|>"
"<|assistant|>"
)


settings = Settings()
Expand Down
65 changes: 19 additions & 46 deletions app/db/index_documents.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,21 @@
import os

import fitz
from langchain.text_splitter import RecursiveCharacterTextSplitter

from ..config.settings import settings, logger


text_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n", " ", ""],
chunk_size=300,
chunk_overlap=50,
length_function=len,
)


def extract_text_from_pdf(pdf_path):
logger.info(f"Extracting text from PDF: {pdf_path}")
doc = fitz.open(pdf_path)
text = ""
for page in doc:
text += page.get_text()
return text


def index_pdf(pdf_path, embedding_model,
es_client, index_name=settings.index_name):
logger.info(f"Indexing PDF: {pdf_path}")
text = extract_text_from_pdf(pdf_path)
chunks = text_splitter.create_documents(
texts=[text],
metadatas=[{"source": os.path.basename(pdf_path)}]
)
for chunk in chunks:
content = chunk.page_content
metadata = chunk.metadata
embedding = embedding_model.encode(content).tolist()
es_client.index(
index=index_name,
body={
"content": content,
def bulk_load_from_json_flattened_file(data: dict,
embedding_model,
index_name):
"""
Function to load in bulk mode a FHIR JSON flattened file
"""
data = data["entry"]
for value in data:
resource_id = value.get("resource_id")
text_chunk = value.get("resource")
embedding = embedding_model.encode(text_chunk)
yield {
"_index": index_name,
"_source": {
"content": text_chunk,
"embedding": embedding,
"metadata": metadata
"metadata": {
"resource_id": resource_id
}
}
)
logger.info(f"Chunk indexed: {content[:30]}...")
es_client.indices.refresh(index=index_name)
logger.info(f"PDF {os.path.basename(pdf_path)} indexed successfully.")
}
104 changes: 69 additions & 35 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import json
import os

from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from elasticsearch import helpers
from fastapi import FastAPI, HTTPException, UploadFile, File

from .db.index_documents import index_pdf
from .config.elasticsearch_config import get_es_client
from .db.index_documents import bulk_load_from_json_flattened_file
from .config.elasticsearch_config import create_index_if_not_exists
from .services.search_documents import search_query
from .services.process_search_output import process_search_output, \
stream_llm_response
llm_response
from .models.sentence_transformer import get_sentence_transformer
from .config.settings import settings, logger

Expand All @@ -18,42 +19,75 @@

# Initialize SentenceTransformer model and elastic client
embedding_model = get_sentence_transformer()
es_client = get_es_client()


def create_index_if_not_exists(index_name):
if not es_client.indices.exists(index=index_name):
es_client.indices.create(index=index_name)
logger.info(f"Index '{index_name}' created.")


create_index_if_not_exists(settings.index_name)


@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
file_path = os.path.join(settings.upload_dir, file.filename)
with open(file_path, "wb") as buffer:
buffer.write(await file.read())
index_pdf(file_path, embedding_model, es_client)
logger.info(f"File uploaded: {file.filename}")
return {"filename": file.filename}
# Create elasticsearch index
es_client = create_index_if_not_exists(settings.index_name)


@app.post("/bulk_load")
async def bulk_load(file: UploadFile = File(...)):
data = await file.read()
# json to dict
json_data = json.loads(data)
# Bulk load to Elasticsearch
try:
helpers.bulk(es_client,
bulk_load_from_json_flattened_file(json_data,
embedding_model,
settings.index_name))
logger.info(f"Bulk load completed for file: {file.filename}")
return {"status": "success", "filename": file.filename}
except Exception as e:
logger.error(f"Bulk load failed: {str(e)}")
return {"status": "error", "message": str(e)}


@app.delete("/delete_all_documents")
async def delete_all_documents(index_name: str):
try:
es_client.delete_by_query(index=index_name, body={
"query": {
"match_all": {}
}
})
logger.info(f"All documents deleted from index '{index_name}'")
return {"status": "success",
"message": f"All documents deleted from index '{index_name}'"}
except Exception as e:
logger.error(f"Failed to delete documents: {str(e)}")
raise HTTPException(status_code=500,
detail=f"Failed to delete documents: {str(e)}")


@app.get("/search")
async def search_documents(query: str, k: int = 5, threshold: float = 0):
results = search_query(query, embedding_model, es_client, k=k,
threshold=threshold)
return JSONResponse(content=results)
async def search_documents(query: str,
k: int = 5,
text_boost: float = 0.25,
embedding_boost: float = 4.0):
results = search_query(query,
embedding_model,
es_client, k=k,
text_boost=text_boost,
embedding_boost=embedding_boost)
return results


@app.get("/generate")
async def answer_query(query: str, k: int = 5, threshold: float = 0):
results = search_query(query, embedding_model, es_client, k=k,
threshold=threshold)
async def answer_query(query: str,
k: int = 5,
params=None,
stream: bool = False,
text_boost: float = 0.25,
embedding_boost: float = 4.0):
results = search_query(query,
embedding_model,
es_client,
k=k,
text_boost=text_boost,
embedding_boost=embedding_boost)
if not results:
concatenated_content = f"No results found for query: {query}"
concatenated_content = "There is no context"
else:
concatenated_content = process_search_output(results)
concatenated_content = process_search_output(results)
return stream_llm_response(concatenated_content, query)
concatenated_content, resources_id = process_search_output(results)

return llm_response(concatenated_content, query, resources_id, stream, params)
Loading

0 comments on commit e4b518b

Please sign in to comment.