Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce engine for API #434

Merged
merged 4 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ ignore_missing_imports = true

[[tool.mypy.overrides]]
module = [
"ragna.deploy._api.orm",
"ragna.deploy._orm",
]
# Our ORM schema doesn't really work with mypy. There are some other ways to define it
# to play ball. We should do that in the future.
Expand Down
12 changes: 12 additions & 0 deletions ragna/core/_components.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import abc
import datetime
import enum
import functools
import inspect
import uuid
from typing import AsyncIterable, AsyncIterator, Iterator, Optional, Type, Union

import pydantic
Expand Down Expand Up @@ -157,6 +159,8 @@ def __init__(
*,
role: MessageRole = MessageRole.SYSTEM,
sources: Optional[list[Source]] = None,
id: Optional[uuid.UUID] = None,
timestamp: Optional[datetime.datetime] = None,
Comment on lines +162 to +163
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was kinda weird to not have these fields here given that we have them in the database as well as in the schema. Without them, converting back and forth became harder, because we'd need to keep track of this information manually.

) -> None:
if isinstance(content, str):
self._content: str = content
Expand All @@ -166,6 +170,14 @@ def __init__(
self.role = role
self.sources = sources or []

if id is None:
id = uuid.uuid4()
self.id = id

if timestamp is None:
timestamp = datetime.datetime.utcnow()
self.timestamp = timestamp

async def __aiter__(self) -> AsyncIterator[str]:
if hasattr(self, "_content"):
yield self._content
Expand Down
87 changes: 70 additions & 17 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect
import uuid
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
Expand All @@ -12,21 +13,24 @@
Iterable,
Iterator,
Optional,
Type,
TypeVar,
Union,
cast,
)

import pydantic
from fastapi import status
from starlette.concurrency import iterate_in_threadpool, run_in_threadpool

from ._components import Assistant, Component, Message, MessageRole, SourceStorage
from ._document import Document, LocalDocument
from ._utils import RagnaException, default_user, merge_models

if TYPE_CHECKING:
from ragna.deploy import Config

T = TypeVar("T")
C = TypeVar("C", bound=Component)
C = TypeVar("C", bound=Component, covariant=True)


