-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
introduce engine for API #434
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
import inspect | ||
import uuid | ||
from typing import ( | ||
TYPE_CHECKING, | ||
Any, | ||
AsyncIterator, | ||
Awaitable, | ||
|
@@ -12,21 +13,24 @@ | |
Iterable, | ||
Iterator, | ||
Optional, | ||
Type, | ||
TypeVar, | ||
Union, | ||
cast, | ||
) | ||
|
||
import pydantic | ||
from fastapi import status | ||
from starlette.concurrency import iterate_in_threadpool, run_in_threadpool | ||
|
||
from ._components import Assistant, Component, Message, MessageRole, SourceStorage | ||
from ._document import Document, LocalDocument | ||
from ._utils import RagnaException, default_user, merge_models | ||
|
||
if TYPE_CHECKING: | ||
from ragna.deploy import Config | ||
|
||
T = TypeVar("T") | ||
C = TypeVar("C", bound=Component) | ||
C = TypeVar("C", bound=Component, covariant=True) | ||
|
||
|
||
class Rag(Generic[C]): | ||
|
@@ -41,20 +45,69 @@ class Rag(Generic[C]): | |
``` | ||
""" | ||
|
||
def __init__(self) -> None: | ||
self._components: dict[Type[C], C] = {} | ||
def __init__( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two new functionalities for the
We had this before, but only for the API. While refactoring, it made the code a lot cleaner when moving it here. |
||
self, | ||
*, | ||
config: Optional[Config] = None, | ||
ignore_unavailable_components: bool = False, | ||
) -> None: | ||
self._components: dict[type[C], C] = {} | ||
self._display_name_map: dict[str, type[C]] = {} | ||
|
||
if config is not None: | ||
self._preload_components( | ||
config=config, | ||
ignore_unavailable_components=ignore_unavailable_components, | ||
) | ||
|
||
def _preload_components( | ||
self, *, config: Config, ignore_unavailable_components: bool | ||
) -> None: | ||
for components in [config.source_storages, config.assistants]: | ||
components = cast(list[type[Component]], components) | ||
at_least_one = False | ||
for component in components: | ||
loaded_component = self._load_component( | ||
component, # type: ignore[arg-type] | ||
ignore_unavailable=ignore_unavailable_components, | ||
) | ||
if loaded_component is None: | ||
print( | ||
f"Ignoring {component.display_name()}, because it is not available." | ||
) | ||
else: | ||
at_least_one = True | ||
|
||
if not at_least_one: | ||
raise RagnaException( | ||
"No component available", | ||
components=[component.display_name() for component in components], | ||
) | ||
|
||
def _load_component( | ||
self, component: Union[Type[C], C], *, ignore_unavailable: bool = False | ||
self, component: Union[C, type[C], str], *, ignore_unavailable: bool = False | ||
) -> Optional[C]: | ||
cls: Type[C] | ||
cls: type[C] | ||
instance: Optional[C] | ||
|
||
if isinstance(component, Component): | ||
instance = cast(C, component) | ||
cls = type(instance) | ||
elif isinstance(component, type) and issubclass(component, Component): | ||
cls = component | ||
instance = None | ||
elif isinstance(component, str): | ||
try: | ||
cls = self._display_name_map[component] | ||
except KeyError: | ||
raise RagnaException( | ||
"Unknown component", | ||
display_name=component, | ||
help="Did you forget to create the Rag() instance with a config?", | ||
http_status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, | ||
http_detail=f"Unknown component '{component}'", | ||
) from None | ||
|
||
instance = None | ||
else: | ||
raise RagnaException | ||
|
@@ -71,31 +124,33 @@ def _load_component( | |
instance = cls() | ||
|
||
self._components[cls] = instance | ||
self._display_name_map[cls.display_name()] = cls | ||
|
||
return self._components[cls] | ||
|
||
def chat( | ||
self, | ||
*, | ||
documents: Iterable[Any], | ||
source_storage: Union[Type[SourceStorage], SourceStorage], | ||
assistant: Union[Type[Assistant], Assistant], | ||
source_storage: Union[SourceStorage, type[SourceStorage], str], | ||
assistant: Union[Assistant, type[Assistant], str], | ||
**params: Any, | ||
) -> Chat: | ||
"""Create a new [ragna.core.Chat][]. | ||
|
||
Args: | ||
documents: Documents to use. If any item is not a [ragna.core.Document][], | ||
[ragna.core.LocalDocument.from_path][] is invoked on it. | ||
FIXME | ||
source_storage: Source storage to use. | ||
assistant: Assistant to use. | ||
**params: Additional parameters passed to the source storage and assistant. | ||
""" | ||
return Chat( | ||
self, | ||
documents=documents, | ||
source_storage=source_storage, | ||
assistant=assistant, | ||
source_storage=cast(SourceStorage, self._load_component(source_storage)), # type: ignore[arg-type] | ||
assistant=cast(Assistant, self._load_component(assistant)), # type: ignore[arg-type] | ||
**params, | ||
) | ||
|
||
|
@@ -146,17 +201,15 @@ def __init__( | |
rag: Rag, | ||
*, | ||
documents: Iterable[Any], | ||
source_storage: Union[Type[SourceStorage], SourceStorage], | ||
assistant: Union[Type[Assistant], Assistant], | ||
source_storage: SourceStorage, | ||
assistant: Assistant, | ||
**params: Any, | ||
) -> None: | ||
self._rag = rag | ||
|
||
self.documents = self._parse_documents(documents) | ||
self.source_storage = cast( | ||
SourceStorage, self._rag._load_component(source_storage) | ||
) | ||
self.assistant = cast(Assistant, self._rag._load_component(assistant)) | ||
self.source_storage = source_storage | ||
self.assistant = assistant | ||
|
||
special_params = SpecialChatParams().model_dump() | ||
special_params.update(params) | ||
|
@@ -306,6 +359,6 @@ async def __aenter__(self) -> Chat: | |
return self | ||
|
||
async def __aexit__( | ||
self, exc_type: Type[Exception], exc: Exception, traceback: str | ||
self, exc_type: type[Exception], exc: Exception, traceback: str | ||
) -> None: | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import uuid | ||
from typing import Annotated, AsyncIterator, cast | ||
|
||
import aiofiles | ||
import pydantic | ||
from fastapi import ( | ||
APIRouter, | ||
Body, | ||
Depends, | ||
Form, | ||
HTTPException, | ||
UploadFile, | ||
) | ||
from fastapi.responses import StreamingResponse | ||
|
||
import ragna | ||
import ragna.core | ||
from ragna._compat import anext | ||
from ragna.core._utils import default_user | ||
from ragna.deploy import Config | ||
|
||
from . import _schemas as schemas | ||
from ._engine import Engine | ||
|
||
|
||
def make_router(config: Config, engine: Engine) -> APIRouter: | ||
router = APIRouter(tags=["API"]) | ||
|
||
def get_user() -> str: | ||
return default_user() | ||
|
||
UserDependency = Annotated[str, Depends(get_user)] | ||
|
||
# TODO: the document endpoints do not go through the engine, because they'll change | ||
# quite drastically when the UI no longer depends on the API | ||
Comment on lines
+34
to
+35
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What the inline comment says. This will be addressed in the follow-up PR |
||
|
||
_database = engine._database | ||
|
||
@router.post("/document") | ||
async def create_document_upload_info( | ||
user: UserDependency, | ||
name: Annotated[str, Body(..., embed=True)], | ||
) -> schemas.DocumentUpload: | ||
with _database.get_session() as session: | ||
document = schemas.Document(name=name) | ||
metadata, parameters = await config.document.get_upload_info( | ||
config=config, user=user, id=document.id, name=document.name | ||
) | ||
document.metadata = metadata | ||
_database.add_document( | ||
session, user=user, document=document, metadata=metadata | ||
) | ||
return schemas.DocumentUpload(parameters=parameters, document=document) | ||
|
||
# TODO: Add UI support and documentation for this endpoint (#406) | ||
@router.post("/documents") | ||
async def create_documents_upload_info( | ||
user: UserDependency, | ||
names: Annotated[list[str], Body(..., embed=True)], | ||
) -> list[schemas.DocumentUpload]: | ||
with _database.get_session() as session: | ||
document_metadata_collection = [] | ||
document_upload_collection = [] | ||
for name in names: | ||
document = schemas.Document(name=name) | ||
metadata, parameters = await config.document.get_upload_info( | ||
config=config, user=user, id=document.id, name=document.name | ||
) | ||
document.metadata = metadata | ||
document_metadata_collection.append((document, metadata)) | ||
document_upload_collection.append( | ||
schemas.DocumentUpload(parameters=parameters, document=document) | ||
) | ||
|
||
_database.add_documents( | ||
session, | ||
user=user, | ||
document_metadata_collection=document_metadata_collection, | ||
) | ||
return document_upload_collection | ||
|
||
# TODO: Add new endpoint for batch uploading documents (#407) | ||
@router.put("/document") | ||
async def upload_document( | ||
token: Annotated[str, Form()], file: UploadFile | ||
) -> schemas.Document: | ||
if not issubclass(config.document, ragna.core.LocalDocument): | ||
raise HTTPException( | ||
status_code=400, | ||
detail="Ragna configuration does not support local upload", | ||
) | ||
with _database.get_session() as session: | ||
user, id = ragna.core.LocalDocument.decode_upload_token(token) | ||
document = _database.get_document(session, user=user, id=id) | ||
|
||
core_document = cast( | ||
ragna.core.LocalDocument, engine._to_core.document(document) | ||
) | ||
core_document.path.parent.mkdir(parents=True, exist_ok=True) | ||
async with aiofiles.open(core_document.path, "wb") as document_file: | ||
while content := await file.read(1024): | ||
await document_file.write(content) | ||
|
||
return document | ||
|
||
@router.get("/components") | ||
def get_components(_: UserDependency) -> schemas.Components: | ||
return engine.get_components() | ||
|
||
@router.post("/chats") | ||
async def create_chat( | ||
user: UserDependency, | ||
chat_metadata: schemas.ChatMetadata, | ||
) -> schemas.Chat: | ||
return engine.create_chat(user=user, chat_metadata=chat_metadata) | ||
|
||
@router.get("/chats") | ||
async def get_chats(user: UserDependency) -> list[schemas.Chat]: | ||
return engine.get_chats(user=user) | ||
|
||
@router.get("/chats/{id}") | ||
async def get_chat(user: UserDependency, id: uuid.UUID) -> schemas.Chat: | ||
return engine.get_chat(user=user, id=id) | ||
|
||
@router.post("/chats/{id}/prepare") | ||
async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message: | ||
return await engine.prepare_chat(user=user, id=id) | ||
|
||
@router.post("/chats/{id}/answer") | ||
async def answer( | ||
user: UserDependency, | ||
id: uuid.UUID, | ||
prompt: Annotated[str, Body(..., embed=True)], | ||
stream: Annotated[bool, Body(..., embed=True)] = False, | ||
) -> schemas.Message: | ||
message_stream = engine.answer_stream(user=user, chat_id=id, prompt=prompt) | ||
answer = await anext(message_stream) | ||
|
||
if not stream: | ||
content_chunks = [chunk.content async for chunk in message_stream] | ||
answer.content += "".join(content_chunks) | ||
return answer | ||
|
||
async def message_chunks() -> AsyncIterator[schemas.Message]: | ||
yield answer | ||
async for chunk in message_stream: | ||
yield chunk | ||
|
||
async def to_jsonl( | ||
models: AsyncIterator[pydantic.BaseModel], | ||
) -> AsyncIterator[str]: | ||
async for model in models: | ||
yield f"{model.model_dump_json()}\n" | ||
|
||
return StreamingResponse( # type: ignore[return-value] | ||
to_jsonl(message_chunks()) | ||
) | ||
|
||
@router.delete("/chats/{id}") | ||
async def delete_chat(user: UserDependency, id: uuid.UUID) -> None: | ||
engine.delete_chat(user=user, id=id) | ||
|
||
return router |
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was kinda weird to not have these fields here given that we have them in the database as well as in the schema. Without them, converting back and forth became harder, because we'd need to keep track of this information manually.