Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add partial validation (first pass) #56

Merged
merged 3 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions backend/app/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from datetime import datetime

from typing_extensions import TypedDict


class AssistantWithoutUserId(TypedDict):
"""Assistant model."""

assistant_id: str
"""The ID of the assistant."""
name: str
"""The name of the assistant."""
config: dict
"""The assistant config."""
updated_at: datetime
"""The last time the assistant was updated."""
public: bool
"""Whether the assistant is public."""


class Assistant(AssistantWithoutUserId):
"""Assistant model."""

user_id: str
"""The ID of the user that owns the assistant."""


class ThreadWithoutUserId(TypedDict):
thread_id: str
"""The ID of the thread."""
assistant_id: str
"""The assistant that was used in conjunction with this thread."""
name: str
"""The name of the thread."""
updated_at: datetime
"""The last time the thread was updated."""


class Thread(ThreadWithoutUserId):
"""Thread model."""

user_id: str
"""The ID of the user that owns the thread."""
113 changes: 83 additions & 30 deletions backend/app/server.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
import asyncio
import json
from pathlib import Path
from typing import Annotated, AsyncIterator, Optional, Sequence
from typing import Annotated, AsyncIterator, List, Optional, Sequence
from uuid import uuid4
from fastapi.exceptions import RequestValidationError

import orjson
from fastapi import BackgroundTasks, Cookie, FastAPI, Form, Request, UploadFile
from fastapi import BackgroundTasks, Cookie, FastAPI, Form, Query, Request, UploadFile
from fastapi.exceptions import RequestValidationError
from fastapi.staticfiles import StaticFiles
from gizmo_agent import agent, ingest_runnable
from langserve.callbacks import AsyncEventAggregatorCallback
from langchain.pydantic_v1 import ValidationError
from langchain.schema.messages import AnyMessage, FunctionMessage
from langchain.schema.output import ChatGeneration
from langchain.schema.runnable import RunnableConfig
from langserve import add_routes
from langserve.server import _get_base_run_id_as_str, _unpack_input
from langserve.callbacks import AsyncEventAggregatorCallback
from langserve.serialization import WellKnownLCSerializer
from pydantic import BaseModel
from langserve.server import _get_base_run_id_as_str, _unpack_input
from pydantic import BaseModel, Field
from sse_starlette import EventSourceResponse
from typing_extensions import TypedDict

