Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ruff and mypy checks to CI #28

Merged
merged 3 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/pr_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ jobs:
pdm install
working-directory: llm-service

- name: Run ruff
run: |
pdm run ruff check app
working-directory: llm-service

- name: Run mypy
run: |
pdm run mypy app
working-directory: llm-service

- name: Test with pytest
run: |
pdm run pytest -sxvvra
Expand Down
38 changes: 0 additions & 38 deletions llm-service/__init__.py

This file was deleted.

9 changes: 4 additions & 5 deletions llm-service/app/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import inspect
import logging
from collections.abc import Callable, Iterator
from typing import ParamSpec, TypeVar
from typing import Awaitable, ParamSpec, TypeVar, Union

import requests
from fastapi import HTTPException
Expand Down Expand Up @@ -88,7 +88,7 @@ def _exception_propagation() -> Iterator[None]:
) from e


def propagates(f: Callable[P, T]) -> Callable[P, T]:
def propagates(f: Callable[P, T]) -> Union[Callable[P, T], Callable[P, Awaitable[T]]]:
"""
Function decorator for catching and propagating exceptions back to a client.

Expand Down Expand Up @@ -119,14 +119,13 @@ def banana():
"""

if inspect.iscoroutinefunction(f):

# for coroutines, the wrapper must be declared async,
# and the wrapped function's result must be awaited
@functools.wraps(f)
async def exception_propagation_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
with _exception_propagation():
return await f(*args, **kwargs)

ret: T = await f(*args, **kwargs)
return ret
else:

@functools.wraps(f)
Expand Down
7 changes: 4 additions & 3 deletions llm-service/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import time
from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager
from typing import AsyncGenerator

from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
Expand Down Expand Up @@ -79,7 +80,7 @@ def _configure_logger() -> None:


@functools.cache
def _get_app_log_handler():
def _get_app_log_handler() -> logging.Handler:
"""Format and return a reusable handler for application logging."""
# match Java backend's formatting
formatter = logging.Formatter(
Expand Down Expand Up @@ -115,7 +116,7 @@ def _configure_app_logger(app_logger: logging.Logger) -> None:
app_logger.setLevel(settings.rag_log_level)


def initialize_logging():
def initialize_logging() -> None:
logger.info("Initializing logging.")

_configure_app_logger(logging.getLogger("uvicorn.access"))
Expand All @@ -125,7 +126,7 @@ def initialize_logging():


@asynccontextmanager
async def lifespan(app: FastAPI):
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
initialize_logging()
yield

Expand Down
4 changes: 0 additions & 4 deletions llm-service/app/routers/index/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,3 @@
# include this for legacy UI calls
router.include_router(amp_update.router, prefix="/index", deprecated=True)
router.include_router(models.router)




22 changes: 18 additions & 4 deletions llm-service/app/routers/index/amp_update/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,36 @@
from .... import exceptions
from ....services.amp_update import check_amp_update_status

router = APIRouter(prefix="/amp-update" , tags=["AMP Update"])
router = APIRouter(prefix="/amp-update", tags=["AMP Update"])


@router.get("", summary="Returns a boolean for whether AMP needs updating.")
@exceptions.propagates
def amp_up_to_date_status() -> bool:
return check_amp_update_status()


@router.post("", summary="Updates AMP.")
@exceptions.propagates
def update_amp() -> str:
print(subprocess.run(["python /home/cdsw/llm-service/scripts/run_refresh_job.py"], shell=True, check=True))
print(
subprocess.run(
["python /home/cdsw/llm-service/scripts/run_refresh_job.py"],
shell=True,
check=True,
)
)
return "OK"


@router.get("/job-status", summary="Get AMP Status.")
@exceptions.propagates
def get_amp_status() -> str:
process: CompletedProcess[bytes] = subprocess.run(["python /home/cdsw/llm-service/scripts/get_job_run_status.py"], shell=True, check=True, capture_output=True)
stdout = process.stdout.decode('utf-8')
process: CompletedProcess[bytes] = subprocess.run(
["python /home/cdsw/llm-service/scripts/get_job_run_status.py"],
shell=True,
check=True,
capture_output=True,
)
stdout = process.stdout.decode("utf-8")
return stdout.strip()
11 changes: 3 additions & 8 deletions llm-service/app/routers/index/data_source/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def delete_document(data_source_id: int, doc_id: str) -> None:
doc_summaries.delete_document(data_source_id, doc_id)



class RagIndexDocumentRequest(BaseModel):
s3_bucket_name: str
s3_document_key: str
Expand All @@ -116,17 +115,13 @@ class RagIndexDocumentRequest(BaseModel):
)
@exceptions.propagates
def download_and_index(
data_source_id: int,
request: RagIndexDocumentRequest,
data_source_id: int,
request: RagIndexDocumentRequest,
) -> str:
with tempfile.TemporaryDirectory() as tmpdirname:
logger.debug("created temporary directory %s", tmpdirname)
s3.download(tmpdirname, request.s3_bucket_name, request.s3_document_key)
qdrant.download_and_index(
tmpdirname,
data_source_id,
request.configuration,
request.s3_document_key
tmpdirname, data_source_id, request.configuration, request.s3_document_key
)
return http.HTTPStatus.OK.phrase

