Skip to content

Commit

Permalink
Merge pull request #19 from cloudera/mob/main
Browse files Browse the repository at this point in the history
Better model error handling
  • Loading branch information
ewilliams-cloudera authored Nov 18, 2024
2 parents fa6a149 + 88ef181 commit a1f3b8a
Show file tree
Hide file tree
Showing 12 changed files with 93 additions and 56 deletions.
21 changes: 9 additions & 12 deletions llm-service/app/routers/index/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,21 @@

router = APIRouter(
prefix="/index",
tags=["index"],
)
router.include_router(data_source.router)
router.include_router(sessions.router)
router.include_router(amp_update.router)
router.include_router(models.router)


class SuggestQuestionsRequest(BaseModel):
data_source_id: int
chat_history: list[RagContext]
configuration: qdrant.RagPredictConfiguration = qdrant.RagPredictConfiguration()

class RagSuggestedQuestionsResponse(BaseModel):
suggested_questions: list[str]

class RagIndexDocumentRequest(BaseModel):
data_source_id: int
s3_bucket_name: str
Expand All @@ -81,7 +88,7 @@ class RagIndexDocumentRequest(BaseModel):
)
@exceptions.propagates
def download_and_index(
request: RagIndexDocumentRequest,
request: RagIndexDocumentRequest,
) -> str:
with tempfile.TemporaryDirectory() as tmpdirname:
logger.debug("created temporary directory %s", tmpdirname)
Expand All @@ -94,16 +101,6 @@ def download_and_index(
)
return http.HTTPStatus.OK.phrase


class SuggestQuestionsRequest(BaseModel):
data_source_id: int
chat_history: list[RagContext]
configuration: qdrant.RagPredictConfiguration = qdrant.RagPredictConfiguration()

class RagSuggestedQuestionsResponse(BaseModel):
suggested_questions: list[str]


