Skip to content

Commit

Permalink
Mob/main (#79)
Browse files Browse the repository at this point in the history
* "fix the batch embedding response processing for CAII"

* update fe to data source ids

* "now we're thinking with panic"

* fix bug with s3 path when the prefix is not provided

* "add in exception handing for more than one kb"

* Fix docstring for qdrant.size()

* fix evals

* fix shadowing

---------

Co-authored-by: jwatson <jkwatson@gmail.com>
Co-authored-by: Michael Liu <mliu@cloudera.com>
  • Loading branch information
3 people authored Dec 13, 2024
1 parent fd22008 commit b952288
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 61 deletions.
4 changes: 1 addition & 3 deletions llm-service/app/ai/vector_stores/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@ def get_embedding_model(self) -> BaseEmbedding:
return models.get_embedding_model(self.data_source_metadata.embedding_model)

def size(self) -> Optional[int]:
"""
If the collection does not exist, return -1
"""
"""If the collection does not exist, return None."""
if not self.client.collection_exists(self.table_name):
return None
document_count: CountResult = self.client.count(self.table_name)
Expand Down
29 changes: 17 additions & 12 deletions llm-service/app/routers/index/sessions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,27 +77,27 @@ def delete_chat_history(session_id: int) -> str:


class RagStudioChatRequest(BaseModel):
data_source_id: int
data_source_ids: list[int]
query: str
configuration: RagPredictConfiguration


@router.post("/chat", summary="Chat with your documents in the requested datasource")
@exceptions.propagates
def chat(
session_id: int,
request: RagStudioChatRequest,
session_id: int,
request: RagStudioChatRequest,
) -> RagStudioChatMessage:
if request.configuration.exclude_knowledge_base:
return llm_talk(session_id, request)
return v2_chat(
session_id, request.data_source_id, request.query, request.configuration
session_id, request.data_source_ids, request.query, request.configuration
)


def llm_talk(
session_id: int,
request: RagStudioChatRequest,
session_id: int,
request: RagStudioChatRequest,
) -> RagStudioChatMessage:
chat_response = llm_completion.completion(
session_id, request.query, request.configuration
Expand All @@ -117,7 +117,7 @@ def llm_talk(


class SuggestQuestionsRequest(BaseModel):
data_source_id: int
data_source_ids: list[int]
configuration: RagPredictConfiguration = RagPredictConfiguration()


Expand All @@ -128,13 +128,18 @@ class RagSuggestedQuestionsResponse(BaseModel):
@router.post("/suggest-questions", summary="Suggest questions with context")
@exceptions.propagates
def suggest_questions(
session_id: int,
request: SuggestQuestionsRequest,
session_id: int,
request: SuggestQuestionsRequest,
) -> RagSuggestedQuestionsResponse:
data_source_size = QdrantVectorStore.for_chunks(request.data_source_id).size()
if data_source_size is None:

if len(request.data_source_ids) != 1:
raise HTTPException(status_code=400, detail="Only one datasource is supported for question suggestion.")

total_data_sources_size: int = sum(
map(lambda ds_id: QdrantVectorStore.for_chunks(ds_id).size() or 0, request.data_source_ids))
if total_data_sources_size == 0:
raise HTTPException(status_code=404, detail="Knowledge base not found.")
suggested_questions = generate_suggested_questions(
request.configuration, request.data_source_id, data_source_size, session_id
request.configuration, request.data_source_ids, total_data_sources_size, session_id
)
return RagSuggestedQuestionsResponse(suggested_questions=suggested_questions)
48 changes: 23 additions & 25 deletions llm-service/app/services/caii/CaiiEmbeddingModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import http.client as http_client
import json
import os
from typing import Any
from typing import Any, List

from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding
from pydantic import Field
Expand Down Expand Up @@ -95,27 +95,25 @@ def make_embedding_request(self, body: str) -> Any:
structured_response = json.loads(json_response)
return structured_response

## TODO: get this working. At the moment, the shape of the data in the response isn't what the code is expecting

# def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:
# if len(texts) == 1:
# return [self._get_text_embedding(texts[0])]
#
# print(f"Getting embeddings for {len(texts)} texts")
# model = self.endpoint.endpointmetadata.model_name
# body = json.dumps(
# {
# "input": texts,
# "input_type": "passage",
# "truncate": "END",
# "model": model,
# }
# )
# structured_response = self.make_embedding_request(body)
# embeddings = structured_response["data"][0]["embedding"]
# print(f"Got embeddings for {len(embeddings)} texts")
# assert isinstance(embeddings, list)
# assert all(isinstance(x, list) for x in embeddings)
# assert all(all(isinstance(y, float) for y in x) for x in embeddings)
#
# return embeddings

def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:
if len(texts) == 1:
return [self._get_text_embedding(texts[0])]

model = self.endpoint.endpointmetadata.model_name
body = json.dumps(
{
"input": texts,
"input_type": "passage",
"truncate": "END",
"model": model,
}
)
structured_response = self.make_embedding_request(body)

embeddings = list(map(lambda data: data["embedding"], structured_response["data"]))
assert isinstance(embeddings, list)
assert all(isinstance(x, list) for x in embeddings)
assert all(all(isinstance(y, float) for y in x) for x in embeddings)

return embeddings
11 changes: 9 additions & 2 deletions llm-service/app/services/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from collections.abc import Iterator
from typing import List

from fastapi import HTTPException
from llama_index.core.base.llms.types import MessageRole
from llama_index.core.chat_engine.types import AgentChatResponse

Expand All @@ -58,11 +59,16 @@

def v2_chat(
session_id: int,
data_source_id: int,
data_source_ids: list[int],
query: str,
configuration: RagPredictConfiguration,
) -> RagStudioChatMessage:
response_id = str(uuid.uuid4())

if len(data_source_ids) != 1:
raise HTTPException(status_code=400, detail="Only one datasource is supported for chat.")

data_source_id: int = data_source_ids[0]
if QdrantVectorStore.for_chunks(data_source_id).size() == 0:
return RagStudioChatMessage(
id=response_id,
Expand Down Expand Up @@ -137,10 +143,11 @@ def format_source_nodes(response: AgentChatResponse) -> List[RagPredictSourceNod

def generate_suggested_questions(
configuration: RagPredictConfiguration,
data_source_id: int,
data_source_ids: list[int],
data_source_size: int,
session_id: int,
) -> List[str]:
data_source_id = data_source_ids[0]
chat_history = retrieve_chat_history(session_id)
if data_source_size == 0:
suggested_questions = []
Expand Down
4 changes: 2 additions & 2 deletions llm-service/app/services/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def evaluate_response(

relevancy_evaluator = RelevancyEvaluator(llm=evaluator_llm)
relevance = relevancy_evaluator.evaluate_response(
query=query, response=Response(response=chat_response.response)
query=query, response=Response(response=chat_response.response, source_nodes=chat_response.source_nodes, metadata=chat_response.metadata)
)
faithfulness_evaluator = FaithfulnessEvaluator(llm=evaluator_llm)
faithfulness = faithfulness_evaluator.evaluate_response(
query=query, response=Response(response=chat_response.response)
query=query, response=Response(response=chat_response.response, source_nodes=chat_response.source_nodes, metadata=chat_response.metadata)
)
return relevance.score or 0, faithfulness.score or 0
2 changes: 1 addition & 1 deletion llm-service/app/tests/services/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class TestProcessResponse:
@example(response="Empty Response")
def test_process_response(self, response: str) -> None:
"""Verify process_response() cleans and filters an LLM's suggested questions."""
processed_response: str = process_response(response)
processed_response: list[str] = process_response(response)
assert len(processed_response) <= 5

for suggested_question in processed_response:
Expand Down
4 changes: 2 additions & 2 deletions ui/src/api/chatApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ export interface QueryConfiguration {

export interface ChatMutationRequest {
query: string;
data_source_id: string;
data_source_ids: number[];
session_id: string;
configuration: QueryConfiguration;
}
Expand Down Expand Up @@ -176,7 +176,7 @@ export const useChatMutation = ({
(cachedData) => replacePlaceholderInChatHistory(data, cachedData),
);
await queryClient.invalidateQueries({
queryKey: suggestedQuestionKey(variables.data_source_id),
queryKey: suggestedQuestionKey(variables.data_source_ids),
});
onSuccess?.(data);
},
Expand Down
10 changes: 5 additions & 5 deletions ui/src/api/ragQueryApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import {
} from "src/api/utils.ts";

export interface SuggestQuestionsRequest {
data_source_id: string;
data_source_ids: number[];
configuration: QueryConfiguration;
session_id: string;
}
Expand All @@ -57,12 +57,12 @@ export interface SuggestQuestionsResponse {
}

export const suggestedQuestionKey = (
data_source_id: SuggestQuestionsRequest["data_source_id"],
data_source_ids: SuggestQuestionsRequest["data_source_ids"],
) => {
return [
QueryKeys.suggestQuestionsQuery,
{
data_source_id,
data_source_ids,
},
];
};
Expand All @@ -71,9 +71,9 @@ export const useSuggestQuestions = (request: SuggestQuestionsRequest) => {
return useQuery({
// Note: We only want to invalidate the query when the data_source_id changes, not when chat history changes
// eslint-disable-next-line @tanstack/query/exhaustive-deps
queryKey: suggestedQuestionKey(request.data_source_id),
queryKey: suggestedQuestionKey(request.data_source_ids),
queryFn: () => suggestQuestionsQuery(request),
enabled: Boolean(request.data_source_id),
enabled: Boolean(request.data_source_ids.length),
gcTime: 0,
});
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,13 @@ const SuggestedQuestionsCards = () => {
activeSession,
excludeKnowledgeBaseState: [excludeKnowledgeBase],
} = useContext(RagChatContext);
const dataSourceId = activeSession?.dataSourceIds[0];
const sessionId = activeSession?.id.toString();
const {
data,
isPending: suggestedQuestionsIsPending,
isFetching: suggestedQuestionsIsFetching,
} = useSuggestQuestions({
data_source_id: dataSourceId?.toString() ?? "",
data_source_ids: activeSession?.dataSourceIds ?? [],
configuration: createQueryConfiguration(
excludeKnowledgeBase,
activeSession,
Expand All @@ -75,15 +74,15 @@ const SuggestedQuestionsCards = () => {

const handleAskSample = (suggestedQuestion: string) => {
if (
dataSourceId &&
dataSourceId > 0 &&
activeSession &&
activeSession.dataSourceIds.length > 0 &&
suggestedQuestion.length > 0 &&
sessionId
) {
setCurrentQuestion(suggestedQuestion);
chatMutation({
query: suggestedQuestion,
data_source_id: dataSourceId.toString(),
data_source_ids: activeSession.dataSourceIds,
session_id: sessionId,
configuration: createQueryConfiguration(
excludeKnowledgeBase,
Expand Down
12 changes: 8 additions & 4 deletions ui/src/pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ const RagChatQueryInput = () => {
dataSourcesQuery: { dataSourcesStatus },
activeSession,
} = useContext(RagChatContext);
const dataSourceId = activeSession?.dataSourceIds[0];
const [userInput, setUserInput] = useState("");
const { sessionId } = useParams({ strict: false });

Expand All @@ -71,7 +70,7 @@ const RagChatQueryInput = () => {
isPending: sampleQuestionsIsPending,
isFetching: sampleQuestionsIsFetching,
} = useSuggestQuestions({
data_source_id: dataSourceId?.toString() ?? "",
data_source_ids: activeSession?.dataSourceIds ?? [],
configuration,
session_id: sessionId ?? "",
});
Expand All @@ -87,11 +86,16 @@ const RagChatQueryInput = () => {
});

const handleChat = (userInput: string) => {
if (dataSourceId && dataSourceId > 0 && userInput.length > 0 && sessionId) {
if (
activeSession &&
activeSession.dataSourceIds.length > 0 &&
userInput.length > 0 &&
sessionId
) {
setCurrentQuestion(userInput);
chatMutation.mutate({
query: userInput,
data_source_id: dataSourceId.toString(),
data_source_ids: activeSession.dataSourceIds,
session_id: sessionId,
configuration: createQueryConfiguration(
excludeKnowledgeBase,
Expand Down

0 comments on commit b952288

Please sign in to comment.