8 changes: 4 additions & 4 deletions llm-service/app/routers/index/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
# DATA.
#
from typing import Literal
from typing import Any, Dict, List, Literal

from fastapi import APIRouter

Expand All @@ -54,13 +54,13 @@

@router.get("/llm", summary="Get LLM Inference models.")
@exceptions.propagates
def get_llm_models() -> list:
def get_llm_models() -> List[Dict[str, Any]]:
return get_available_llm_models()


@router.get("/embeddings", summary="Get LLM Embedding models.")
@exceptions.propagates
def get_llm_embedding_models() -> list:
def get_llm_embedding_models() -> List[Dict[str, Any]]:
return get_available_embedding_models()


Expand All @@ -79,4 +79,4 @@ def llm_model_test(model_name: str) -> Literal["ok"]:
@router.get("/embedding/{model_name}/test", summary="Test Embedding model.")
@exceptions.propagates
def embedding_model_test(model_name: str) -> str:
return test_embedding_model(model_name)
return test_embedding_model(model_name)
51 changes: 33 additions & 18 deletions llm-service/app/routers/index/sessions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,35 @@
import uuid

from fastapi import APIRouter

from pydantic import BaseModel

from .... import exceptions
from ....rag_types import RagPredictConfiguration
from ....services import llm_completion, qdrant
from ....services.chat import generate_suggested_questions, v2_chat
from ....services.chat_store import RagStudioChatMessage, chat_store
from ....services import qdrant, llm_completion
from ....services.chat import (v2_chat, generate_suggested_questions)

router = APIRouter(prefix="/sessions/{session_id}", tags=["Sessions"])

@router.get("/chat-history", summary="Returns an array of chat messages for the provided session.")

@router.get(
"/chat-history",
summary="Returns an array of chat messages for the provided session.",
)
@exceptions.propagates
def chat_history(session_id: int) -> list[RagStudioChatMessage]:
return chat_store.retrieve_chat_history(session_id=session_id)

@router.delete("/chat-history", summary="Deletes the chat history for the provided session.")

@router.delete(
"/chat-history", summary="Deletes the chat history for the provided session."
)
@exceptions.propagates
def clear_chat_history(session_id: int) -> str:
chat_store.clear_chat_history(session_id=session_id)
return "Chat history cleared."


@router.delete("", summary="Deletes the requested session.")
@exceptions.propagates
def delete_chat_history(session_id: int) -> str:
Expand All @@ -70,54 +78,61 @@ def delete_chat_history(session_id: int) -> str:
class RagStudioChatRequest(BaseModel):
data_source_id: int
query: str
configuration: qdrant.RagPredictConfiguration
configuration: RagPredictConfiguration


@router.post("/chat", summary="Chat with your documents in the requested datasource")
@exceptions.propagates
def chat(
session_id: int,
request: RagStudioChatRequest,
session_id: int,
request: RagStudioChatRequest,
) -> RagStudioChatMessage:
if request.configuration.exclude_knowledge_base:
return llm_talk(session_id, request)
return v2_chat(session_id, request.data_source_id, request.query, request.configuration)
return v2_chat(
session_id, request.data_source_id, request.query, request.configuration
)


def llm_talk(
session_id: int,
request: RagStudioChatRequest,
session_id: int,
request: RagStudioChatRequest,
) -> RagStudioChatMessage:
chat_response = llm_completion.completion(session_id, request.query, request.configuration)
chat_response = llm_completion.completion(
session_id, request.query, request.configuration
)
new_chat_message = RagStudioChatMessage(
id=str(uuid.uuid4()),
source_nodes=[],
evaluations=[],
rag_message={
"user": request.query,
"assistant": chat_response.message.content,
"assistant": str(chat_response.message.content),
},
timestamp=time.time()
timestamp=time.time(),
)
chat_store.append_to_history(session_id, [new_chat_message])
return new_chat_message


class SuggestQuestionsRequest(BaseModel):
data_source_id: int
configuration: qdrant.RagPredictConfiguration = qdrant.RagPredictConfiguration()
configuration: RagPredictConfiguration = RagPredictConfiguration()


class RagSuggestedQuestionsResponse(BaseModel):
suggested_questions: list[str]


@router.post("/suggest-questions", summary="Suggest questions with context")
@exceptions.propagates
def suggest_questions(
session_id: int,
request: SuggestQuestionsRequest,
session_id: int,
request: SuggestQuestionsRequest,
) -> RagSuggestedQuestionsResponse:
data_source_size = qdrant.size_of(request.data_source_id)
qdrant.check_data_source_exists(data_source_size)
suggested_questions = generate_suggested_questions(
request.configuration, request.data_source_id, data_source_size, session_id
)
return RagSuggestedQuestionsResponse(suggested_questions=suggested_questions)
return RagSuggestedQuestionsResponse(suggested_questions=suggested_questions)
Loading
Loading