from app.schema import Assistant, AssistantWithoutUserId, Thread, ThreadWithoutUserId
from app.storage import (
get_assistant,
get_thread_messages,
Expand All @@ -42,11 +42,32 @@
# Get root of app, used to point to directory containing static files
ROOT = Path(__file__).parent.parent

OpengptsUserId = Annotated[
str,
Cookie(
description=(
"A cookie that identifies the user. This is not an authentication "
"mechanism that should be used in an actual production environment that "
"contains sensitive information."
)
),
]


def attach_user_id_to_config(
config: RunnableConfig,
request: Request,
) -> RunnableConfig:
"""Attach the user id to the runnable config.

Args:
config: The runnable config.
request: The request.

Returns:
A modified runnable config that contains information about the user
who made the request in the `configurable.user_id` field.
"""
config["configurable"]["user_id"] = request.cookies["opengpts_user_id"]
return config

Expand All @@ -63,10 +84,14 @@ def attach_user_id_to_config(


class AgentInput(BaseModel):
"""An input into an agent."""

messages: Sequence[AnyMessage]


class CreateRunPayload(BaseModel):
"""Payload for creating a run."""

assistant_id: str
thread_id: str
stream: bool
Expand All @@ -80,6 +105,7 @@ async def create_run_endpoint(
opengpts_user_id: Annotated[str, Cookie()],
background_tasks: BackgroundTasks,
):
"""Create a run."""
try:
body = await request.json()
except json.JSONDecodeError:
Expand Down Expand Up @@ -176,70 +202,97 @@ async def _stream() -> AsyncIterator[dict]:
return {"status": "ok"} # TODO add a run id


@app.post("/ingest")
def ingest_endpoint(files: list[UploadFile], config: str = Form(...)):
@app.post("/ingest", description="Upload files to the given user.")
def ingest_endpoint(files: list[UploadFile], config: str = Form(...)) -> None:
"""Ingest a list of files."""
config = orjson.loads(config)
return ingest_runnable.batch([file.file for file in files], config)


@app.get("/assistants/")
def list_assistants_endpoint(opengpts_user_id: Annotated[str, Cookie()]):
def list_assistants_endpoint(
opengpts_user_id: OpengptsUserId
) -> List[AssistantWithoutUserId]:
"""List all assistants for the current user."""
return list_assistants(opengpts_user_id)


@app.get("/assistants/public/")
def list_public_assistants_endpoint(shared_id: Optional[str] = None):
def list_public_assistants_endpoint(
shared_id: Annotated[
Optional[str], Query(description="ID of a publicly shared assistant.")
] = None,
) -> List[AssistantWithoutUserId]:
"""List all public assistants."""
return list_public_assistants(
FEATURED_PUBLIC_ASSISTANTS + ([shared_id] if shared_id else [])
)


class AssistantPayload(TypedDict):
name: str
config: dict
public: bool
class AssistantPayload(BaseModel):
"""Payload for creating an assistant."""

name: str = Field(..., description="The name of the assistant.")
config: dict = Field(..., description="The assistant config.")
public: bool = Field(default=False, description="Whether the assistant is public.")


AssistantID = Annotated[str, Path(description="The ID of the assistant.")]
ThreadID = Annotated[str, Path(description="The ID of the thread.")]


@app.put("/assistants/{aid}")
def put_assistant_endpoint(
aid: str,
payload: AssistantPayload,
opengpts_user_id: Annotated[str, Cookie()],
):
aid: AssistantID,
payload: AssistantPayload,
) -> Assistant:
"""Create or update an assistant."""
return put_assistant(
opengpts_user_id,
aid,
name=payload["name"],
config=payload["config"],
public=payload["public"],
name=payload.name,
config=payload.config,
public=payload.public,
)


@app.get("/threads/")
def list_threads_endpoint(opengpts_user_id: Annotated[str, Cookie()]):
def list_threads_endpoint(
opengpts_user_id: OpengptsUserId
) -> List[ThreadWithoutUserId]:
"""List all threads for the current user."""
return list_threads(opengpts_user_id)


@app.get("/threads/{tid}/messages")
def get_thread_messages_endpoint(opengpts_user_id: Annotated[str, Cookie()], tid: str):
def get_thread_messages_endpoint(
opengpts_user_id: OpengptsUserId,
tid: ThreadID,
):
"""Get all messages for a thread."""
return get_thread_messages(opengpts_user_id, tid)


class ThreadPayload(TypedDict):
name: str
assistant_id: Optional[str]
class ThreadPutRequest(BaseModel):
"""Payload for creating a thread."""

name: str = Field(..., description="The name of the thread.")
assistant_id: str = Field(..., description="The ID of the assistant to use.")


@app.put("/threads/{tid}")
def put_thread_endpoint(
opengpts_user_id: Annotated[str, Cookie()], tid: str, payload: ThreadPayload
):
opengpts_user_id: OpengptsUserId,
tid: ThreadID,
thread_put_request: ThreadPutRequest,
) -> Thread:
"""Update a thread."""
return put_thread(
opengpts_user_id,
tid,
assistant_id=payload.get("assistant_id"),
name=payload["name"],
assistant_id=thread_put_request.assistant_id,
name=thread_put_request.name,
)


Expand Down
57 changes: 40 additions & 17 deletions backend/app/storage.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
import os
from datetime import datetime
from typing import List, Sequence

import orjson
from agent_executor.checkpoint import RedisCheckpoint
from langchain.schema.messages import AnyMessage
from langchain.utilities.redis import get_client
from agent_executor.checkpoint import RedisCheckpoint
from permchain.channels.base import ChannelsManager
from permchain.channels import Topic
from permchain.channels.base import ChannelsManager
from redis.client import Redis as RedisType

from app.schema import Assistant, AssistantWithoutUserId, Thread, ThreadWithoutUserId


def assistants_list_key(user_id: str):
def assistants_list_key(user_id: str) -> str:
return f"opengpts:{user_id}:assistants"


def assistant_key(user_id: str, assistant_id: str):
def assistant_key(user_id: str, assistant_id: str) -> str:
return f"opengpts:{user_id}:assistant:{assistant_id}"