class Rag(Generic[C]):
Expand All @@ -41,20 +45,69 @@ class Rag(Generic[C]):
```
"""

def __init__(self) -> None:
self._components: dict[Type[C], C] = {}
def __init__(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two new functionalities for the ragna.Rag object:

  1. When creating a new chat with Rag().chat(), one now can also pass the display name of any loaded component instead of passing an instance or the class.
  2. The object can now be created with a configuration object, which loads all configured source storages and assistants.

We had this before, but only for the API. While refactoring, it made the code a lot cleaner when moving it here.

self,
*,
config: Optional[Config] = None,
ignore_unavailable_components: bool = False,
) -> None:
self._components: dict[type[C], C] = {}
self._display_name_map: dict[str, type[C]] = {}

if config is not None:
self._preload_components(
config=config,
ignore_unavailable_components=ignore_unavailable_components,
)

def _preload_components(
self, *, config: Config, ignore_unavailable_components: bool
) -> None:
for components in [config.source_storages, config.assistants]:
components = cast(list[type[Component]], components)
at_least_one = False
for component in components:
loaded_component = self._load_component(
component, # type: ignore[arg-type]
ignore_unavailable=ignore_unavailable_components,
)
if loaded_component is None:
print(
f"Ignoring {component.display_name()}, because it is not available."
)
else:
at_least_one = True

if not at_least_one:
raise RagnaException(
"No component available",
components=[component.display_name() for component in components],
)

def _load_component(
self, component: Union[Type[C], C], *, ignore_unavailable: bool = False
self, component: Union[C, type[C], str], *, ignore_unavailable: bool = False
) -> Optional[C]:
cls: Type[C]
cls: type[C]
instance: Optional[C]

if isinstance(component, Component):
instance = cast(C, component)
cls = type(instance)
elif isinstance(component, type) and issubclass(component, Component):
cls = component
instance = None
elif isinstance(component, str):
try:
cls = self._display_name_map[component]
except KeyError:
raise RagnaException(
"Unknown component",
display_name=component,
help="Did you forget to create the Rag() instance with a config?",
http_status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
http_detail=f"Unknown component '{component}'",
) from None

instance = None
else:
raise RagnaException
Expand All @@ -71,31 +124,33 @@ def _load_component(
instance = cls()

self._components[cls] = instance
self._display_name_map[cls.display_name()] = cls

return self._components[cls]

def chat(
self,
*,
documents: Iterable[Any],
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
source_storage: Union[SourceStorage, type[SourceStorage], str],
assistant: Union[Assistant, type[Assistant], str],
**params: Any,
) -> Chat:
"""Create a new [ragna.core.Chat][].

Args:
documents: Documents to use. If any item is not a [ragna.core.Document][],
[ragna.core.LocalDocument.from_path][] is invoked on it.
FIXME
source_storage: Source storage to use.
assistant: Assistant to use.
**params: Additional parameters passed to the source storage and assistant.
"""
return Chat(
self,
documents=documents,
source_storage=source_storage,
assistant=assistant,
source_storage=cast(SourceStorage, self._load_component(source_storage)), # type: ignore[arg-type]
assistant=cast(Assistant, self._load_component(assistant)), # type: ignore[arg-type]
**params,
)

Expand Down Expand Up @@ -146,17 +201,15 @@ def __init__(
rag: Rag,
*,
documents: Iterable[Any],
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
source_storage: SourceStorage,
assistant: Assistant,
**params: Any,
) -> None:
self._rag = rag

self.documents = self._parse_documents(documents)
self.source_storage = cast(
SourceStorage, self._rag._load_component(source_storage)
)
self.assistant = cast(Assistant, self._rag._load_component(assistant))
self.source_storage = source_storage
self.assistant = assistant

special_params = SpecialChatParams().model_dump()
special_params.update(params)
Expand Down Expand Up @@ -306,6 +359,6 @@ async def __aenter__(self) -> Chat:
return self

async def __aexit__(
self, exc_type: Type[Exception], exc: Exception, traceback: str
self, exc_type: type[Exception], exc: Exception, traceback: str
) -> None:
pass
163 changes: 163 additions & 0 deletions ragna/deploy/_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import uuid
from typing import Annotated, AsyncIterator, cast

import aiofiles
import pydantic
from fastapi import (
APIRouter,
Body,
Depends,
Form,
HTTPException,
UploadFile,
)
from fastapi.responses import StreamingResponse

import ragna
import ragna.core
from ragna._compat import anext
from ragna.core._utils import default_user
from ragna.deploy import Config

from . import _schemas as schemas
from ._engine import Engine


def make_router(config: Config, engine: Engine) -> APIRouter:
router = APIRouter(tags=["API"])

def get_user() -> str:
return default_user()

UserDependency = Annotated[str, Depends(get_user)]

# TODO: the document endpoints do not go through the engine, because they'll change
# quite drastically when the UI no longer depends on the API
Comment on lines +34 to +35
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What the inline comment says. This will be addressed in the follow-up PR


_database = engine._database

@router.post("/document")
async def create_document_upload_info(
user: UserDependency,
name: Annotated[str, Body(..., embed=True)],
) -> schemas.DocumentUpload:
with _database.get_session() as session:
document = schemas.Document(name=name)
metadata, parameters = await config.document.get_upload_info(
config=config, user=user, id=document.id, name=document.name
)
document.metadata = metadata
_database.add_document(
session, user=user, document=document, metadata=metadata
)
return schemas.DocumentUpload(parameters=parameters, document=document)

# TODO: Add UI support and documentation for this endpoint (#406)
@router.post("/documents")
async def create_documents_upload_info(
user: UserDependency,
names: Annotated[list[str], Body(..., embed=True)],
) -> list[schemas.DocumentUpload]:
with _database.get_session() as session:
document_metadata_collection = []
document_upload_collection = []
for name in names:
document = schemas.Document(name=name)
metadata, parameters = await config.document.get_upload_info(
config=config, user=user, id=document.id, name=document.name
)
document.metadata = metadata
document_metadata_collection.append((document, metadata))
document_upload_collection.append(
schemas.DocumentUpload(parameters=parameters, document=document)
)

_database.add_documents(
session,
user=user,
document_metadata_collection=document_metadata_collection,
)
return document_upload_collection

# TODO: Add new endpoint for batch uploading documents (#407)
@router.put("/document")
async def upload_document(
token: Annotated[str, Form()], file: UploadFile
) -> schemas.Document:
if not issubclass(config.document, ragna.core.LocalDocument):
raise HTTPException(
status_code=400,
detail="Ragna configuration does not support local upload",
)
with _database.get_session() as session:
user, id = ragna.core.LocalDocument.decode_upload_token(token)
document = _database.get_document(session, user=user, id=id)

core_document = cast(
ragna.core.LocalDocument, engine._to_core.document(document)
)
core_document.path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(core_document.path, "wb") as document_file:
while content := await file.read(1024):
await document_file.write(content)

return document

@router.get("/components")
def get_components(_: UserDependency) -> schemas.Components:
return engine.get_components()

@router.post("/chats")
async def create_chat(
user: UserDependency,
chat_metadata: schemas.ChatMetadata,
) -> schemas.Chat:
return engine.create_chat(user=user, chat_metadata=chat_metadata)

@router.get("/chats")
async def get_chats(user: UserDependency) -> list[schemas.Chat]:
return engine.get_chats(user=user)

@router.get("/chats/{id}")
async def get_chat(user: UserDependency, id: uuid.UUID) -> schemas.Chat:
return engine.get_chat(user=user, id=id)

@router.post("/chats/{id}/prepare")
async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message:
return await engine.prepare_chat(user=user, id=id)

@router.post("/chats/{id}/answer")
async def answer(
user: UserDependency,
id: uuid.UUID,
prompt: Annotated[str, Body(..., embed=True)],
stream: Annotated[bool, Body(..., embed=True)] = False,
) -> schemas.Message:
message_stream = engine.answer_stream(user=user, chat_id=id, prompt=prompt)
answer = await anext(message_stream)

if not stream:
content_chunks = [chunk.content async for chunk in message_stream]
answer.content += "".join(content_chunks)
return answer

async def message_chunks() -> AsyncIterator[schemas.Message]:
yield answer
async for chunk in message_stream:
yield chunk

async def to_jsonl(
models: AsyncIterator[pydantic.BaseModel],
) -> AsyncIterator[str]:
async for model in models:
yield f"{model.model_dump_json()}\n"

return StreamingResponse( # type: ignore[return-value]
to_jsonl(message_chunks())
)

@router.delete("/chats/{id}")
async def delete_chat(user: UserDependency, id: uuid.UUID) -> None:
engine.delete_chat(user=user, id=id)

return router
1 change: 0 additions & 1 deletion ragna/deploy/_api/__init__.py

This file was deleted.

Loading
Loading