diff --git a/llm-service/app/ai/vector_stores/qdrant.py b/llm-service/app/ai/vector_stores/qdrant.py index b4007105..150f9747 100644 --- a/llm-service/app/ai/vector_stores/qdrant.py +++ b/llm-service/app/ai/vector_stores/qdrant.py @@ -98,9 +98,7 @@ def get_embedding_model(self) -> BaseEmbedding: return models.get_embedding_model(self.data_source_metadata.embedding_model) def size(self) -> Optional[int]: - """ - If the collection does not exist, return -1 - """ + """If the collection does not exist, return None.""" if not self.client.collection_exists(self.table_name): return None document_count: CountResult = self.client.count(self.table_name) diff --git a/llm-service/app/routers/index/sessions/__init__.py b/llm-service/app/routers/index/sessions/__init__.py index 183ca1f2..cc43e723 100644 --- a/llm-service/app/routers/index/sessions/__init__.py +++ b/llm-service/app/routers/index/sessions/__init__.py @@ -77,7 +77,7 @@ def delete_chat_history(session_id: int) -> str: class RagStudioChatRequest(BaseModel): - data_source_id: int + data_source_ids: list[int] query: str configuration: RagPredictConfiguration @@ -85,19 +85,19 @@ class RagStudioChatRequest(BaseModel): @router.post("/chat", summary="Chat with your documents in the requested datasource") @exceptions.propagates def chat( - session_id: int, - request: RagStudioChatRequest, + session_id: int, + request: RagStudioChatRequest, ) -> RagStudioChatMessage: if request.configuration.exclude_knowledge_base: return llm_talk(session_id, request) return v2_chat( - session_id, request.data_source_id, request.query, request.configuration + session_id, request.data_source_ids, request.query, request.configuration ) def llm_talk( - session_id: int, - request: RagStudioChatRequest, + session_id: int, + request: RagStudioChatRequest, ) -> RagStudioChatMessage: chat_response = llm_completion.completion( session_id, request.query, request.configuration @@ -117,7 +117,7 @@ def llm_talk( class SuggestQuestionsRequest(BaseModel): - data_source_id: int + data_source_ids: list[int] configuration: RagPredictConfiguration = RagPredictConfiguration() @@ -128,13 +128,18 @@ class RagSuggestedQuestionsResponse(BaseModel): @router.post("/suggest-questions", summary="Suggest questions with context") @exceptions.propagates def suggest_questions( - session_id: int, - request: SuggestQuestionsRequest, + session_id: int, + request: SuggestQuestionsRequest, ) -> RagSuggestedQuestionsResponse: - data_source_size = QdrantVectorStore.for_chunks(request.data_source_id).size() - if data_source_size is None: + + if len(request.data_source_ids) != 1: + raise HTTPException(status_code=400, detail="Only one datasource is supported for question suggestion.") + + total_data_sources_size: int = sum( + map(lambda ds_id: QdrantVectorStore.for_chunks(ds_id).size() or 0, request.data_source_ids)) + if total_data_sources_size == 0: raise HTTPException(status_code=404, detail="Knowledge base not found.") suggested_questions = generate_suggested_questions( - request.configuration, request.data_source_id, data_source_size, session_id + request.configuration, request.data_source_ids, total_data_sources_size, session_id ) return RagSuggestedQuestionsResponse(suggested_questions=suggested_questions) diff --git a/llm-service/app/services/caii/CaiiEmbeddingModel.py b/llm-service/app/services/caii/CaiiEmbeddingModel.py index 7ce04b8a..bd8d1529 100644 --- a/llm-service/app/services/caii/CaiiEmbeddingModel.py +++ b/llm-service/app/services/caii/CaiiEmbeddingModel.py @@ -38,7 +38,7 @@ import http.client as http_client import json import os -from typing import Any +from typing import Any, List from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding from pydantic import Field @@ -95,27 +95,25 @@ def make_embedding_request(self, body: str) -> Any: structured_response = json.loads(json_response) return structured_response -## TODO: get this working. At the moment, the shape of the data in the response isn't what the code is expecting - - # def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]: - # if len(texts) == 1: - # return [self._get_text_embedding(texts[0])] - # - # print(f"Getting embeddings for {len(texts)} texts") - # model = self.endpoint.endpointmetadata.model_name - # body = json.dumps( - # { - # "input": texts, - # "input_type": "passage", - # "truncate": "END", - # "model": model, - # } - # ) - # structured_response = self.make_embedding_request(body) - # embeddings = structured_response["data"][0]["embedding"] - # print(f"Got embeddings for {len(embeddings)} texts") - # assert isinstance(embeddings, list) - # assert all(isinstance(x, list) for x in embeddings) - # assert all(all(isinstance(y, float) for y in x) for x in embeddings) - # - # return embeddings + + def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]: + if len(texts) == 1: + return [self._get_text_embedding(texts[0])] + + model = self.endpoint.endpointmetadata.model_name + body = json.dumps( + { + "input": texts, + "input_type": "passage", + "truncate": "END", + "model": model, + } + ) + structured_response = self.make_embedding_request(body) + + embeddings = list(map(lambda data: data["embedding"], structured_response["data"])) + assert isinstance(embeddings, list) + assert all(isinstance(x, list) for x in embeddings) + assert all(all(isinstance(y, float) for y in x) for x in embeddings) + + return embeddings diff --git a/llm-service/app/services/chat.py b/llm-service/app/services/chat.py index 48257313..85ee6fcd 100644 --- a/llm-service/app/services/chat.py +++ b/llm-service/app/services/chat.py @@ -41,6 +41,7 @@ from collections.abc import Iterator from typing import List +from fastapi import HTTPException from llama_index.core.base.llms.types import MessageRole from llama_index.core.chat_engine.types import AgentChatResponse @@ -58,11 +59,16 @@ def v2_chat( session_id: int, - data_source_id: int, + data_source_ids: list[int], query: str, configuration: RagPredictConfiguration, ) -> RagStudioChatMessage: response_id = str(uuid.uuid4()) + + if len(data_source_ids) != 1: + raise HTTPException(status_code=400, detail="Only one datasource is supported for chat.") + + data_source_id: int = data_source_ids[0] if QdrantVectorStore.for_chunks(data_source_id).size() == 0: return RagStudioChatMessage( id=response_id, @@ -137,10 +143,11 @@ def format_source_nodes(response: AgentChatResponse) -> List[RagPredictSourceNod def generate_suggested_questions( configuration: RagPredictConfiguration, - data_source_id: int, + data_source_ids: list[int], data_source_size: int, session_id: int, ) -> List[str]: + data_source_id = data_source_ids[0] chat_history = retrieve_chat_history(session_id) if data_source_size == 0: suggested_questions = [] diff --git a/llm-service/app/services/evaluators.py b/llm-service/app/services/evaluators.py index f91e820b..2d3888b5 100644 --- a/llm-service/app/services/evaluators.py +++ b/llm-service/app/services/evaluators.py @@ -51,10 +51,10 @@ def evaluate_response( relevancy_evaluator = RelevancyEvaluator(llm=evaluator_llm) relevance = relevancy_evaluator.evaluate_response( - query=query, response=Response(response=chat_response.response) + query=query, response=Response(response=chat_response.response, source_nodes=chat_response.source_nodes, metadata=chat_response.metadata) ) faithfulness_evaluator = FaithfulnessEvaluator(llm=evaluator_llm) faithfulness = faithfulness_evaluator.evaluate_response( - query=query, response=Response(response=chat_response.response) + query=query, response=Response(response=chat_response.response, source_nodes=chat_response.source_nodes, metadata=chat_response.metadata) ) return relevance.score or 0, faithfulness.score or 0 diff --git a/llm-service/app/tests/services/test_chat.py b/llm-service/app/tests/services/test_chat.py index b50a8998..4a24f221 100644 --- a/llm-service/app/tests/services/test_chat.py +++ b/llm-service/app/tests/services/test_chat.py @@ -85,7 +85,7 @@ class TestProcessResponse: @example(response="Empty Response") def test_process_response(self, response: str) -> None: """Verify process_response() cleans and filters an LLM's suggested questions.""" - processed_response: str = process_response(response) + processed_response: list[str] = process_response(response) assert len(processed_response) <= 5 for suggested_question in processed_response: diff --git a/ui/src/api/chatApi.ts b/ui/src/api/chatApi.ts index 74e17f70..5567e509 100644 --- a/ui/src/api/chatApi.ts +++ b/ui/src/api/chatApi.ts @@ -73,7 +73,7 @@ export interface QueryConfiguration { export interface ChatMutationRequest { query: string; - data_source_id: string; + data_source_ids: number[]; session_id: string; configuration: QueryConfiguration; } @@ -176,7 +176,7 @@ export const useChatMutation = ({ (cachedData) => replacePlaceholderInChatHistory(data, cachedData), ); await queryClient.invalidateQueries({ - queryKey: suggestedQuestionKey(variables.data_source_id), + queryKey: suggestedQuestionKey(variables.data_source_ids), }); onSuccess?.(data); }, diff --git a/ui/src/api/ragQueryApi.ts b/ui/src/api/ragQueryApi.ts index e546a638..c03c9b27 100644 --- a/ui/src/api/ragQueryApi.ts +++ b/ui/src/api/ragQueryApi.ts @@ -47,7 +47,7 @@ import { } from "src/api/utils.ts"; export interface SuggestQuestionsRequest { - data_source_id: string; + data_source_ids: number[]; configuration: QueryConfiguration; session_id: string; } @@ -57,12 +57,12 @@ export interface SuggestQuestionsResponse { } export const suggestedQuestionKey = ( - data_source_id: SuggestQuestionsRequest["data_source_id"], + data_source_ids: SuggestQuestionsRequest["data_source_ids"], ) => { return [ QueryKeys.suggestQuestionsQuery, { - data_source_id, + data_source_ids, }, ]; }; @@ -71,9 +71,9 @@ export const useSuggestQuestions = (request: SuggestQuestionsRequest) => { return useQuery({ // Note: We only want to invalidate the query when the data_source_id changes, not when chat history changes // eslint-disable-next-line @tanstack/query/exhaustive-deps - queryKey: suggestedQuestionKey(request.data_source_id), + queryKey: suggestedQuestionKey(request.data_source_ids), queryFn: () => suggestQuestionsQuery(request), - enabled: Boolean(request.data_source_id), + enabled: Boolean(request.data_source_ids.length), gcTime: 0, }); }; diff --git a/ui/src/pages/RagChatTab/ChatOutput/Placeholders/SuggestedQuestionsCards.tsx b/ui/src/pages/RagChatTab/ChatOutput/Placeholders/SuggestedQuestionsCards.tsx index c7eb302b..546798f0 100644 --- a/ui/src/pages/RagChatTab/ChatOutput/Placeholders/SuggestedQuestionsCards.tsx +++ b/ui/src/pages/RagChatTab/ChatOutput/Placeholders/SuggestedQuestionsCards.tsx @@ -49,14 +49,13 @@ const SuggestedQuestionsCards = () => { activeSession, excludeKnowledgeBaseState: [excludeKnowledgeBase], } = useContext(RagChatContext); - const dataSourceId = activeSession?.dataSourceIds[0]; const sessionId = activeSession?.id.toString(); const { data, isPending: suggestedQuestionsIsPending, isFetching: suggestedQuestionsIsFetching, } = useSuggestQuestions({ - data_source_id: dataSourceId?.toString() ?? "", + data_source_ids: activeSession?.dataSourceIds ?? [], configuration: createQueryConfiguration( excludeKnowledgeBase, activeSession, @@ -75,15 +74,15 @@ const SuggestedQuestionsCards = () => { const handleAskSample = (suggestedQuestion: string) => { if ( - dataSourceId && - dataSourceId > 0 && + activeSession && + activeSession.dataSourceIds.length > 0 && suggestedQuestion.length > 0 && sessionId ) { setCurrentQuestion(suggestedQuestion); chatMutation({ query: suggestedQuestion, - data_source_id: dataSourceId.toString(), + data_source_ids: activeSession.dataSourceIds, session_id: sessionId, configuration: createQueryConfiguration( excludeKnowledgeBase, diff --git a/ui/src/pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx b/ui/src/pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx index 682a7d49..6187dd42 100644 --- a/ui/src/pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx +++ b/ui/src/pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx @@ -58,7 +58,6 @@ const RagChatQueryInput = () => { dataSourcesQuery: { dataSourcesStatus }, activeSession, } = useContext(RagChatContext); - const dataSourceId = activeSession?.dataSourceIds[0]; const [userInput, setUserInput] = useState(""); const { sessionId } = useParams({ strict: false }); @@ -71,7 +70,7 @@ const RagChatQueryInput = () => { isPending: sampleQuestionsIsPending, isFetching: sampleQuestionsIsFetching, } = useSuggestQuestions({ - data_source_id: dataSourceId?.toString() ?? "", + data_source_ids: activeSession?.dataSourceIds ?? [], configuration, session_id: sessionId ?? "", }); @@ -87,11 +86,16 @@ const RagChatQueryInput = () => { }); const handleChat = (userInput: string) => { - if (dataSourceId && dataSourceId > 0 && userInput.length > 0 && sessionId) { + if ( + activeSession && + activeSession.dataSourceIds.length > 0 && + userInput.length > 0 && + sessionId + ) { setCurrentQuestion(userInput); chatMutation.mutate({ query: userInput, - data_source_id: dataSourceId.toString(), + data_source_ids: activeSession.dataSourceIds, session_id: sessionId, configuration: createQueryConfiguration( excludeKnowledgeBase,