def threads_list_key(user_id: str):
def threads_list_key(user_id: str) -> str:
return f"opengpts:{user_id}:threads"


def thread_key(user_id: str, thread_id: str):
def thread_key(user_id: str, thread_id: str) -> str:
return f"opengpts:{user_id}:thread:{thread_id}"


Expand All @@ -47,7 +50,8 @@ def _get_redis_client() -> RedisType:
return get_client(url)


def list_assistants(user_id: str):
def list_assistants(user_id: str) -> List[Assistant]:
"""List all assistants for the current user."""
client = _get_redis_client()
ids = [orjson.loads(id) for id in client.smembers(assistants_list_key(user_id))]
with client.pipeline() as pipe:
Expand All @@ -57,13 +61,17 @@ def list_assistants(user_id: str):
return [load(assistant_hash_keys, values) for values in assistants]


def get_assistant(user_id: str, assistant_id: str):
def get_assistant(user_id: str, assistant_id: str) -> Assistant:
"""Get an assistant by ID."""
client = _get_redis_client()
values = client.hmget(assistant_key(user_id, assistant_id), *assistant_hash_keys)
return load(assistant_hash_keys, values)


def list_public_assistants(assistant_ids: list[str]):
def list_public_assistants(
assistant_ids: Sequence[str]
) -> List[AssistantWithoutUserId]:
"""List all the public assistants."""
if not assistant_ids:
return []
client = _get_redis_client()
Expand All @@ -87,10 +95,22 @@ def list_public_assistants(assistant_ids: list[str]):

def put_assistant(
user_id: str, assistant_id: str, *, name: str, config: dict, public: bool = False
):
saved = {
"user_id": user_id,
"assistant_id": assistant_id,
) -> Assistant:
"""Modify an assistant.

Args:
user_id: The user ID.
assistant_id: The assistant ID.
name: The assistant name.
config: The assistant config.
public: Whether the assistant is public.

Returns:
return the assistant model if no exception is raised.
"""
saved: Assistant = {
"user_id": user_id, # TODO(Nuno): Could we remove this?
"assistant_id": assistant_id, # TODO(Nuno): remove this?
"name": name,
"config": config,
"updated_at": datetime.utcnow(),
Expand All @@ -107,7 +127,8 @@ def put_assistant(
return saved


def list_threads(user_id: str):
def list_threads(user_id: str) -> List[ThreadWithoutUserId]:
"""List all threads for the current user."""
client = _get_redis_client()
ids = [orjson.loads(id) for id in client.smembers(threads_list_key(user_id))]
with client.pipeline() as pipe:
Expand All @@ -118,6 +139,7 @@ def list_threads(user_id: str):


def get_thread_messages(user_id: str, thread_id: str):
"""Get all messages for a thread."""
client = RedisCheckpoint()
checkpoint = client.get(
{"configurable": {"user_id": user_id, "thread_id": thread_id}}
Expand All @@ -130,9 +152,10 @@ def get_thread_messages(user_id: str, thread_id: str):
return {k: v.get() for k, v in channels.items()}


def put_thread(user_id: str, thread_id: str, *, assistant_id: str, name: str):
saved = {
"user_id": user_id,
def put_thread(user_id: str, thread_id: str, *, assistant_id: str, name: str) -> Thread:
"""Modify a thread."""
saved: Thread = {
"user_id": user_id, # TODO(Nuno): Could we remove this?
"thread_id": thread_id,
"assistant_id": assistant_id,
"name": name,
Expand Down
8 changes: 8 additions & 0 deletions backend/packages/agent-executor/agent_executor/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,17 @@ def _convert_ingestion_input_to_blob(data: BinaryIO) -> Blob:


class IngestRunnable(RunnableSerializable[BinaryIO, List[str]]):
"""Runnable for ingesting files into a vectorstore."""

text_splitter: TextSplitter
"""Text splitter to use for splitting the text into chunks."""
vectorstore: VectorStore
"""Vectorstore to ingest into."""
assistant_id: Optional[str]
"""Ingested documents will be associated with this assistant id.

The assistant ID is used as the namespace, and is filtered on at query time.
"""

class Config:
arbitrary_types_allowed = True
Expand Down
Loading