Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Stream message tokens #52

Merged
merged 7 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions backend/app/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from datetime import datetime

from typing_extensions import TypedDict


class AssistantWithoutUserId(TypedDict):
"""Assistant model."""

assistant_id: str
"""The ID of the assistant."""
name: str
"""The name of the assistant."""
config: dict
"""The assistant config."""
updated_at: datetime
"""The last time the assistant was updated."""
public: bool
"""Whether the assistant is public."""


class Assistant(AssistantWithoutUserId):
"""Assistant model."""

user_id: str
"""The ID of the user that owns the assistant."""


class ThreadWithoutUserId(TypedDict):
thread_id: str
"""The ID of the thread."""
assistant_id: str
"""The assistant that was used in conjunction with this thread."""
name: str
"""The name of the thread."""
updated_at: datetime
"""The last time the thread was updated."""


class Thread(ThreadWithoutUserId):
"""Thread model."""

user_id: str
"""The ID of the user that owns the thread."""
235 changes: 209 additions & 26 deletions backend/app/server.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,36 @@
import asyncio
import json
from pathlib import Path
from typing import Annotated, Optional
from typing import Annotated, AsyncIterator, List, Optional, Sequence
from uuid import uuid4

import orjson
from fastapi import Cookie, FastAPI, Form, Request, UploadFile
from fastapi import BackgroundTasks, Cookie, FastAPI, Form, Query, Request, UploadFile
from fastapi.exceptions import RequestValidationError
from fastapi.staticfiles import StaticFiles
from gizmo_agent import agent, ingest_runnable
from langchain.pydantic_v1 import ValidationError
from langchain.schema.messages import AnyMessage, FunctionMessage
from langchain.schema.output import ChatGeneration
from langchain.schema.runnable import RunnableConfig
from langserve import add_routes
from typing_extensions import TypedDict
from langserve.callbacks import AsyncEventAggregatorCallback
from langserve.serialization import WellKnownLCSerializer
from langserve.server import _get_base_run_id_as_str, _unpack_input
from pydantic import BaseModel, Field
from sse_starlette import EventSourceResponse

from app.schema import Assistant, AssistantWithoutUserId, Thread, ThreadWithoutUserId
from app.storage import (
get_assistant,
get_thread_messages,
list_assistants,
list_public_assistants,
list_threads,
put_assistant,
put_thread,
)
from app.stream import StreamMessagesHandler

app = FastAPI()

Expand All @@ -28,11 +42,32 @@
# Get root of app, used to point to directory containing static files
ROOT = Path(__file__).parent.parent

OpengptsUserId = Annotated[
str,
Cookie(
description=(
"A cookie that identifies the user. This is not an authentication "
"mechanism that should be used in an actual production environment that "
"contains sensitive information."
)
),
]


def attach_user_id_to_config(
config: RunnableConfig,
request: Request,
) -> RunnableConfig:
"""Attach the user id to the runnable config.

Args:
config: The runnable config.
request: The request.

Returns:
A modified runnable config that contains information about the user
who made the request in the `configurable.user_id` field.
"""
config["configurable"]["user_id"] = request.cookies["opengpts_user_id"]
return config

Expand All @@ -45,71 +80,219 @@ def attach_user_id_to_config(
enable_feedback_endpoint=True,
)

serializer = WellKnownLCSerializer()


class AgentInput(BaseModel):
"""An input into an agent."""

messages: Sequence[AnyMessage]


class CreateRunPayload(BaseModel):
"""Payload for creating a run."""

assistant_id: str
thread_id: str
stream: bool
# TODO make optional
input: AgentInput


@app.post("/runs")
async def create_run_endpoint(
request: Request,
opengpts_user_id: Annotated[str, Cookie()],
background_tasks: BackgroundTasks,
):
"""Create a run."""
try:
body = await request.json()
except json.JSONDecodeError:
raise RequestValidationError(errors=["Invalid JSON body"])
assistant, state = await asyncio.gather(
asyncio.get_running_loop().run_in_executor(
None, get_assistant, opengpts_user_id, body["assistant_id"]
),
asyncio.get_running_loop().run_in_executor(
None, get_thread_messages, opengpts_user_id, body["thread_id"]
),
)
config: RunnableConfig = attach_user_id_to_config(
{
**assistant["config"],
"configurable": {
**assistant["config"]["configurable"],
"thread_id": body["thread_id"],
"assistant_id": body["assistant_id"],
},
},
request,
)
try:
input_ = _unpack_input(agent.get_input_schema(config).validate(body["input"]))
except ValidationError as e:
raise RequestValidationError(e.errors(), body=body)
if body["stream"]:
streamer = StreamMessagesHandler(state["messages"] + input_["messages"])
event_aggregator = AsyncEventAggregatorCallback()
config["callbacks"] = [streamer, event_aggregator]

# Call the runnable in streaming mode,
# add each chunk to the output stream
async def consume_astream() -> None:
try:
async for chunk in agent.astream(input_, config):
await streamer.send_stream.send(chunk)
# hack: function messages aren't generated by chat model
# so the callback handler doesn't know about them
message = chunk["messages"][-1]
if isinstance(message, FunctionMessage):
streamer.output[uuid4()] = ChatGeneration(message=message)
except Exception as e:
await streamer.send_stream.send(e)
finally:
await streamer.send_stream.aclose()

