diff --git a/pyproject.toml b/pyproject.toml index f85a71a4..6df079a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -162,7 +162,7 @@ ignore_missing_imports = true [[tool.mypy.overrides]] module = [ - "ragna.deploy._api.orm", + "ragna.deploy._orm", ] # Our ORM schema doesn't really work with mypy. There are some other ways to define it # to play ball. We should do that in the future. diff --git a/ragna/core/_components.py b/ragna/core/_components.py index d237c1b8..2f987910 100644 --- a/ragna/core/_components.py +++ b/ragna/core/_components.py @@ -1,9 +1,11 @@ from __future__ import annotations import abc +import datetime import enum import functools import inspect +import uuid from typing import AsyncIterable, AsyncIterator, Iterator, Optional, Type, Union import pydantic @@ -157,6 +159,8 @@ def __init__( *, role: MessageRole = MessageRole.SYSTEM, sources: Optional[list[Source]] = None, + id: Optional[uuid.UUID] = None, + timestamp: Optional[datetime.datetime] = None, ) -> None: if isinstance(content, str): self._content: str = content @@ -166,6 +170,14 @@ def __init__( self.role = role self.sources = sources or [] + if id is None: + id = uuid.uuid4() + self.id = id + + if timestamp is None: + timestamp = datetime.datetime.utcnow() + self.timestamp = timestamp + async def __aiter__(self) -> AsyncIterator[str]: if hasattr(self, "_content"): yield self._content diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index 6cdff127..98c32a42 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -4,6 +4,7 @@ import inspect import uuid from typing import ( + TYPE_CHECKING, Any, AsyncIterator, Awaitable, @@ -12,21 +13,24 @@ Iterable, Iterator, Optional, - Type, TypeVar, Union, cast, ) import pydantic +from fastapi import status from starlette.concurrency import iterate_in_threadpool, run_in_threadpool from ._components import Assistant, Component, Message, MessageRole, SourceStorage from ._document import Document, LocalDocument from ._utils import RagnaException, default_user, merge_models +if TYPE_CHECKING: + from ragna.deploy import Config + T = TypeVar("T") -C = TypeVar("C", bound=Component) +C = TypeVar("C", bound=Component, covariant=True) class Rag(Generic[C]): @@ -41,13 +45,49 @@ class Rag(Generic[C]): ``` """ - def __init__(self) -> None: - self._components: dict[Type[C], C] = {} + def __init__( + self, + *, + config: Optional[Config] = None, + ignore_unavailable_components: bool = False, + ) -> None: + self._components: dict[type[C], C] = {} + self._display_name_map: dict[str, type[C]] = {} + + if config is not None: + self._preload_components( + config=config, + ignore_unavailable_components=ignore_unavailable_components, + ) + + def _preload_components( + self, *, config: Config, ignore_unavailable_components: bool + ) -> None: + for components in [config.source_storages, config.assistants]: + components = cast(list[type[Component]], components) + at_least_one = False + for component in components: + loaded_component = self._load_component( + component, # type: ignore[arg-type] + ignore_unavailable=ignore_unavailable_components, + ) + if loaded_component is None: + print( + f"Ignoring {component.display_name()}, because it is not available." + ) + else: + at_least_one = True + + if not at_least_one: + raise RagnaException( + "No component available", + components=[component.display_name() for component in components], + ) def _load_component( - self, component: Union[Type[C], C], *, ignore_unavailable: bool = False + self, component: Union[C, type[C], str], *, ignore_unavailable: bool = False ) -> Optional[C]: - cls: Type[C] + cls: type[C] instance: Optional[C] if isinstance(component, Component): @@ -55,6 +95,19 @@ def _load_component( cls = type(instance) elif isinstance(component, type) and issubclass(component, Component): cls = component + instance = None + elif isinstance(component, str): + try: + cls = self._display_name_map[component] + except KeyError: + raise RagnaException( + "Unknown component", + display_name=component, + help="Did you forget to create the Rag() instance with a config?", + http_status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + http_detail=f"Unknown component '{component}'", + ) from None + instance = None else: raise RagnaException @@ -71,6 +124,7 @@ def _load_component( instance = cls() self._components[cls] = instance + self._display_name_map[cls.display_name()] = cls return self._components[cls] @@ -78,8 +132,8 @@ def chat( self, *, documents: Iterable[Any], - source_storage: Union[Type[SourceStorage], SourceStorage], - assistant: Union[Type[Assistant], Assistant], + source_storage: Union[SourceStorage, type[SourceStorage], str], + assistant: Union[Assistant, type[Assistant], str], **params: Any, ) -> Chat: """Create a new [ragna.core.Chat][]. @@ -87,6 +141,7 @@ def chat( Args: documents: Documents to use. If any item is not a [ragna.core.Document][], [ragna.core.LocalDocument.from_path][] is invoked on it. + FIXME source_storage: Source storage to use. assistant: Assistant to use. **params: Additional parameters passed to the source storage and assistant. @@ -94,8 +149,8 @@ def chat( return Chat( self, documents=documents, - source_storage=source_storage, - assistant=assistant, + source_storage=cast(SourceStorage, self._load_component(source_storage)), # type: ignore[arg-type] + assistant=cast(Assistant, self._load_component(assistant)), # type: ignore[arg-type] **params, ) @@ -146,17 +201,15 @@ def __init__( rag: Rag, *, documents: Iterable[Any], - source_storage: Union[Type[SourceStorage], SourceStorage], - assistant: Union[Type[Assistant], Assistant], + source_storage: SourceStorage, + assistant: Assistant, **params: Any, ) -> None: self._rag = rag self.documents = self._parse_documents(documents) - self.source_storage = cast( - SourceStorage, self._rag._load_component(source_storage) - ) - self.assistant = cast(Assistant, self._rag._load_component(assistant)) + self.source_storage = source_storage + self.assistant = assistant special_params = SpecialChatParams().model_dump() special_params.update(params) @@ -306,6 +359,6 @@ async def __aenter__(self) -> Chat: return self async def __aexit__( - self, exc_type: Type[Exception], exc: Exception, traceback: str + self, exc_type: type[Exception], exc: Exception, traceback: str ) -> None: pass diff --git a/ragna/deploy/_api.py b/ragna/deploy/_api.py new file mode 100644 index 00000000..d3194064 --- /dev/null +++ b/ragna/deploy/_api.py @@ -0,0 +1,163 @@ +import uuid +from typing import Annotated, AsyncIterator, cast + +import aiofiles +import pydantic +from fastapi import ( + APIRouter, + Body, + Depends, + Form, + HTTPException, + UploadFile, +) +from fastapi.responses import StreamingResponse + +import ragna +import ragna.core +from ragna._compat import anext +from ragna.core._utils import default_user +from ragna.deploy import Config + +from . import _schemas as schemas +from ._engine import Engine + + +def make_router(config: Config, engine: Engine) -> APIRouter: + router = APIRouter(tags=["API"]) + + def get_user() -> str: + return default_user() + + UserDependency = Annotated[str, Depends(get_user)] + + # TODO: the document endpoints do not go through the engine, because they'll change + # quite drastically when the UI no longer depends on the API + + _database = engine._database + + @router.post("/document") + async def create_document_upload_info( + user: UserDependency, + name: Annotated[str, Body(..., embed=True)], + ) -> schemas.DocumentUpload: + with _database.get_session() as session: + document = schemas.Document(name=name) + metadata, parameters = await config.document.get_upload_info( + config=config, user=user, id=document.id, name=document.name + ) + document.metadata = metadata + _database.add_document( + session, user=user, document=document, metadata=metadata + ) + return schemas.DocumentUpload(parameters=parameters, document=document) + + # TODO: Add UI support and documentation for this endpoint (#406) + @router.post("/documents") + async def create_documents_upload_info( + user: UserDependency, + names: Annotated[list[str], Body(..., embed=True)], + ) -> list[schemas.DocumentUpload]: + with _database.get_session() as session: + document_metadata_collection = [] + document_upload_collection = [] + for name in names: + document = schemas.Document(name=name) + metadata, parameters = await config.document.get_upload_info( + config=config, user=user, id=document.id, name=document.name + ) + document.metadata = metadata + document_metadata_collection.append((document, metadata)) + document_upload_collection.append( + schemas.DocumentUpload(parameters=parameters, document=document) + ) + + _database.add_documents( + session, + user=user, + document_metadata_collection=document_metadata_collection, + ) + return document_upload_collection + + # TODO: Add new endpoint for batch uploading documents (#407) + @router.put("/document") + async def upload_document( + token: Annotated[str, Form()], file: UploadFile + ) -> schemas.Document: + if not issubclass(config.document, ragna.core.LocalDocument): + raise HTTPException( + status_code=400, + detail="Ragna configuration does not support local upload", + ) + with _database.get_session() as session: + user, id = ragna.core.LocalDocument.decode_upload_token(token) + document = _database.get_document(session, user=user, id=id) + + core_document = cast( + ragna.core.LocalDocument, engine._to_core.document(document) + ) + core_document.path.parent.mkdir(parents=True, exist_ok=True) + async with aiofiles.open(core_document.path, "wb") as document_file: + while content := await file.read(1024): + await document_file.write(content) + + return document + + @router.get("/components") + def get_components(_: UserDependency) -> schemas.Components: + return engine.get_components() + + @router.post("/chats") + async def create_chat( + user: UserDependency, + chat_metadata: schemas.ChatMetadata, + ) -> schemas.Chat: + return engine.create_chat(user=user, chat_metadata=chat_metadata) + + @router.get("/chats") + async def get_chats(user: UserDependency) -> list[schemas.Chat]: + return engine.get_chats(user=user) + + @router.get("/chats/{id}") + async def get_chat(user: UserDependency, id: uuid.UUID) -> schemas.Chat: + return engine.get_chat(user=user, id=id) + + @router.post("/chats/{id}/prepare") + async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message: + return await engine.prepare_chat(user=user, id=id) + + @router.post("/chats/{id}/answer") + async def answer( + user: UserDependency, + id: uuid.UUID, + prompt: Annotated[str, Body(..., embed=True)], + stream: Annotated[bool, Body(..., embed=True)] = False, + ) -> schemas.Message: + message_stream = engine.answer_stream(user=user, chat_id=id, prompt=prompt) + answer = await anext(message_stream) + + if not stream: + content_chunks = [chunk.content async for chunk in message_stream] + answer.content += "".join(content_chunks) + return answer + + async def message_chunks() -> AsyncIterator[schemas.Message]: + yield answer + async for chunk in message_stream: + yield chunk + + async def to_jsonl( + models: AsyncIterator[pydantic.BaseModel], + ) -> AsyncIterator[str]: + async for model in models: + yield f"{model.model_dump_json()}\n" + + return StreamingResponse( # type: ignore[return-value] + to_jsonl(message_chunks()) + ) + + @router.delete("/chats/{id}") + async def delete_chat(user: UserDependency, id: uuid.UUID) -> None: + engine.delete_chat(user=user, id=id) + + return router diff --git a/ragna/deploy/_api/__init__.py b/ragna/deploy/_api/__init__.py deleted file mode 100644 index f99fb828..00000000 --- a/ragna/deploy/_api/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .core import make_router diff --git a/ragna/deploy/_api/core.py b/ragna/deploy/_api/core.py deleted file mode 100644 index 9bb9c682..00000000 --- a/ragna/deploy/_api/core.py +++ /dev/null @@ -1,318 +0,0 @@ -import contextlib -import uuid -from typing import Annotated, Any, AsyncIterator, Iterator, Type, cast - -import aiofiles -from fastapi import ( - APIRouter, - Body, - Depends, - Form, - HTTPException, - UploadFile, - status, -) -from fastapi.responses import StreamingResponse -from pydantic import BaseModel - -import ragna -import ragna.core -from ragna._compat import aiter, anext -from ragna.core import Assistant, Component, Rag, RagnaException, SourceStorage -from ragna.core._rag import SpecialChatParams -from ragna.core._utils import default_user -from ragna.deploy import Config - -from . import database, schemas - - -def make_router(config: Config, ignore_unavailable_components: bool) -> APIRouter: - router = APIRouter(tags=["API"]) - - rag = Rag() # type: ignore[var-annotated] - components_map: dict[str, Component] = {} - for components in [config.source_storages, config.assistants]: - components = cast(list[Type[Component]], components) - at_least_one = False - for component in components: - loaded_component = rag._load_component( - component, ignore_unavailable=ignore_unavailable_components - ) - if loaded_component is None: - print( - f"Ignoring {component.display_name()}, because it is not available." - ) - else: - at_least_one = True - components_map[component.display_name()] = loaded_component - - if not at_least_one: - raise RagnaException( - "No component available", - components=[component.display_name() for component in components], - ) - - def get_component(display_name: str) -> Component: - component = components_map.get(display_name) - if component is None: - raise RagnaException( - "Unknown component", - display_name=display_name, - http_status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - http_detail=RagnaException.MESSAGE, - ) - - return component - - @router.get("/") - async def version() -> str: - return ragna.__version__ - - def get_user() -> str: - return default_user() - - UserDependency = Annotated[str, Depends(get_user)] - - def _get_component_json_schema( - component: Type[Component], - ) -> dict[str, dict[str, Any]]: - json_schema = component._protocol_model().model_json_schema() - # FIXME: there is likely a better way to exclude certain fields builtin in - # pydantic - for special_param in SpecialChatParams.model_fields: - if ( - "properties" in json_schema - and special_param in json_schema["properties"] - ): - del json_schema["properties"][special_param] - if "required" in json_schema and special_param in json_schema["required"]: - json_schema["required"].remove(special_param) - return json_schema - - @router.get("/components") - async def get_components(_: UserDependency) -> schemas.Components: - return schemas.Components( - documents=sorted(config.document.supported_suffixes()), - source_storages=[ - _get_component_json_schema(type(source_storage)) - for source_storage in components_map.values() - if isinstance(source_storage, SourceStorage) - ], - assistants=[ - _get_component_json_schema(type(assistant)) - for assistant in components_map.values() - if isinstance(assistant, Assistant) - ], - ) - - make_session = database.get_sessionmaker(config.database_url) - - @contextlib.contextmanager - def get_session() -> Iterator[database.Session]: - with make_session() as session: # type: ignore[attr-defined] - yield session - - @router.post("/document") - async def create_document_upload_info( - user: UserDependency, - name: Annotated[str, Body(..., embed=True)], - ) -> schemas.DocumentUpload: - with get_session() as session: - document = schemas.Document(name=name) - metadata, parameters = await config.document.get_upload_info( - config=config, user=user, id=document.id, name=document.name - ) - database.add_document( - session, user=user, document=document, metadata=metadata - ) - return schemas.DocumentUpload(parameters=parameters, document=document) - - # TODO: Add UI support and documentation for this endpoint (#406) - @router.post("/documents") - async def create_documents_upload_info( - user: UserDependency, - names: Annotated[list[str], Body(..., embed=True)], - ) -> list[schemas.DocumentUpload]: - with get_session() as session: - document_metadata_collection = [] - document_upload_collection = [] - for name in names: - document = schemas.Document(name=name) - metadata, parameters = await config.document.get_upload_info( - config=config, user=user, id=document.id, name=document.name - ) - document_metadata_collection.append((document, metadata)) - document_upload_collection.append( - schemas.DocumentUpload(parameters=parameters, document=document) - ) - - database.add_documents( - session, - user=user, - document_metadata_collection=document_metadata_collection, - ) - return document_upload_collection - - # TODO: Add new endpoint for batch uploading documents (#407) - @router.put("/document") - async def upload_document( - token: Annotated[str, Form()], file: UploadFile - ) -> schemas.Document: - if not issubclass(config.document, ragna.core.LocalDocument): - raise HTTPException( - status_code=400, - detail="Ragna configuration does not support local upload", - ) - with get_session() as session: - user, id = ragna.core.LocalDocument.decode_upload_token(token) - document, metadata = database.get_document(session, user=user, id=id) - - core_document = ragna.core.LocalDocument( - id=document.id, name=document.name, metadata=metadata - ) - core_document.path.parent.mkdir(parents=True, exist_ok=True) - async with aiofiles.open(core_document.path, "wb") as document_file: - while content := await file.read(1024): - await document_file.write(content) - - return document - - def schema_to_core_chat( - session: database.Session, *, user: str, chat: schemas.Chat - ) -> ragna.core.Chat: - core_chat = rag.chat( - documents=[ - config.document( - id=document.id, - name=document.name, - metadata=database.get_document( - session, - user=user, - id=document.id, - )[1], - ) - for document in chat.metadata.documents - ], - source_storage=get_component(chat.metadata.source_storage), # type: ignore[arg-type] - assistant=get_component(chat.metadata.assistant), # type: ignore[arg-type] - user=user, - chat_id=chat.id, - chat_name=chat.metadata.name, - **chat.metadata.params, - ) - # FIXME: We need to reconstruct the previous messages here. Right now this is - # not needed, because the chat itself never accesses past messages. However, - # if we implement a chat history feature, i.e. passing past messages to - # the assistant, this becomes crucial. - core_chat._messages = [] - core_chat._prepared = chat.prepared - - return core_chat - - @router.post("/chats") - async def create_chat( - user: UserDependency, - chat_metadata: schemas.ChatMetadata, - ) -> schemas.Chat: - with get_session() as session: - chat = schemas.Chat(metadata=chat_metadata) - - # Although we don't need the actual ragna.core.Chat object here, - # we use it to validate the documents and metadata. - schema_to_core_chat(session, user=user, chat=chat) - - database.add_chat(session, user=user, chat=chat) - return chat - - @router.get("/chats") - async def get_chats(user: UserDependency) -> list[schemas.Chat]: - with get_session() as session: - return database.get_chats(session, user=user) - - @router.get("/chats/{id}") - async def get_chat(user: UserDependency, id: uuid.UUID) -> schemas.Chat: - with get_session() as session: - return database.get_chat(session, user=user, id=id) - - @router.post("/chats/{id}/prepare") - async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message: - with get_session() as session: - chat = database.get_chat(session, user=user, id=id) - - core_chat = schema_to_core_chat(session, user=user, chat=chat) - - welcome = schemas.Message.from_core(await core_chat.prepare()) - - chat.prepared = True - chat.messages.append(welcome) - database.update_chat(session, user=user, chat=chat) - - return welcome - - @router.post("/chats/{id}/answer") - async def answer( - user: UserDependency, - id: uuid.UUID, - prompt: Annotated[str, Body(..., embed=True)], - stream: Annotated[bool, Body(..., embed=True)] = False, - ) -> schemas.Message: - with get_session() as session: - chat = database.get_chat(session, user=user, id=id) - chat.messages.append( - schemas.Message(content=prompt, role=ragna.core.MessageRole.USER) - ) - core_chat = schema_to_core_chat(session, user=user, chat=chat) - - core_answer = await core_chat.answer(prompt, stream=stream) - - if stream: - - async def message_chunks() -> AsyncIterator[BaseModel]: - core_answer_stream = aiter(core_answer) - content_chunk = await anext(core_answer_stream) - - answer = schemas.Message( - content=content_chunk, - role=core_answer.role, - sources=[ - schemas.Source.from_core(source) - for source in core_answer.sources - ], - ) - yield answer - - # Avoid sending the sources multiple times - answer_chunk = answer.model_copy(update=dict(sources=None)) - content_chunks = [answer_chunk.content] - async for content_chunk in core_answer_stream: - content_chunks.append(content_chunk) - answer_chunk.content = content_chunk - yield answer_chunk - - with get_session() as session: - answer.content = "".join(content_chunks) - chat.messages.append(answer) - database.update_chat(session, user=user, chat=chat) - - async def to_jsonl(models: AsyncIterator[Any]) -> AsyncIterator[str]: - async for model in models: - yield f"{model.model_dump_json()}\n" - - return StreamingResponse( # type: ignore[return-value] - to_jsonl(message_chunks()) - ) - else: - answer = schemas.Message.from_core(core_answer) - - with get_session() as session: - chat.messages.append(answer) - database.update_chat(session, user=user, chat=chat) - - return answer - - @router.delete("/chats/{id}") - async def delete_chat(user: UserDependency, id: uuid.UUID) -> None: - with get_session() as session: - database.delete_chat(session, user=user, id=id) - - return router diff --git a/ragna/deploy/_api/database.py b/ragna/deploy/_api/database.py deleted file mode 100644 index 2a61b048..00000000 --- a/ragna/deploy/_api/database.py +++ /dev/null @@ -1,270 +0,0 @@ -from __future__ import annotations - -import functools -import uuid -from typing import Any, Callable, Optional, cast -from urllib.parse import urlsplit - -from sqlalchemy import create_engine, select -from sqlalchemy.orm import Session, joinedload -from sqlalchemy.orm import sessionmaker as _sessionmaker - -from ragna.core import RagnaException - -from . import orm, schemas - - -def get_sessionmaker(database_url: str) -> Callable[[], Session]: - components = urlsplit(database_url) - if components.scheme == "sqlite": - connect_args = dict(check_same_thread=False) - else: - connect_args = dict() - engine = create_engine(database_url, connect_args=connect_args) - orm.Base.metadata.create_all(bind=engine) - return _sessionmaker(autocommit=False, autoflush=False, bind=engine) - - -@functools.lru_cache(maxsize=1024) -def _get_user_id(session: Session, username: str) -> uuid.UUID: - user: Optional[orm.User] = session.execute( - select(orm.User).where(orm.User.name == username) - ).scalar_one_or_none() - - if user is None: - # Add a new user if the current username is not registered yet. Since this is - # behind the authentication layer, we don't need any extra security here. - user = orm.User(id=uuid.uuid4(), name=username) - session.add(user) - session.commit() - - return cast(uuid.UUID, user.id) - - -def add_document( - session: Session, *, user: str, document: schemas.Document, metadata: dict[str, Any] -) -> None: - session.add( - orm.Document( - id=document.id, - user_id=_get_user_id(session, user), - name=document.name, - metadata_=metadata, - ) - ) - session.commit() - - -def add_documents( - session: Session, - *, - user: str, - document_metadata_collection: list[tuple[schemas.Document, dict[str, Any]]], -) -> None: - """ - Add multiple documents to the database. - - This function allows adding multiple documents at once by calling `add_all`. This is - important when there is non-negligible latency attached to each database operation. - """ - user_id = _get_user_id(session, user) - documents = [ - orm.Document( - id=document.id, - user_id=user_id, - name=document.name, - metadata_=metadata, - ) - for document, metadata in document_metadata_collection - ] - session.add_all(documents) - session.commit() - - -def _orm_to_schema_document(document: orm.Document) -> schemas.Document: - return schemas.Document(id=document.id, name=document.name) - - -@functools.lru_cache(maxsize=1024) -def get_document( - session: Session, *, user: str, id: uuid.UUID -) -> tuple[schemas.Document, dict[str, Any]]: - document = session.execute( - select(orm.Document).where( - (orm.Document.user_id == _get_user_id(session, user)) - & (orm.Document.id == id) - ) - ).scalar_one_or_none() - return _orm_to_schema_document(document), document.metadata_ - - -def add_chat(session: Session, *, user: str, chat: schemas.Chat) -> None: - document_ids = {document.id for document in chat.metadata.documents} - documents = ( - session.execute(select(orm.Document).where(orm.Document.id.in_(document_ids))) - .scalars() - .all() - ) - if len(documents) != len(document_ids): - raise RagnaException( - str(set(document_ids) - {document.id for document in documents}) - ) - session.add( - orm.Chat( - id=chat.id, - user_id=_get_user_id(session, user), - name=chat.metadata.name, - documents=documents, - source_storage=chat.metadata.source_storage, - assistant=chat.metadata.assistant, - params=chat.metadata.params, - prepared=chat.prepared, - ) - ) - session.commit() - - -def _orm_to_schema_chat(chat: orm.Chat) -> schemas.Chat: - documents = [ - schemas.Document(id=document.id, name=document.name) - for document in chat.documents - ] - messages = [ - schemas.Message( - id=message.id, - role=message.role, - content=message.content, - sources=[ - schemas.Source( - id=source.id, - document=_orm_to_schema_document(source.document), - location=source.location, - content=source.content, - num_tokens=source.num_tokens, - ) - for source in message.sources - ], - timestamp=message.timestamp, - ) - for message in chat.messages - ] - return schemas.Chat( - id=chat.id, - metadata=schemas.ChatMetadata( - name=chat.name, - documents=documents, - source_storage=chat.source_storage, - assistant=chat.assistant, - params=chat.params, - ), - messages=messages, - prepared=chat.prepared, - ) - - -def _select_chat(*, eager: bool = False) -> Any: - selector = select(orm.Chat) - if eager: - selector = selector.options( # type: ignore[attr-defined] - joinedload(orm.Chat.messages).joinedload(orm.Message.sources), - joinedload(orm.Chat.documents), - ) - return selector - - -def get_chats(session: Session, *, user: str) -> list[schemas.Chat]: - return [ - _orm_to_schema_chat(chat) - for chat in session.execute( - _select_chat(eager=True).where( - orm.Chat.user_id == _get_user_id(session, user) - ) - ) - .scalars() - .unique() - .all() - ] - - -def _get_orm_chat( - session: Session, *, user: str, id: uuid.UUID, eager: bool = False -) -> orm.Chat: - chat: Optional[orm.Chat] = ( - session.execute( - _select_chat(eager=eager).where( - (orm.Chat.id == id) & (orm.Chat.user_id == _get_user_id(session, user)) - ) - ) - .unique() - .scalar_one_or_none() - ) - if chat is None: - raise RagnaException() - return chat - - -def get_chat(session: Session, *, user: str, id: uuid.UUID) -> schemas.Chat: - return _orm_to_schema_chat(_get_orm_chat(session, user=user, id=id, eager=True)) - - -def _schema_to_orm_source(session: Session, source: schemas.Source) -> orm.Source: - orm_source: Optional[orm.Source] = session.execute( - select(orm.Source).where(orm.Source.id == source.id) - ).scalar_one_or_none() - - if orm_source is None: - orm_source = orm.Source( - id=source.id, - document_id=source.document.id, - location=source.location, - content=source.content, - num_tokens=source.num_tokens, - ) - session.add(orm_source) - session.commit() - session.refresh(orm_source) - - return orm_source - - -def _schema_to_orm_message( - session: Session, chat_id: uuid.UUID, message: schemas.Message -) -> orm.Message: - orm_message: Optional[orm.Message] = session.execute( - select(orm.Message).where(orm.Message.id == message.id) - ).scalar_one_or_none() - if orm_message is None: - orm_message = orm.Message( - id=message.id, - chat_id=chat_id, - content=message.content, - role=message.role, - sources=[ - _schema_to_orm_source(session, source=source) - for source in message.sources - ], - timestamp=message.timestamp, - ) - session.add(orm_message) - session.commit() - session.refresh(orm_message) - - return orm_message - - -def update_chat(session: Session, user: str, chat: schemas.Chat) -> None: - orm_chat = _get_orm_chat(session, user=user, id=chat.id) - - orm_chat.prepared = chat.prepared - orm_chat.messages = [ # type: ignore[assignment] - _schema_to_orm_message(session, chat_id=chat.id, message=message) - for message in chat.messages - ] - - session.commit() - - -def delete_chat(session: Session, user: str, id: uuid.UUID) -> None: - orm_chat = _get_orm_chat(session, user=user, id=id) - session.delete(orm_chat) # type: ignore[no-untyped-call] - session.commit() diff --git a/ragna/deploy/_core.py b/ragna/deploy/_core.py index fa516e13..4bc37477 100644 --- a/ragna/deploy/_core.py +++ b/ragna/deploy/_core.py @@ -9,6 +9,7 @@ from ._api import make_router as make_api_router from ._config import Config +from ._engine import Engine from ._ui import app as make_ui_app from ._utils import handle_localhost_origins, redirect, set_redirect_root_path @@ -33,14 +34,13 @@ def make_app( allow_headers=["*"], ) + engine = Engine( + config=config, + ignore_unavailable_components=ignore_unavailable_components, + ) + if api: - app.include_router( - make_api_router( - config, - ignore_unavailable_components=ignore_unavailable_components, - ), - prefix="/api", - ) + app.include_router(make_api_router(config, engine), prefix="/api") if ui: panel_app = make_ui_app(config=config) diff --git a/ragna/deploy/_database.py b/ragna/deploy/_database.py new file mode 100644 index 00000000..323ccd21 --- /dev/null +++ b/ragna/deploy/_database.py @@ -0,0 +1,274 @@ +from __future__ import annotations + +import uuid +from typing import Any, Optional +from urllib.parse import urlsplit + +from sqlalchemy import create_engine, select +from sqlalchemy.orm import Session, joinedload, sessionmaker + +from ragna.core import RagnaException + +from . import _orm as orm +from . import _schemas as schemas + + +class Database: + def __init__(self, url: str) -> None: + components = urlsplit(url) + if components.scheme == "sqlite": + connect_args = dict(check_same_thread=False) + else: + connect_args = dict() + engine = create_engine(url, connect_args=connect_args) + orm.Base.metadata.create_all(bind=engine) + + self.get_session = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + self._to_orm = SchemaToOrmConverter() + self._to_schema = OrmToSchemaConverter() + + def _get_user(self, session: Session, *, username: str) -> orm.User: + user: Optional[orm.User] = session.execute( + select(orm.User).where(orm.User.name == username) + ).scalar_one_or_none() + + if user is None: + # Add a new user if the current username is not registered yet. Since this + # is behind the authentication layer, we don't need any extra security here. + user = orm.User(id=uuid.uuid4(), name=username) + session.add(user) + session.commit() + + return user + + def add_document( + self, + session: Session, + *, + user: str, + document: schemas.Document, + metadata: dict[str, Any], + ) -> None: + session.add( + orm.Document( + id=document.id, + user_id=self._get_user(session, username=user).id, + name=document.name, + metadata_=metadata, + ) + ) + session.commit() + + def add_documents( + self, + session: Session, + *, + user: str, + document_metadata_collection: list[tuple[schemas.Document, dict[str, Any]]], + ) -> None: + """ + Add multiple documents to the database. + + This function allows adding multiple documents at once by calling `add_all`. This is + important when there is non-negligible latency attached to each database operation. + """ + documents = [ + orm.Document( + id=document.id, + user_id=self._get_user(session, username=user).id, + name=document.name, + metadata_=metadata, + ) + for document, metadata in document_metadata_collection + ] + session.add_all(documents) + session.commit() + + def get_document( + self, session: Session, *, user: str, id: uuid.UUID + ) -> schemas.Document: + document = session.execute( + select(orm.Document).where( + (orm.Document.user_id == self._get_user(session, username=user).id) + & (orm.Document.id == id) + ) + ).scalar_one_or_none() + return self._to_schema.document(document) + + def add_chat(self, session: Session, *, user: str, chat: schemas.Chat) -> None: + document_ids = {document.id for document in chat.metadata.documents} + # FIXME also check if the user is allowed to access the documents? + documents = ( + session.execute( + select(orm.Document).where(orm.Document.id.in_(document_ids)) + ) + .scalars() + .all() + ) + if len(documents) != len(document_ids): + raise RagnaException( + str(document_ids - {document.id for document in documents}) + ) + + orm_chat = self._to_orm.chat( + chat, + user_id=self._get_user(session, username=user).id, + # We have to pass the documents here, because SQLAlchemy does not allow a + # second instance of orm.Document with the same primary key in the session. + documents=documents, + ) + session.add(orm_chat) + session.commit() + + def _select_chat(self, *, eager: bool = False) -> Any: + selector = select(orm.Chat) + if eager: + selector = selector.options( # type: ignore[attr-defined] + joinedload(orm.Chat.messages).joinedload(orm.Message.sources), + joinedload(orm.Chat.documents), + ) + return selector + + def get_chats(self, session: Session, *, user: str) -> list[schemas.Chat]: + return [ + self._to_schema.chat(chat) + for chat in session.execute( + self._select_chat(eager=True).where( + orm.Chat.user_id == self._get_user(session, username=user).id + ) + ) + .scalars() + .unique() + .all() + ] + + def _get_orm_chat( + self, session: Session, *, user: str, id: uuid.UUID, eager: bool = False + ) -> orm.Chat: + chat: Optional[orm.Chat] = ( + session.execute( + self._select_chat(eager=eager).where( + (orm.Chat.id == id) + & (orm.Chat.user_id == self._get_user(session, username=user).id) + ) + ) + .unique() + .scalar_one_or_none() + ) + if chat is None: + raise RagnaException() + return chat + + def get_chat(self, session: Session, *, user: str, id: uuid.UUID) -> schemas.Chat: + return self._to_schema.chat( + (self._get_orm_chat(session, user=user, id=id, eager=True)) + ) + + def update_chat(self, session: Session, user: str, chat: schemas.Chat) -> None: + orm_chat = self._to_orm.chat( + chat, user_id=self._get_user(session, username=user).id + ) + session.merge(orm_chat) + session.commit() + + def delete_chat(self, session: Session, user: str, id: uuid.UUID) -> None: + orm_chat = self._get_orm_chat(session, user=user, id=id) + session.delete(orm_chat) # type: ignore[no-untyped-call] + session.commit() + + +class SchemaToOrmConverter: + def document( + self, document: schemas.Document, *, user_id: uuid.UUID + ) -> orm.Document: + return orm.Document( + id=document.id, + user_id=user_id, + name=document.name, + metadata_=document.metadata, + ) + + def source(self, source: schemas.Source) -> orm.Source: + return orm.Source( + id=source.id, + document_id=source.document.id, + location=source.location, + content=source.content, + num_tokens=source.num_tokens, + ) + + def message(self, message: schemas.Message, *, chat_id: uuid.UUID) -> orm.Message: + return orm.Message( + id=message.id, + chat_id=chat_id, + content=message.content, + role=message.role, + sources=[self.source(source) for source in message.sources], + timestamp=message.timestamp, + ) + + def chat( + self, + chat: schemas.Chat, + *, + user_id: uuid.UUID, + documents: Optional[list[orm.Document]] = None, + ) -> orm.Chat: + if documents is None: + documents = [ + self.document(document, user_id=user_id) + for document in chat.metadata.documents + ] + return orm.Chat( + id=chat.id, + user_id=user_id, + name=chat.metadata.name, + documents=documents, + source_storage=chat.metadata.source_storage, + assistant=chat.metadata.assistant, + params=chat.metadata.params, + messages=[ + self.message(message, chat_id=chat.id) for message in chat.messages + ], + prepared=chat.prepared, + ) + + +class OrmToSchemaConverter: + def document(self, document: orm.Document) -> schemas.Document: + return schemas.Document( + id=document.id, name=document.name, metadata=document.metadata_ + ) + + def source(self, source: orm.Source) -> schemas.Source: + return schemas.Source( + id=source.id, + document=self.document(source.document), + location=source.location, + content=source.content, + num_tokens=source.num_tokens, + ) + + def message(self, message: orm.Message) -> schemas.Message: + return schemas.Message( + id=message.id, + role=message.role, # type: ignore[arg-type] + content=message.content, + sources=[self.source(source) for source in message.sources], + timestamp=message.timestamp, + ) + + def chat(self, chat: orm.Chat) -> schemas.Chat: + return schemas.Chat( + id=chat.id, + metadata=schemas.ChatMetadata( + name=chat.name, + documents=[self.document(document) for document in chat.documents], + source_storage=chat.source_storage, + assistant=chat.assistant, + params=chat.params, + ), + messages=[self.message(message) for message in chat.messages], + prepared=chat.prepared, + ) diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py new file mode 100644 index 00000000..847f7a93 --- /dev/null +++ b/ragna/deploy/_engine.py @@ -0,0 +1,205 @@ +import uuid +from typing import Any, AsyncIterator, Optional, Type + +from ragna import Rag, core +from ragna._compat import aiter, anext +from ragna.core._rag import SpecialChatParams +from ragna.deploy import Config + +from . import _schemas as schemas +from ._database import Database + + +class Engine: + def __init__(self, *, config: Config, ignore_unavailable_components: bool) -> None: + self._config = config + + self._database = Database(url=config.database_url) + + self._rag: Rag = Rag( + config=config, + ignore_unavailable_components=ignore_unavailable_components, + ) + + self._to_core = SchemaToCoreConverter(config=config, rag=self._rag) + self._to_schema = CoreToSchemaConverter() + + def _get_component_json_schema( + self, + component: Type[core.Component], + ) -> dict[str, dict[str, Any]]: + json_schema = component._protocol_model().model_json_schema() + # FIXME: there is likely a better way to exclude certain fields builtin in + # pydantic + for special_param in SpecialChatParams.model_fields: + if ( + "properties" in json_schema + and special_param in json_schema["properties"] + ): + del json_schema["properties"][special_param] + if "required" in json_schema and special_param in json_schema["required"]: + json_schema["required"].remove(special_param) + return json_schema + + def get_components(self) -> schemas.Components: + return schemas.Components( + documents=sorted(self._config.document.supported_suffixes()), + source_storages=[ + self._get_component_json_schema(source_storage) + for source_storage in self._rag._components.keys() + if issubclass(source_storage, core.SourceStorage) + ], + assistants=[ + self._get_component_json_schema(assistant) + for assistant in self._rag._components.keys() + if issubclass(assistant, core.Assistant) + ], + ) + + def create_chat( + self, *, user: str, chat_metadata: schemas.ChatMetadata + ) -> schemas.Chat: + chat = schemas.Chat(metadata=chat_metadata) + + # Although we don't need the actual core.Chat here, this just performs the input + # validation. + self._to_core.chat(chat, user=user) + + with self._database.get_session() as session: + self._database.add_chat(session, user=user, chat=chat) + + return chat + + def get_chats(self, *, user: str) -> list[schemas.Chat]: + with self._database.get_session() as session: + return self._database.get_chats(session, user=user) + + def get_chat(self, *, user: str, id: uuid.UUID) -> schemas.Chat: + with self._database.get_session() as session: + return self._database.get_chat(session, user=user, id=id) + + async def prepare_chat(self, *, user: str, id: uuid.UUID) -> schemas.Message: + core_chat = self._to_core.chat(self.get_chat(user=user, id=id), user=user) + core_message = await core_chat.prepare() + + with self._database.get_session() as session: + self._database.update_chat( + session, chat=self._to_schema.chat(core_chat), user=user + ) + + return self._to_schema.message(core_message) + + async def answer_stream( + self, *, user: str, chat_id: uuid.UUID, prompt: str + ) -> AsyncIterator[schemas.Message]: + core_chat = self._to_core.chat(self.get_chat(user=user, id=chat_id), user=user) + core_message = await core_chat.answer(prompt, stream=True) + + content_stream = aiter(core_message) + content_chunk = await anext(content_stream) + message = self._to_schema.message(core_message, content_override=content_chunk) + yield message + + # Avoid sending the sources multiple times + message_chunk = message.model_copy(update=dict(sources=None)) + async for content_chunk in content_stream: + message_chunk.content = content_chunk + yield message_chunk + + with self._database.get_session() as session: + self._database.update_chat( + session, chat=self._to_schema.chat(core_chat), user=user + ) + + def delete_chat(self, *, user: str, id: uuid.UUID) -> None: + with self._database.get_session() as session: + self._database.delete_chat(session, user=user, id=id) + + +class SchemaToCoreConverter: + def __init__(self, *, config: Config, rag: Rag) -> None: + self._config = config + self._rag = rag + + def document(self, document: schemas.Document) -> core.Document: + return self._config.document( + id=document.id, + name=document.name, + metadata=document.metadata, + ) + + def source(self, source: schemas.Source) -> core.Source: + return core.Source( + id=source.id, + document=self.document(source.document), + location=source.location, + content=source.content, + num_tokens=source.num_tokens, + ) + + def message(self, message: schemas.Message) -> core.Message: + return core.Message( + message.content, + role=message.role, + sources=[self.source(source) for source in message.sources], + ) + + def chat(self, chat: schemas.Chat, *, user: str) -> core.Chat: + core_chat = self._rag.chat( + documents=[self.document(document) for document in chat.metadata.documents], + source_storage=chat.metadata.source_storage, + assistant=chat.metadata.assistant, + user=user, + chat_id=chat.id, + chat_name=chat.metadata.name, + **chat.metadata.params, + ) + core_chat._messages = [self.message(message) for message in chat.messages] + core_chat._prepared = chat.prepared + + return core_chat + + +class CoreToSchemaConverter: + def document(self, document: core.Document) -> schemas.Document: + return schemas.Document( + id=document.id, + name=document.name, + metadata=document.metadata, + ) + + def source(self, source: core.Source) -> schemas.Source: + return schemas.Source( + id=source.id, + document=self.document(source.document), + location=source.location, + content=source.content, + num_tokens=source.num_tokens, + ) + + def message( + self, message: core.Message, *, content_override: Optional[str] = None + ) -> schemas.Message: + return schemas.Message( + id=message.id, + content=content_override or message.content, + role=message.role, + sources=[self.source(source) for source in message.sources], + timestamp=message.timestamp, + ) + + def chat(self, chat: core.Chat) -> schemas.Chat: + params = chat.params.copy() + del params["user"] + return schemas.Chat( + id=params.pop("chat_id"), + metadata=schemas.ChatMetadata( + name=params.pop("chat_name"), + source_storage=chat.source_storage.display_name(), + assistant=chat.assistant.display_name(), + params=params, + documents=[self.document(document) for document in chat.documents], + ), + messages=[self.message(message) for message in chat._messages], + prepared=chat._prepared, + ) diff --git a/ragna/deploy/_api/orm.py b/ragna/deploy/_orm.py similarity index 100% rename from ragna/deploy/_api/orm.py rename to ragna/deploy/_orm.py diff --git a/ragna/deploy/_api/schemas.py b/ragna/deploy/_schemas.py similarity index 62% rename from ragna/deploy/_api/schemas.py rename to ragna/deploy/_schemas.py index 53957a74..55ae333f 100644 --- a/ragna/deploy/_api/schemas.py +++ b/ragna/deploy/_schemas.py @@ -18,13 +18,7 @@ class Components(BaseModel): class Document(BaseModel): id: uuid.UUID = Field(default_factory=uuid.uuid4) name: str - - @classmethod - def from_core(cls, document: ragna.core.Document) -> Document: - return cls( - id=document.id, - name=document.name, - ) + metadata: dict[str, Any] = Field(default_factory=dict) class DocumentUpload(BaseModel): @@ -40,16 +34,6 @@ class Source(BaseModel): content: str num_tokens: int - @classmethod - def from_core(cls, source: ragna.core.Source) -> Source: - return cls( - id=source.id, - document=Document.from_core(source.document), - location=source.location, - content=source.content, - num_tokens=source.num_tokens, - ) - class Message(BaseModel): id: uuid.UUID = Field(default_factory=uuid.uuid4) @@ -58,14 +42,6 @@ class Message(BaseModel): sources: list[Source] = Field(default_factory=list) timestamp: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) - @classmethod - def from_core(cls, message: ragna.core.Message) -> Message: - return cls( - content=message.content, - role=message.role, - sources=[Source.from_core(source) for source in message.sources], - ) - class ChatMetadata(BaseModel): name: str diff --git a/tests/deploy/api/conftest.py b/tests/deploy/api/conftest.py new file mode 100644 index 00000000..4bc8053c --- /dev/null +++ b/tests/deploy/api/conftest.py @@ -0,0 +1,41 @@ +import contextlib +import json + +import httpx +import pytest + + +@pytest.fixture(scope="package", autouse=True) +def enhance_raise_for_status(package_mocker): + raise_for_status = httpx.Response.raise_for_status + + def enhanced_raise_for_status(self): + __tracebackhide__ = True + + try: + return raise_for_status(self) + except httpx.HTTPStatusError as error: + content = None + with contextlib.suppress(Exception): + content = error.response.read() + content = content.decode() + content = "\n" + json.dumps(json.loads(content), indent=2) + + if content is None: + raise error + + message = f"{error}\nResponse content: {content}" + raise httpx.HTTPStatusError( + message, request=error.request, response=error.response + ) from None + + yield package_mocker.patch( + ".".join( + [ + httpx.Response.__module__, + httpx.Response.__name__, + raise_for_status.__name__, + ] + ), + new=enhanced_raise_for_status, + )