Skip to content

Commit

Permalink
Chat backend (#40)
Browse files Browse the repository at this point in the history
* implement chat UI

* WIP: Chat backend

* Working version with reply

* Working chat backend with websocket

* Fixed error

* Added types for messages

* Aligned with main branch, removed chat UI components.

* Removed yarn.lock updates

---------

Co-authored-by: David L. Qiu <david@qiu.dev>
  • Loading branch information
3coins and dlqqq authored Apr 8, 2023
1 parent b06526b commit bae1734
Show file tree
Hide file tree
Showing 7 changed files with 406 additions and 25 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,5 @@ playground/

# reserve path for a dev script
dev.sh

.vscode
45 changes: 43 additions & 2 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
import queue
from jupyter_server.extension.application import ExtensionApp
from .handlers import PromptAPIHandler, TaskAPIHandler
from langchain import ConversationChain
from .handlers import ChatHandler, ChatHistoryHandler, PromptAPIHandler, TaskAPIHandler, ChatAPIHandler
from importlib_metadata import entry_points
import inspect
from .engine import BaseModelEngine
from .providers import ChatOpenAIProvider
import os

from langchain.memory import ConversationBufferMemory
from langchain.prompts import (
ChatPromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate
)

class AiExtension(ExtensionApp):
name = "jupyter_ai"
handlers = [
("api/ai/prompt", PromptAPIHandler),
(r"api/ai/chat/?", ChatAPIHandler),
(r"api/ai/tasks/?", TaskAPIHandler),
(r"api/ai/tasks/([\w\-:]*)", TaskAPIHandler)
(r"api/ai/tasks/([\w\-:]*)", TaskAPIHandler),
(r"api/ai/chats/?", ChatHandler),
(r"api/ai/chats/history?", ChatHistoryHandler),
]

@property
Expand All @@ -18,6 +33,7 @@ def ai_engines(self):
self.settings["ai_engines"] = {}

return self.settings["ai_engines"]


def initialize_settings(self):
# EP := entry point
Expand Down Expand Up @@ -69,5 +85,30 @@ def initialize_settings(self):
self.settings["ai_default_tasks"] = default_tasks
self.log.info("Registered all default tasks.")

## load OpenAI chat provider
if ChatOpenAIProvider.auth_strategy.name in os.environ:
self.settings["openai_chat"] = ChatOpenAIProvider(model_id="gpt-3.5-turbo")
# Create a conversation memory
memory = ConversationBufferMemory(return_messages=True)
prompt_template = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template("The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know."),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}")
])
chain = ConversationChain(
llm=self.settings["openai_chat"],
prompt=prompt_template,
verbose=True,
memory=memory
)
self.settings["chat_provider"] = chain

self.log.info(f"Registered {self.name} server extension")

# Add a message queue to the settings to be used by the chat handler
self.settings["chat_message_queue"] = queue.Queue()

# Store chat clients in a dictionary
self.settings["chat_clients"] = {}


198 changes: 179 additions & 19 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from dataclasses import asdict
import json
from typing import Optional

import tornado
from tornado.web import HTTPError
from pydantic import ValidationError

from jupyter_server.base.handlers import APIHandler
from tornado import web, websocket

from jupyter_server.base.handlers import APIHandler as BaseAPIHandler, JupyterHandler
from jupyter_server.utils import ensure_async

from .task_manager import TaskManager
from .models import PromptRequest
from .models import ChatHistory, PromptRequest, ChatRequest
from langchain.schema import _message_to_dict, HumanMessage, AIMessage

class PromptAPIHandler(APIHandler):
class APIHandler(BaseAPIHandler):
@property
def engines(self):
return self.settings["ai_engines"]
Expand All @@ -26,6 +32,11 @@ def task_manager(self):
self.settings["task_manager"] = TaskManager(engines=self.engines, default_tasks=self.default_tasks)
return self.settings["task_manager"]

@property
def openai_chat(self):
return self.settings["openai_chat"]

class PromptAPIHandler(APIHandler):
@tornado.web.authenticated
async def post(self):
try:
Expand All @@ -49,23 +60,27 @@ async def post(self):
"insertion_mode": task.insertion_mode
}))

class TaskAPIHandler(APIHandler):
@property
def engines(self):
return self.settings["ai_engines"]

@property
def default_tasks(self):
return self.settings["ai_default_tasks"]
class ChatAPIHandler(APIHandler):
@tornado.web.authenticated
async def post(self):
try:
request = ChatRequest(**self.get_json_body())
except ValidationError as e:
self.log.exception(e)
raise HTTPError(500, str(e)) from e

if not self.openai_chat:
raise HTTPError(500, "No chat models available.")

result = await ensure_async(self.openai_chat.agenerate([request.prompt]))
output = result.generations[0][0].text
self.openai_chat.append_exchange(request.prompt, output)

@property
def task_manager(self):
# we have to create the TaskManager lazily, since no event loop is
# running in ServerApp.initialize_settings().
if "task_manager" not in self.settings:
self.settings["task_manager"] = TaskManager(engines=self.engines, default_tasks=self.default_tasks)
return self.settings["task_manager"]

self.finish(json.dumps({
"output": output,
}))

class TaskAPIHandler(APIHandler):
@tornado.web.authenticated
async def get(self, id=None):
if id is None:
Expand All @@ -78,3 +93,148 @@ async def get(self, id=None):
raise HTTPError(404, f"Task not found with ID: {id}")

self.finish(json.dumps(describe_task_response.dict()))


class ChatHistoryHandler(BaseAPIHandler):
"""Handler to return message history"""

_chat_provider = None
_messages = []

