Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Openai summaries #7

Merged
merged 4 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions app/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from fastapi import FastAPI
from app.config.elasticsearch_config import create_index_if_not_exists
from app.models.sentence_transformer import get_sentence_transformer
from app.config.settings import settings


embedding_model = get_sentence_transformer()
es_client = create_index_if_not_exists(settings.elasticsearch.index_name)


def create_app():
app = FastAPI()

from app.routes.database_endpoints import router as database_router
from app.routes.llm_endpoints import router as llm_router
from app.routes.openai_endpoints import router as openai_router
from app.routes.evaluation_endpoints import router as evaluation_router

app.include_router(database_router, prefix="/database")
app.include_router(llm_router, prefix="/generation")
app.include_router(openai_router, prefix="/openai")
app.include_router(evaluation_router, prefix="/evaluation")

return app


app = create_app()
21 changes: 6 additions & 15 deletions app/config/elasticsearch_config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from elasticsearch import Elasticsearch

from config.settings import settings, logger
from app.config.settings import settings, logger


def get_es_client():
return Elasticsearch(
hosts=[settings.elasticsearch.host],
basic_auth=(settings.elasticsearch.user,
settings.elasticsearch.password),
basic_auth=(settings.elasticsearch.user, settings.elasticsearch.password),
max_retries=10,
)

Expand All @@ -16,16 +15,9 @@ def get_mapping():
return {
"mappings": {
"properties": {
"content": {
"type": "text"
},
"embedding": {
"type": "dense_vector",
"dims": 384
},
"metadata": {
"type": "object"
}
"content": {"type": "text"},
"embedding": {"type": "dense_vector", "dims": 384},
"metadata": {"type": "object"},
}
}
}
Expand All @@ -34,7 +26,6 @@ def get_mapping():
def create_index_if_not_exists(index_name):
es_client = get_es_client()
if not es_client.indices.exists(index=index_name):
es_client.indices.create(index=index_name,
body=get_mapping())
es_client.indices.create(index=index_name, body=get_mapping())
logger.info(f"Index '{index_name}' created.")
return es_client
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ Provide detailed, helpful, and accurate responses, and include references where
If information is not available, politely inform the user that you cannot provide an answer.
<|end|>
<|user|>
Context information is below.\n
---------------------\n
{context}\n
---------------------\n
Context information is below.
---------------------
{context}
---------------------
Given the context information (if there is any),
this is my message: {query}
<|assistant|>
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ If information is not available, politely inform the user that you cannot provid

<|eot_id|><|start_header_id|>user<|end_header_id|>

Context information is below.\n
---------------------\n
{context}\n
---------------------\n
Context information is below.
---------------------
{context}
---------------------
Given the context information (if there is any),
this is my message: {query}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ attributes. Prioritize essential data for efficient storage
and retrieval, and omit any unnecessary details.
<|end|>
<|user|>
This is the resource:\n
---------------------\n
{query}\n
---------------------\n
This is the resource:
---------------------
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a point in having these lines when you later replace all line breaks (\n)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove them, I thought I had removed all of them

{query}
---------------------
<|assistant|>
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ and retrieval, and omit any unnecessary details.

<|eot_id|><|start_header_id|>user<|end_header_id|>

This is the resource:\n
---------------------\n
{query}\n
---------------------\n
This is the resource:
---------------------
{query}
---------------------
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
5 changes: 5 additions & 0 deletions app/config/prompts/summaries_openai_system_prompt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
You will receive a single FHIR resource. Summarize the key information
from the resource in a clear, concise paragraph of plain text,
ideally up to 800 characters. The output should be human-readable and
understandable, not in JSON or other structured formats. Focus on the most
relevant attributes and omit unnecessary details.
19 changes: 13 additions & 6 deletions app/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,36 @@ def __init__(self):
# Base dir
base_dir = os.path.dirname(os.path.abspath(__file__))
# Embedding model
self.embedding_model_name = os.getenv(
"EMBEDDING_MODEL_NAME", "all-MiniLM-L6-v2")
self.embedding_model_name = os.getenv("EMBEDDING_MODEL_NAME", "all-MiniLM-L6-v2")
# LLM host
self.llm_host = os.getenv("LLAMA_HOST", "http://localhost:8090")
# Conversation prompts
self.conversation_model_prompt = self.load_prompt(
os.path.join(base_dir, "prompts/conversation_model_prompt_Phi-3.5-instruct.txt"))
os.path.join(base_dir, "prompts/conversation_model_prompt_Phi-3.5-instruct.txt")
)
# Summaries prompts
self.summaries_model_prompt = self.load_prompt(
os.path.join(base_dir, "prompts/summaries_model_prompt_Phi-3.5-instruct.txt"))
os.path.join(base_dir, "prompts/summaries_model_prompt_Phi-3.5-instruct.txt")
)
# Summaries openai system prompt
self.summaries_openai_system_prompt = self.load_prompt(
os.path.join(base_dir, "prompts/summaries_openai_system_prompt.txt")
)

