Skip to content

Commit

Permalink
feat: Update interceptor_to_embeddings_handler (#5)
Browse files Browse the repository at this point in the history
Co-authored-by: Ihor Lahutin <ihor_lahutin@epam.com>
  • Loading branch information
igorlagutin and Ihor Lahutin authored Oct 2, 2024
1 parent 67ab88b commit e948187
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 26 deletions.
2 changes: 1 addition & 1 deletion aidial_interceptors_sdk/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .adapter import interceptor_to_embeddings_handler
from .adapter import interceptor_to_embeddings
from .base import EmbeddingsInterceptor, EmbeddingsNoOpInterceptor
51 changes: 30 additions & 21 deletions aidial_interceptors_sdk/embeddings/adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Type

from fastapi import Request
from aidial_sdk.embeddings import Embeddings
from aidial_sdk.embeddings.request import Request
from aidial_sdk.embeddings.response import Response
from openai.types import CreateEmbeddingResponse

from aidial_interceptors_sdk.dial_client import DialClient
Expand All @@ -10,29 +12,36 @@
from aidial_interceptors_sdk.utils._reflection import call_with_extra_body


def interceptor_to_embeddings_handler(cls: Type[EmbeddingsInterceptor]):
@dial_exception_decorator
async def _handler(request: Request) -> dict:
def interceptor_to_embeddings(cls: Type[EmbeddingsInterceptor]) -> Embeddings:

dial_client = await DialClient.create(
api_key=request.headers.get("api-key"),
api_version=request.query_params.get("api-version"),
authorization=request.headers.get("authorization"),
)
class Impl(Embeddings):
@dial_exception_decorator
async def embeddings(self, request: Request) -> Response:

interceptor = cls(dial_client=dial_client, **request.path_params)
dial_client = await DialClient.create(
api_key=request.api_key,
api_version=request.api_version,
authorization=request.jwt,
)

body = await request.json()
body = await debug_logging("request")(interceptor.modify_request)(body)
interceptor = cls(
dial_client=dial_client,
**request.original_request.path_params,
)

response: CreateEmbeddingResponse = await call_with_extra_body(
dial_client.client.embeddings.create, body
)
body = await request.original_request.json()
body = await debug_logging("request")(interceptor.modify_request)(
body
)

resp = response.to_dict()
resp = await debug_logging("response")(interceptor.modify_response)(
resp
)
return resp
response: CreateEmbeddingResponse = await call_with_extra_body(
dial_client.client.embeddings.create, body
)

return _handler
response_dict = await debug_logging("response")(
interceptor.modify_response
)(response.to_dict())

return Response.parse_obj(response_dict)

return Impl()
6 changes: 2 additions & 4 deletions aidial_interceptors_sdk/examples/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from aidial_interceptors_sdk.chat_completion import (
interceptor_to_chat_completion,
)
from aidial_interceptors_sdk.embeddings import interceptor_to_embeddings_handler
from aidial_interceptors_sdk.embeddings.adapter import interceptor_to_embeddings
from aidial_interceptors_sdk.examples.registry import (
chat_completion_interceptors,
embeddings_interceptors,
Expand All @@ -23,9 +23,7 @@
configure_loggers()

for id, cls in embeddings_interceptors.items():
app.post(f"/openai/deployments/{id}/embeddings")(
interceptor_to_embeddings_handler(cls)
)
app.add_embeddings(id, interceptor_to_embeddings(cls))

for id, cls in chat_completion_interceptors.items():
app.add_chat_completion(id, interceptor_to_chat_completion(cls))

0 comments on commit e948187

Please sign in to comment.