Skip to content

Commit

Permalink
Create edgedb.ai package (#489)
Browse files Browse the repository at this point in the history
  • Loading branch information
fantix authored May 1, 2024
1 parent c30f062 commit e38833f
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 0 deletions.
32 changes: 32 additions & 0 deletions edgedb/ai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2024-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from .types import AIOptions, ChatParticipantRole, Prompt, QueryContext
from .core import create_ai, EdgeDBAI
from .core import create_async_ai, AsyncEdgeDBAI

__all__ = [
"AIOptions",
"ChatParticipantRole",
"Prompt",
"QueryContext",
"create_ai",
"EdgeDBAI",
"create_async_ai",
"AsyncEdgeDBAI",
]
174 changes: 174 additions & 0 deletions edgedb/ai/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2024-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations
import typing

import edgedb
import httpx
import httpx_sse

from . import types


def create_ai(client: edgedb.Client, **kwargs) -> EdgeDBAI:
client.ensure_connected()
return EdgeDBAI(client, types.AIOptions(**kwargs))


async def create_async_ai(
client: edgedb.AsyncIOClient, **kwargs
) -> AsyncEdgeDBAI:
await client.ensure_connected()
return AsyncEdgeDBAI(client, types.AIOptions(**kwargs))


class BaseEdgeDBAI:
options: types.AIOptions
context: types.QueryContext
client_cls = NotImplemented

def __init__(
self,
client: typing.Union[edgedb.Client, edgedb.AsyncIOClient],
options: types.AIOptions,
**kwargs,
):
pool = client._impl
host, port = pool._working_addr
params = pool._working_params
proto = "http" if params.tls_security == "insecure" else "https"
branch = params.branch
self.options = options
self.context = types.QueryContext(**kwargs)
args = dict(
base_url=f"{proto}://{host}:{port}/branch/{branch}/ext/ai",
verify=params.ssl_ctx,
)
if params.password is not None:
args["auth"] = (params.user, params.password)
elif params.secret_key is not None:
args["headers"] = {"Authorization": f"Bearer {params.secret_key}"}
self._init_client(**args)

def _init_client(self, **kwargs):
raise NotImplementedError

def with_config(self, **kwargs) -> typing.Self:
cls = type(self)
rv = cls.__new__(cls)
rv.options = self.options.derive(kwargs)
rv.context = self.context
rv.client = self.client
return rv

def with_context(self, **kwargs) -> typing.Self:
cls = type(self)
rv = cls.__new__(cls)
rv.options = self.options
rv.context = self.context.derive(kwargs)
rv.client = self.client
return rv


class EdgeDBAI(BaseEdgeDBAI):
client: httpx.Client

def _init_client(self, **kwargs):
self.client = httpx.Client(**kwargs)

def query_rag(
self, message: str, context: typing.Optional[types.QueryContext] = None
) -> str:
if context is None:
context = self.context
resp = self.client.post(
**types.RAGRequest(
model=self.options.model,
prompt=self.options.prompt,
context=context,
query=message,
stream=False,
).to_httpx_request()
)
resp.raise_for_status()
return resp.json()["response"]

def stream_rag(
self, message: str, context: typing.Optional[types.QueryContext] = None
):
if context is None:
context = self.context
with httpx_sse.connect_sse(
self.client,
"post",
**types.RAGRequest(
model=self.options.model,
prompt=self.options.prompt,
context=context,
query=message,
stream=True,
).to_httpx_request(),
) as event_source:
event_source.response.raise_for_status()
for sse in event_source.iter_sse():
yield sse.data


class AsyncEdgeDBAI(BaseEdgeDBAI):
client: httpx.AsyncClient

def _init_client(self, **kwargs):
self.client = httpx.AsyncClient(**kwargs)

async def query_rag(
self, message: str, context: typing.Optional[types.QueryContext] = None
) -> str:
if context is None:
context = self.context
resp = await self.client.post(
**types.RAGRequest(
model=self.options.model,
prompt=self.options.prompt,
context=context,
query=message,
stream=False,
).to_httpx_request()
)
resp.raise_for_status()
return resp.json()["response"]

async def stream_rag(
self, message: str, context: typing.Optional[types.QueryContext] = None
):
if context is None:
context = self.context
async with httpx_sse.aconnect_sse(
self.client,
"post",
**types.RAGRequest(
model=self.options.model,
prompt=self.options.prompt,
context=context,
query=message,
stream=True,
).to_httpx_request(),
) as event_source:
event_source.response.raise_for_status()
async for sse in event_source.aiter_sse():
yield sse.data
81 changes: 81 additions & 0 deletions edgedb/ai/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2024-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import typing

import dataclasses as dc
import enum


class ChatParticipantRole(enum.Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"


class Custom(typing.TypedDict):
role: ChatParticipantRole
content: str


class Prompt:
name: typing.Optional[str]
id: typing.Optional[str]
custom: typing.Optional[typing.List[Custom]]


@dc.dataclass
class AIOptions:
model: str
prompt: typing.Optional[Prompt] = None

def derive(self, kwargs):
return AIOptions(**{**dc.asdict(self), **kwargs})


@dc.dataclass
class QueryContext:
query: str = ""
variables: typing.Optional[typing.Dict[str, typing.Any]] = None
globals: typing.Optional[typing.Dict[str, typing.Any]] = None
max_object_count: typing.Optional[int] = None

def derive(self, kwargs):
return QueryContext(**{**dc.asdict(self), **kwargs})


@dc.dataclass
class RAGRequest:
model: str
prompt: typing.Optional[Prompt]
context: QueryContext
query: str
stream: typing.Optional[bool]

def to_httpx_request(self) -> typing.Dict[str, typing.Any]:
return dict(
url="/rag",
headers={
"Content-Type": "application/json",
"Accept": (
"text/event-stream" if self.stream else "application/json"
),
},
json=dc.asdict(self),
)
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,13 @@
'sphinx_rtd_theme~=1.0.0',
]

AI_DEPENDENCIES = [
'httpx~=0.27.0',
'httpx-sse~=0.4.0',
]

EXTRA_DEPENDENCIES = {
'ai': AI_DEPENDENCIES,
'docs': DOC_DEPENDENCIES,
'test': TEST_DEPENDENCIES,
# Dependencies required to develop edgedb.
Expand Down

0 comments on commit e38833f

Please sign in to comment.