def load_prompt(self, file_path: str) -> str:
with open(file_path, 'r') as file:
return file.read().strip()
with open(file_path, "r") as file:
return file.read().strip().replace("\n", " ")


class Settings:
def __init__(self):
self.elasticsearch = ElasticsearchSettings()
self.model = ModelsSettings()


settings = Settings()

# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logging.getLogger("elastic_transport").setLevel(logging.WARNING)
23 changes: 5 additions & 18 deletions app/db/index_documents.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
def bulk_load_fhir_data(data: list[dict],
text_key: str,
embedding_model,
index_name):
def bulk_load_fhir_data(data: list[dict], text_key: str, embedding_model, index_name):
"""
Function to load in bulk mode a FHIR data
"""
Expand All @@ -10,11 +7,8 @@ def bulk_load_fhir_data(data: list[dict],
resource_type = value.get("resource_type")
resource = value.get(text_key)
embedding = embedding_model.encode(resource)

metadata = {
"resource_id": resource_id,
"resource_type": resource_type
}

metadata = {"resource_id": resource_id, "resource_type": resource_type}

if "tokens_evaluated" in value:
metadata["tokens_evaluated"] = value["tokens_evaluated"]
Expand All @@ -24,12 +18,5 @@ def bulk_load_fhir_data(data: list[dict],
metadata["prompt_ms"] = value["prompt_ms"]
if "predicted_ms" in value:
metadata["predicted_ms"] = value["predicted_ms"]

yield {
"_index": index_name,
"_source": {
"content": resource,
"embedding": embedding,
"metadata": metadata
}
}

yield {"_index": index_name, "_source": {"content": resource, "embedding": embedding, "metadata": metadata}}
Empty file added app/evaluation/__init__.py
Empty file.
Empty file.
108 changes: 108 additions & 0 deletions app/evaluation/retrieval/retrieval_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import json
import random

from tqdm import tqdm

from app.services.search_documents import search_query


def evaluate_resources_summaries_retrieval(
es_client: str,
embedding_model: str,
resource_chunk_counts: dict,
qa_references: list[dict],
search_text_boost: int = 1,
search_embedding_boost: int = 1,
k: int = 5
) -> dict:
# Initialize counters and sums for metrics
total_questions = 0
total_contexts_found = 0
position_sum = 0
reciprocal_rank_sum = 0
precision_sum = 0
recall_sum = 0

# Iterate over the OpenAI responses
for response in tqdm(qa_references, total=len(qa_references), desc="Calculating retrieval metrics"):
# Get content and id of openai responses
reference_resource_id = response["custom_id"]
content = response["response"]["body"]["choices"][0]["message"]["content"]

questions_and_answers = json.loads(
content)["questions_and_answers"]

if len(questions_and_answers) > 0:
# Sample one random question per resource_id to evaluate
questions_and_answers = [random.choice(questions_and_answers)]

for qa in questions_and_answers:
if isinstance(qa, dict) and "question" in qa:
question = qa["question"]
total_questions += 1

# Query question
search_results = search_query(question,
embedding_model,
es_client,
k=k,
text_boost=search_text_boost,
embedding_boost=search_embedding_boost)

# Evaluate if any returned chunk belongs to the correct resource_id
found = False
rank = 0
retrieved_relevant_chunks = 0

# Get the total number of relevant chunks for this resource_id
relevant_chunks = resource_chunk_counts[reference_resource_id]

if search_results != {"detail": "Not Found"}:
for i, result in enumerate(search_results):
if result["metadata"]["resource_id"] == reference_resource_id:
if not found:
total_contexts_found += 1
rank = i + 1
reciprocal_rank_sum += 1 / rank
found = True
retrieved_relevant_chunks += 1
elif search_results == {"detail": "Not Found"}:
search_results = {}

# Calculate precision and recall for this specific question
precision = retrieved_relevant_chunks / \
len(search_results) if len(search_results) > 0 else 0
recall = retrieved_relevant_chunks / relevant_chunks if relevant_chunks > 0 else 0

precision_sum += precision
recall_sum += recall

if found:
position_sum += rank

# Calculate final metrics
retrieval_accuracy = round(
total_contexts_found / total_questions, 3) if total_questions > 0 else 0
average_position = round(
position_sum / total_contexts_found, 3) if total_contexts_found > 0 else 0
mrr = round(reciprocal_rank_sum / total_questions,
3) if total_questions > 0 else 0
average_precision = round(
precision_sum / total_questions, 3) if total_questions > 0 else 0
average_recall = round(recall_sum / total_questions,
3) if total_questions > 0 else 0

return {
# The percentage of questions for which the system successfully retrieved at least one relevant chunk.
"Retrieval Accuracy": retrieval_accuracy,
"Average Position": average_position,
"MRR": mrr,
# Precision = Number of relevant chunks returned / Total number of chunks returned
"Average Precision": average_precision,
# Recall = Number of relevant chunks returned / Total number of relevant chunks that exist
"Average Recall": average_recall,
# Others
"Total Questions": total_questions,
"Total contexts found": total_contexts_found,
"Total positions sum": position_sum,
}
Loading