Skip to content

Commit

Permalink
introduce engine for API (#434)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Jul 5, 2024
1 parent 425b8b8 commit 47d1d37
Show file tree
Hide file tree
Showing 25 changed files with 892 additions and 1,096 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ ignore_missing_imports = true

[[tool.mypy.overrides]]
module = [
"ragna.deploy._api.orm",
"ragna.deploy._orm",
]
# Our ORM schema doesn't really work with mypy. There are some other ways to define it
# to play ball. We should do that in the future.
Expand Down
16 changes: 2 additions & 14 deletions ragna/_cli/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from pathlib import Path
from typing import Annotated, Optional

import httpx
import rich
import typer
import uvicorn
Expand Down Expand Up @@ -74,13 +73,13 @@ def deploy(
*,
config: ConfigOption = "./ragna.toml", # type: ignore[assignment]
api: Annotated[
Optional[bool],
bool,
typer.Option(
"--api/--no-api",
help="Deploy the Ragna REST API.",
show_default="True if UI is not deployed and otherwise check availability",
),
] = None,
] = True,
ui: Annotated[
bool,
typer.Option(
Expand All @@ -101,19 +100,8 @@ def deploy(
typer.Option(help="Open a browser when Ragna is deployed."),
] = None,
) -> None:
def api_available() -> bool:
try:
return httpx.get(f"{config._url}/health").is_success
except httpx.ConnectError:
return False

if api is None:
api = not api_available() if ui else True

if not (api or ui):
raise Exception
elif ui and not api and not api_available():
raise Exception

if open_browser is None:
open_browser = ui
Expand Down
12 changes: 12 additions & 0 deletions ragna/core/_components.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import abc
import datetime
import enum
import functools
import inspect
import uuid
from typing import AsyncIterable, AsyncIterator, Iterator, Optional, Type, Union

import pydantic
Expand Down Expand Up @@ -157,6 +159,8 @@ def __init__(
*,
role: MessageRole = MessageRole.SYSTEM,
sources: Optional[list[Source]] = None,
id: Optional[uuid.UUID] = None,
timestamp: Optional[datetime.datetime] = None,
) -> None:
if isinstance(content, str):
self._content: str = content
Expand All @@ -166,6 +170,14 @@ def __init__(
self.role = role
self.sources = sources or []

if id is None:
id = uuid.uuid4()
self.id = id

if timestamp is None:
timestamp = datetime.datetime.utcnow()
self.timestamp = timestamp

async def __aiter__(self) -> AsyncIterator[str]:
if hasattr(self, "_content"):
yield self._content
Expand Down
7 changes: 6 additions & 1 deletion ragna/core/_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import jwt
from pydantic import BaseModel

import ragna

from ._utils import PackageRequirement, RagnaException, Requirement, RequirementsMixin

if TYPE_CHECKING:
Expand All @@ -24,6 +26,7 @@ class DocumentUploadParameters(BaseModel):
data: dict


# FIXME: this needs to become what local root is now
class Document(RequirementsMixin, abc.ABC):
"""Abstract base class for all documents."""

Expand Down Expand Up @@ -124,7 +127,9 @@ def from_path(

@property
def path(self) -> Path:
return Path(self.metadata["path"])
# FIXME
return ragna.local_root() / "documents" / str(self.id)
# return Path(self.metadata["path"])

def is_readable(self) -> bool:
return self.path.exists()
Expand Down
87 changes: 70 additions & 17 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect
import uuid
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
Expand All @@ -12,21 +13,24 @@
Iterable,
Iterator,
Optional,
Type,
TypeVar,
Union,
cast,
)

import pydantic
from fastapi import status
from starlette.concurrency import iterate_in_threadpool, run_in_threadpool

from ._components import Assistant, Component, Message, MessageRole, SourceStorage
from ._document import Document, LocalDocument
from ._utils import RagnaException, default_user, merge_models

if TYPE_CHECKING:
from ragna.deploy import Config

T = TypeVar("T")
C = TypeVar("C", bound=Component)
C = TypeVar("C", bound=Component, covariant=True)


class Rag(Generic[C]):
Expand All @@ -41,20 +45,69 @@ class Rag(Generic[C]):
```
"""

def __init__(self) -> None:
self._components: dict[Type[C], C] = {}
def __init__(
self,
*,
config: Optional[Config] = None,
ignore_unavailable_components: bool = False,
) -> None:
self._components: dict[type[C], C] = {}
self._display_name_map: dict[str, type[C]] = {}

if config is not None:
self._preload_components(
config=config,
ignore_unavailable_components=ignore_unavailable_components,
)

def _preload_components(
self, *, config: Config, ignore_unavailable_components: bool
) -> None:
for components in [config.source_storages, config.assistants]:
components = cast(list[type[Component]], components)
at_least_one = False
for component in components:
loaded_component = self._load_component(
component, # type: ignore[arg-type]
ignore_unavailable=ignore_unavailable_components,
)
if loaded_component is None:
print(
f"Ignoring {component.display_name()}, because it is not available."
)
else:
at_least_one = True

if not at_least_one:
raise RagnaException(
"No component available",
components=[component.display_name() for component in components],
)

def _load_component(
self, component: Union[Type[C], C], *, ignore_unavailable: bool = False
self, component: Union[C, type[C], str], *, ignore_unavailable: bool = False
) -> Optional[C]:
cls: Type[C]
cls: type[C]
instance: Optional[C]

if isinstance(component, Component):
instance = cast(C, component)
cls = type(instance)
elif isinstance(component, type) and issubclass(component, Component):
cls = component
instance = None
elif isinstance(component, str):
try:
cls = self._display_name_map[component]
except KeyError:
raise RagnaException(
"Unknown component",
display_name=component,
help="Did you forget to create the Rag() instance with a config?",
http_status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
http_detail=f"Unknown component '{component}'",
) from None

instance = None
else:
raise RagnaException
Expand All @@ -71,31 +124,33 @@ def _load_component(
instance = cls()

self._components[cls] = instance
self._display_name_map[cls.display_name()] = cls

return self._components[cls]

def chat(
self,
*,
documents: Iterable[Any],
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
source_storage: Union[SourceStorage, type[SourceStorage], str],
assistant: Union[Assistant, type[Assistant], str],
**params: Any,
) -> Chat:
"""Create a new [ragna.core.Chat][].
Args:
documents: Documents to use. If any item is not a [ragna.core.Document][],
[ragna.core.LocalDocument.from_path][] is invoked on it.
FIXME
source_storage: Source storage to use.
assistant: Assistant to use.
**params: Additional parameters passed to the source storage and assistant.
"""
return Chat(
self,
documents=documents,
source_storage=source_storage,
assistant=assistant,
source_storage=cast(SourceStorage, self._load_component(source_storage)), # type: ignore[arg-type]
assistant=cast(Assistant, self._load_component(assistant)), # type: ignore[arg-type]
**params,
)

Expand Down Expand Up @@ -146,17 +201,15 @@ def __init__(
rag: Rag,
*,
documents: Iterable[Any],
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
source_storage: SourceStorage,
assistant: Assistant,
**params: Any,
) -> None:
self._rag = rag

self.documents = self._parse_documents(documents)
self.source_storage = cast(
SourceStorage, self._rag._load_component(source_storage)
)
self.assistant = cast(Assistant, self._rag._load_component(assistant))
self.source_storage = source_storage
self.assistant = assistant

special_params = SpecialChatParams().model_dump()
special_params.update(params)
Expand Down Expand Up @@ -306,6 +359,6 @@ async def __aenter__(self) -> Chat:
return self

async def __aexit__(
self, exc_type: Type[Exception], exc: Exception, traceback: str
self, exc_type: type[Exception], exc: Exception, traceback: str
) -> None:
pass
Loading

0 comments on commit 47d1d37

Please sign in to comment.