# Start the runnable in the background
task = asyncio.create_task(consume_astream())

# Consume the stream into an EventSourceResponse
async def _stream() -> AsyncIterator[dict]:
has_sent_metadata = False

async for chunk in streamer.receive_stream:
if isinstance(chunk, BaseException):
yield {
"event": "error",
# Do not expose the error message to the client since
# the message may contain sensitive information.
# We'll add client side errors for validation as well.
"data": orjson.dumps(
{"status_code": 500, "message": "Internal Server Error"}
).decode(),
}
raise chunk
else:
if not has_sent_metadata and event_aggregator.callback_events:
yield {
"event": "metadata",
"data": orjson.dumps(
{"run_id": _get_base_run_id_as_str(event_aggregator)}
).decode(),
}
has_sent_metadata = True

yield {
# EventSourceResponse expects a string for data
# so after serializing into bytes, we decode into utf-8
# to get a string.
"data": serializer.dumps(chunk).decode("utf-8"),
"event": "data",
}

@app.post("/ingest")
def ingest_endpoint(files: list[UploadFile], config: str = Form(...)):
# Send an end event to signal the end of the stream
yield {"event": "end"}
# Wait for the runnable to finish
await task

return EventSourceResponse(_stream())
else:
background_tasks.add_task(agent.ainvoke, input_, config)
return {"status": "ok"} # TODO add a run id


@app.post("/ingest", description="Upload files to the given user.")
def ingest_endpoint(files: list[UploadFile], config: str = Form(...)) -> None:
"""Ingest a list of files."""
config = orjson.loads(config)
return ingest_runnable.batch([file.file for file in files], config)


@app.get("/assistants/")
def list_assistants_endpoint(opengpts_user_id: Annotated[str, Cookie()]):
def list_assistants_endpoint(
opengpts_user_id: OpengptsUserId
) -> List[AssistantWithoutUserId]:
"""List all assistants for the current user."""
return list_assistants(opengpts_user_id)


@app.get("/assistants/public/")
def list_public_assistants_endpoint(shared_id: Optional[str] = None):
def list_public_assistants_endpoint(
shared_id: Annotated[
Optional[str], Query(description="ID of a publicly shared assistant.")
] = None,
) -> List[AssistantWithoutUserId]:
"""List all public assistants."""
return list_public_assistants(
FEATURED_PUBLIC_ASSISTANTS + ([shared_id] if shared_id else [])
)


class AssistantPayload(TypedDict):
name: str
config: dict
public: bool
class AssistantPayload(BaseModel):
"""Payload for creating an assistant."""

name: str = Field(..., description="The name of the assistant.")
config: dict = Field(..., description="The assistant config.")
public: bool = Field(default=False, description="Whether the assistant is public.")


AssistantID = Annotated[str, Path(description="The ID of the assistant.")]
ThreadID = Annotated[str, Path(description="The ID of the thread.")]


@app.put("/assistants/{aid}")
def put_assistant_endpoint(
aid: str,
payload: AssistantPayload,
opengpts_user_id: Annotated[str, Cookie()],
):
aid: AssistantID,
payload: AssistantPayload,
) -> Assistant:
"""Create or update an assistant."""
return put_assistant(
opengpts_user_id,
aid,
name=payload["name"],
config=payload["config"],
public=payload["public"],
name=payload.name,
config=payload.config,
public=payload.public,
)


@app.get("/threads/")
def list_threads_endpoint(opengpts_user_id: Annotated[str, Cookie()]):
def list_threads_endpoint(
opengpts_user_id: OpengptsUserId
) -> List[ThreadWithoutUserId]:
"""List all threads for the current user."""
return list_threads(opengpts_user_id)


@app.get("/threads/{tid}/messages")
def get_thread_messages_endpoint(opengpts_user_id: Annotated[str, Cookie()], tid: str):
def get_thread_messages_endpoint(
opengpts_user_id: OpengptsUserId,
tid: ThreadID,
):
"""Get all messages for a thread."""
return get_thread_messages(opengpts_user_id, tid)


class ThreadPayload(TypedDict):
name: str
assistant_id: str
class ThreadPutRequest(BaseModel):
"""Payload for creating a thread."""

name: str = Field(..., description="The name of the thread.")
assistant_id: str = Field(..., description="The ID of the assistant to use.")


@app.put("/threads/{tid}")
def put_thread_endpoint(
opengpts_user_id: Annotated[str, Cookie()], tid: str, payload: ThreadPayload
):
opengpts_user_id: OpengptsUserId,
tid: ThreadID,
thread_put_request: ThreadPutRequest,
) -> Thread:
"""Update a thread."""
return put_thread(
opengpts_user_id,
tid,
assistant_id=payload["assistant_id"],
name=payload["name"],
assistant_id=thread_put_request.assistant_id,
name=thread_put_request.name,
)


Expand Down
Loading