@router.post("/suggest-questions", summary="Suggest questions with context")
@exceptions.propagates
def suggest_questions(
Expand Down
3 changes: 1 addition & 2 deletions llm-service/app/routers/index/amp_update/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,14 @@
# DATA.
# ##############################################################################

import json
import subprocess

from fastapi import APIRouter
from subprocess import CompletedProcess
from .... import exceptions
from ....services.amp_update import check_amp_update_status

router = APIRouter(prefix="/amp-update")
router = APIRouter(prefix="/amp-update" , tags=["AMP Update"])

@router.get("", summary="Returns a boolean for whether AMP needs updating.")
@exceptions.propagates
Expand Down
2 changes: 1 addition & 1 deletion llm-service/app/routers/index/data_source/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .... import exceptions
from ....services import doc_summaries, qdrant

router = APIRouter(prefix="/data_sources/{data_source_id}")
router = APIRouter(prefix="/data_sources/{data_source_id}", tags=["Data Sources"])


class SummarizeDocumentRequest(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion llm-service/app/routers/index/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
test_llm_model,
)

router = APIRouter(prefix="/models")
router = APIRouter(prefix="/models", tags=["Models"])


@router.get("/llm", summary="Get LLM Inference models.")
Expand Down
2 changes: 1 addition & 1 deletion llm-service/app/routers/index/sessions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from .... import exceptions
from ....services.chat_store import RagStudioChatMessage, chat_store

router = APIRouter(prefix="/sessions/{session_id}")
router = APIRouter(prefix="/sessions/{session_id}", tags=["Sessions"])

@router.get("/chat-history", summary="Returns an array of chat messages for the provided session.")
@exceptions.propagates
Expand Down
14 changes: 13 additions & 1 deletion llm-service/app/services/caii.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,16 @@ def get_embedding_model() -> BaseEmbedding:
def get_caii_llm_models():
domain = os.environ['CAII_DOMAIN']
endpoint_name = os.environ['CAII_INFERENCE_ENDPOINT_NAME']
models = describe_endpoint(domain=domain, endpoint_name=endpoint_name)
try:
models = describe_endpoint(domain=domain, endpoint_name=endpoint_name)
except requests.exceptions.ConnectionError as e:
print(e)
raise HTTPException(status_code=421, detail = f"Unable to connect to host {domain}. Please check your CAII_DOMAIN env variable.")
except HTTPException as e:
if e.status_code == 404:
return [{"model_id": endpoint_name}]
else:
raise e
return build_model_response(models)

def get_caii_embedding_models():
Expand All @@ -120,6 +129,9 @@ def get_caii_embedding_models():
endpoint_name = os.environ['CAII_EMBEDDING_ENDPOINT_NAME']
try:
models = describe_endpoint(domain=domain, endpoint_name=endpoint_name)
except requests.exceptions.ConnectionError as e:
print(e)
raise HTTPException(status_code=421, detail = f"Unable to connect to host {domain}. Please check your CAII_DOMAIN env variable.")
except HTTPException as e:
if e.status_code == 404:
return [{"model_id": endpoint_name}]
Expand Down
38 changes: 21 additions & 17 deletions ui/src/api/modelsApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@
* BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
* DATA.
******************************************************************************/
import { queryOptions, useQuery } from "@tanstack/react-query";
import { queryOptions, useMutation, useQuery } from "@tanstack/react-query";
import {
ApiError,
CustomError,
getRequest,
llmServicePath,
MutationKeys,
QueryKeys,
UseMutationType,
} from "src/api/utils.ts";

export interface Model {
Expand Down Expand Up @@ -97,14 +99,15 @@ const getModelSource = async (): Promise<ModelSource> => {
return await getRequest(`${llmServicePath}/index/models/model_source`);
};

export const useTestLlmModel = (model_id: string) => {
return useQuery({
queryKey: [QueryKeys.testLlmModel, { model_id }],
queryFn: async () => {
return await testLlmModel(model_id);
},
enabled: !!model_id,
retry: false,
export const useTestLlmModel = ({
onSuccess,
onError,
}: UseMutationType<string>) => {
return useMutation({
mutationKey: [MutationKeys.testLlmModel],
mutationFn: testLlmModel,
onError,
onSuccess,
});
};

Expand All @@ -121,14 +124,15 @@ const testLlmModel = async (model_id: string): Promise<string> => {
});
};

export const useTestEmbeddingModel = (model_id: string) => {
return useQuery({
queryKey: [QueryKeys.testEmbeddingModel, { model_id }],
queryFn: async () => {
return await testEmbeddingModel(model_id);
},
retry: false,
enabled: !!model_id,
export const useTestEmbeddingModel = ({
onSuccess,
onError,
}: UseMutationType<string>) => {
return useMutation({
mutationKey: [MutationKeys.testEmbeddingModel],
mutationFn: testEmbeddingModel,
onError,
onSuccess,
});
};

Expand Down
4 changes: 2 additions & 2 deletions ui/src/api/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ export enum MutationKeys {
"deleteChatHistory" = "deleteChatHistory",
"deleteSession" = "deleteSession",
"updateAmp" = "updateAmp",
"testLlmModel" = "testLlmModel",
"testEmbeddingModel" = "testEmbeddingModel",
}

export enum QueryKeys {
Expand All @@ -81,8 +83,6 @@ export enum QueryKeys {
"getLlmModels" = "getLlmModels",
"getEmbeddingModels" = "getEmbeddingModels",
"getModelSource" = "getModelSource",
"testLlmModel" = "testLlmModel",
"testEmbeddingModel" = "testEmbeddingModel",
}

export const commonHeaders = {
Expand Down
13 changes: 7 additions & 6 deletions ui/src/pages/Models/EmbeddingModelTable.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,27 @@

import { Table, TableProps } from "antd";
import { Model, useTestEmbeddingModel } from "src/api/modelsApi.ts";
import { useState } from "react";
import { modelColumns, TestCell } from "pages/Models/ModelTable.tsx";

const EmbeddingModelTestCell = ({ model }: { model: Model }) => {
const [testModel, setTestModel] = useState("");
const {
data: testResult,
isLoading,
isPending,
error,
} = useTestEmbeddingModel(testModel);
mutate,
} = useTestEmbeddingModel({
onError: () => undefined,
});

const handleTestModel = () => {
setTestModel(model.model_id);
mutate(model.model_id);
};

return (
<TestCell
onClick={handleTestModel}
model={model}
loading={isLoading}
loading={isPending}
error={error}
testResult={testResult}
/>
Expand Down
15 changes: 10 additions & 5 deletions ui/src/pages/Models/InferenceModelTable.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,27 @@

import { Table, TableProps } from "antd";
import { Model, useTestLlmModel } from "src/api/modelsApi.ts";
import { useState } from "react";
import { modelColumns, TestCell } from "pages/Models/ModelTable.tsx";

const InferenceModelTestCell = ({ model }: { model: Model }) => {
const [testModel, setTestModel] = useState("");
const { data: testResult, isLoading, error } = useTestLlmModel(testModel);
const {
data: testResult,
isPending,
error,
mutate,
} = useTestLlmModel({
onError: () => undefined,
});

const handleTestModel = () => {
setTestModel(model.model_id);
mutate(model.model_id);
};

return (
<TestCell
onClick={handleTestModel}
model={model}
loading={isLoading}
loading={isPending}
error={error}
testResult={testResult}
/>
Expand Down
32 changes: 27 additions & 5 deletions ui/src/pages/Models/ModelPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,41 @@
* DATA.
******************************************************************************/

import { Flex, Typography } from "antd";
import { Alert, Flex, Typography } from "antd";
import EmbeddingModelTable from "pages/Models/EmbeddingModelTable.tsx";
import { useGetEmbeddingModels, useGetLlmModels } from "src/api/modelsApi.ts";
import InferenceModelTable from "pages/Models/InferenceModelTable.tsx";

const ModelPage = () => {
const { data: embeddingModels, isLoading: areEmbeddingModelsLoading } =
useGetEmbeddingModels();
const { data: inferenceModels, isLoading: areInferenceModelsLoading } =
useGetLlmModels();
const {
data: embeddingModels,
isLoading: areEmbeddingModelsLoading,
error: embeddingError,
} = useGetEmbeddingModels();
const {
data: inferenceModels,
isLoading: areInferenceModelsLoading,
error: inferenceError,
} = useGetLlmModels();

return (
<Flex vertical align="center">
<div style={{ maxWidth: 800 }}>
{inferenceError ? (
<Alert
style={{ margin: 10 }}
message={`Inference model error: ${inferenceError.message}`}
type="error"
/>
) : null}
{embeddingError ? (
<Alert
style={{ margin: 10 }}
message={`Embedding model error: ${embeddingError.message}`}
type="error"
/>
) : null}
</div>
<Flex vertical style={{ width: "80%", maxWidth: 1000 }} gap={20}>
<Typography.Title level={3}>Embedding Models</Typography.Title>
<EmbeddingModelTable
Expand Down
3 changes: 0 additions & 3 deletions ui/src/pages/RagChatTab/RagChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,10 @@ import ChatBodyController from "pages/RagChatTab/ChatOutput/ChatMessages/ChatBod
import { useContext } from "react";
import { RagChatContext } from "pages/RagChatTab/State/RagChatContext.tsx";
import { RagChatHeader } from "pages/RagChatTab/Header/RagChatHeader.tsx";
import { QueryClient } from "@tanstack/react-query";

const { Footer, Content } = Layout;

const RagChat = () => {
const queryClient = new QueryClient();
console.log("RagChat.tsx: RagChat: queryClient: ", queryClient);
const { dataSourceId, dataSources, activeSession } =
useContext(RagChatContext);

Expand Down

0 comments on commit a1f3b8a

Please sign in to comment.