@property
def chat_provider(self):
if self._chat_provider is None:
self._chat_provider = self.settings["chat_provider"]
return self._chat_provider

@property
def messages(self):
self._messages = self.chat_provider.memory.chat_memory.messages or []
return self._messages

@tornado.web.authenticated
async def get(self):
messages = []
for message in self.messages:
messages.append(message)
history = ChatHistory(messages=messages)

self.finish(history.json(models_as_dict=False))

@tornado.web.authenticated
async def delete(self):
self.chat_provider.memory.chat_memory.clear()
self.messages = []
self.set_status(204)
self.finish()


class ChatHandler(
JupyterHandler,
websocket.WebSocketHandler
):
"""
A websocket handler for chat.
"""

_chat_provider = None
_chat_message_queue = None
_messages = []

@property
def chat_provider(self):
if self._chat_provider is None:
self._chat_provider = self.settings["chat_provider"]
return self._chat_provider

@property
def chat_message_queue(self):
if self._chat_message_queue is None:
self._chat_message_queue = self.settings["chat_message_queue"]
return self._chat_message_queue

@property
def messages(self):
self._messages = self.chat_provider.memory.chat_memory.messages or []
return self._messages

def add_chat_client(self, username):
self.settings["chat_clients"][username] = self
self.log.debug("Clients are : %s", self.settings["chat_clients"].keys())

def remove_chat_client(self, username):
self.settings["chat_clients"][username] = None
self.log.debug("Chat clients: %s", self.settings['chat_clients'].keys())

def initialize(self):
self.log.debug("Initializing websocket connection %s", self.request.path)

def pre_get(self):
"""Handles authentication/authorization.
"""
# authenticate the request before opening the websocket
user = self.current_user
if user is None:
self.log.warning("Couldn't authenticate WebSocket connection")
raise web.HTTPError(403)

# authorize the user.
if not self.authorizer.is_authorized(self, user, "execute", "events"):
raise web.HTTPError(403)

async def get(self, *args, **kwargs):
"""Get an event socket."""
self.pre_get()
res = super().get(*args, **kwargs)
await res

def open(self):
self.log.debug("Client with user %s connected...", self.current_user.username)
self.add_chat_client(self.current_user.username)

def broadcast_message(self, message: any, exclude_current_user: Optional[bool] = False):
"""Broadcasts message to all connected clients,
optionally excluding the current user
"""

self.log.debug("Broadcasting message: %s to all clients...", message)
client_names = self.settings["chat_clients"].keys()
if exclude_current_user:
client_names = client_names - [self.current_user.username]

for username in client_names:
client = self.settings["chat_clients"][username]
if client:
client.write_message(message)

def on_message(self, message):
self.log.debug("Message recieved: %s", message)

try:
message = json.loads(message)
chat_request = ChatRequest(**message)
except ValidationError as e:
self.log.error(e)
return

message = HumanMessage(
content=chat_request.prompt,
additional_kwargs=dict(user=asdict(self.current_user))
)
data = json.dumps(_message_to_dict(message))
# broadcast the message to other clients
self.broadcast_message(message=data, exclude_current_user=True)

# process the message
response = self.chat_provider.predict(input=message.content)

response = AIMessage(
content=response
)
# broadcast to all clients
self.broadcast_message(message=json.dumps(_message_to_dict(response)))


def on_close(self):
self.log.debug("Disconnecting client with user %s", self.current_user.username)
self.remove_chat_client(self.current_user.username)
18 changes: 16 additions & 2 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from pydantic import BaseModel
from typing import Dict, List
from pydantic import BaseModel, validator
from typing import Dict, List, Literal

from langchain.schema import BaseMessage, _message_to_dict

class PromptRequest(BaseModel):
task_id: str
engine_id: str
prompt_variables: Dict[str, str]

class ChatRequest(BaseModel):
prompt: str

class ListEnginesEntry(BaseModel):
id: str
name: str
Expand All @@ -22,3 +27,12 @@ class DescribeTaskResponse(BaseModel):
insertion_mode: str
prompt_template: str
engines: List[ListEnginesEntry]

class ChatHistory(BaseModel):
"""History of chat messages"""
messages: List[BaseMessage]

class Config:
json_encoders = {
BaseMessage: lambda v: _message_to_dict(v)
}
21 changes: 19 additions & 2 deletions packages/jupyter-ai/jupyter_ai/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
Cohere,
HuggingFaceHub,
OpenAI,
OpenAIChat,
SagemakerEndpoint
)

from pydantic import BaseModel, Extra
from langchain.chat_models import ChatOpenAI


class EnvAuthStrategy(BaseModel):
"""Require one auth token via an environment variable."""
Expand Down Expand Up @@ -153,7 +155,7 @@ class OpenAIProvider(BaseProvider, OpenAI):
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

class ChatOpenAIProvider(BaseProvider, OpenAIChat):
class ChatOpenAIProvider(BaseProvider, ChatOpenAI):
id = "openai-chat"
name = "OpenAI"
models = [
Expand All @@ -168,6 +170,21 @@ class ChatOpenAIProvider(BaseProvider, OpenAIChat):
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def append_exchange(self, prompt: str, output: str):
"""Appends a conversational exchange between user and an OpenAI Chat
model to a transcript that will be included in future exchanges."""
self.prefix_messages.append({
"role": "user",
"content": prompt
})
self.prefix_messages.append({
"role": "assistant",
"content": output
})

class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
id = "sagemaker-endpoint"
name = "Sagemaker Endpoint"
Expand Down
Loading

0 comments on commit bae1734

Please sign in to comment.