Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat backend #40

Merged
merged 8 commits into from
Apr 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,5 @@ playground/

# reserve path for a dev script
dev.sh

.vscode
45 changes: 43 additions & 2 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
import queue
from jupyter_server.extension.application import ExtensionApp
from .handlers import PromptAPIHandler, TaskAPIHandler
from langchain import ConversationChain
from .handlers import ChatHandler, ChatHistoryHandler, PromptAPIHandler, TaskAPIHandler, ChatAPIHandler
from importlib_metadata import entry_points
import inspect
from .engine import BaseModelEngine
from .providers import ChatOpenAIProvider
import os

from langchain.memory import ConversationBufferMemory
from langchain.prompts import (
ChatPromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate
)

class AiExtension(ExtensionApp):
name = "jupyter_ai"
handlers = [
("api/ai/prompt", PromptAPIHandler),
(r"api/ai/chat/?", ChatAPIHandler),
(r"api/ai/tasks/?", TaskAPIHandler),
(r"api/ai/tasks/([\w\-:]*)", TaskAPIHandler)
(r"api/ai/tasks/([\w\-:]*)", TaskAPIHandler),
(r"api/ai/chats/?", ChatHandler),
(r"api/ai/chats/history?", ChatHistoryHandler),
]

@property
Expand All @@ -18,6 +33,7 @@ def ai_engines(self):
self.settings["ai_engines"] = {}

return self.settings["ai_engines"]


def initialize_settings(self):
# EP := entry point
Expand Down Expand Up @@ -69,5 +85,30 @@ def initialize_settings(self):
self.settings["ai_default_tasks"] = default_tasks
self.log.info("Registered all default tasks.")

## load OpenAI chat provider
if ChatOpenAIProvider.auth_strategy.name in os.environ:
self.settings["openai_chat"] = ChatOpenAIProvider(model_id="gpt-3.5-turbo")
# Create a conversation memory
memory = ConversationBufferMemory(return_messages=True)
prompt_template = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template("The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know."),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}")
])
chain = ConversationChain(
llm=self.settings["openai_chat"],
prompt=prompt_template,
verbose=True,
memory=memory
)
self.settings["chat_provider"] = chain

self.log.info(f"Registered {self.name} server extension")

# Add a message queue to the settings to be used by the chat handler
self.settings["chat_message_queue"] = queue.Queue()

# Store chat clients in a dictionary
self.settings["chat_clients"] = {}


198 changes: 179 additions & 19 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from dataclasses import asdict
import json
from typing import Optional

import tornado
from tornado.web import HTTPError
from pydantic import ValidationError

from jupyter_server.base.handlers import APIHandler
from tornado import web, websocket

from jupyter_server.base.handlers import APIHandler as BaseAPIHandler, JupyterHandler
from jupyter_server.utils import ensure_async

from .task_manager import TaskManager
from .models import PromptRequest
from .models import ChatHistory, PromptRequest, ChatRequest
from langchain.schema import _message_to_dict, HumanMessage, AIMessage

class PromptAPIHandler(APIHandler):
class APIHandler(BaseAPIHandler):
@property
def engines(self):
return self.settings["ai_engines"]
Expand All @@ -26,6 +32,11 @@ def task_manager(self):
self.settings["task_manager"] = TaskManager(engines=self.engines, default_tasks=self.default_tasks)
return self.settings["task_manager"]

@property
def openai_chat(self):
return self.settings["openai_chat"]

class PromptAPIHandler(APIHandler):
@tornado.web.authenticated
async def post(self):
try:
Expand All @@ -49,23 +60,27 @@ async def post(self):
"insertion_mode": task.insertion_mode
}))

class TaskAPIHandler(APIHandler):
@property
def engines(self):
return self.settings["ai_engines"]

@property
def default_tasks(self):
return self.settings["ai_default_tasks"]
class ChatAPIHandler(APIHandler):
@tornado.web.authenticated
async def post(self):
try:
request = ChatRequest(**self.get_json_body())
except ValidationError as e:
self.log.exception(e)
raise HTTPError(500, str(e)) from e

if not self.openai_chat:
raise HTTPError(500, "No chat models available.")

result = await ensure_async(self.openai_chat.agenerate([request.prompt]))
output = result.generations[0][0].text
self.openai_chat.append_exchange(request.prompt, output)

@property
def task_manager(self):
# we have to create the TaskManager lazily, since no event loop is
# running in ServerApp.initialize_settings().
if "task_manager" not in self.settings:
self.settings["task_manager"] = TaskManager(engines=self.engines, default_tasks=self.default_tasks)
return self.settings["task_manager"]

self.finish(json.dumps({
"output": output,
}))

class TaskAPIHandler(APIHandler):
@tornado.web.authenticated
async def get(self, id=None):
if id is None:
Expand All @@ -78,3 +93,148 @@ async def get(self, id=None):
raise HTTPError(404, f"Task not found with ID: {id}")

self.finish(json.dumps(describe_task_response.dict()))


class ChatHistoryHandler(BaseAPIHandler):
"""Handler to return message history"""

_chat_provider = None
_messages = []

@property
def chat_provider(self):
if self._chat_provider is None:
self._chat_provider = self.settings["chat_provider"]
return self._chat_provider

@property
def messages(self):
self._messages = self.chat_provider.memory.chat_memory.messages or []
return self._messages

@tornado.web.authenticated
async def get(self):
messages = []
for message in self.messages:
messages.append(message)
history = ChatHistory(messages=messages)

self.finish(history.json(models_as_dict=False))

@tornado.web.authenticated
async def delete(self):
self.chat_provider.memory.chat_memory.clear()
self.messages = []
self.set_status(204)
self.finish()


class ChatHandler(
JupyterHandler,
websocket.WebSocketHandler
):
"""
A websocket handler for chat.
"""

_chat_provider = None
_chat_message_queue = None
_messages = []

@property
def chat_provider(self):
if self._chat_provider is None:
self._chat_provider = self.settings["chat_provider"]
return self._chat_provider

@property
def chat_message_queue(self):
if self._chat_message_queue is None:
self._chat_message_queue = self.settings["chat_message_queue"]
return self._chat_message_queue

@property
def messages(self):
self._messages = self.chat_provider.memory.chat_memory.messages or []
return self._messages

def add_chat_client(self, username):
self.settings["chat_clients"][username] = self
self.log.debug("Clients are : %s", self.settings["chat_clients"].keys())

def remove_chat_client(self, username):
self.settings["chat_clients"][username] = None
self.log.debug("Chat clients: %s", self.settings['chat_clients'].keys())

def initialize(self):
self.log.debug("Initializing websocket connection %s", self.request.path)

def pre_get(self):
"""Handles authentication/authorization.
"""
# authenticate the request before opening the websocket
user = self.current_user
if user is None:
self.log.warning("Couldn't authenticate WebSocket connection")
raise web.HTTPError(403)

# authorize the user.
if not self.authorizer.is_authorized(self, user, "execute", "events"):
raise web.HTTPError(403)

async def get(self, *args, **kwargs):
"""Get an event socket."""
self.pre_get()
res = super().get(*args, **kwargs)
await res

def open(self):
self.log.debug("Client with user %s connected...", self.current_user.username)
self.add_chat_client(self.current_user.username)

def broadcast_message(self, message: any, exclude_current_user: Optional[bool] = False):
"""Broadcasts message to all connected clients,
optionally excluding the current user
"""

self.log.debug("Broadcasting message: %s to all clients...", message)
client_names = self.settings["chat_clients"].keys()
if exclude_current_user:
client_names = client_names - [self.current_user.username]

for username in client_names:
client = self.settings["chat_clients"][username]
if client:
client.write_message(message)

def on_message(self, message):
self.log.debug("Message recieved: %s", message)

try:
message = json.loads(message)
chat_request = ChatRequest(**message)
except ValidationError as e:
self.log.error(e)
return

message = HumanMessage(
content=chat_request.prompt,
additional_kwargs=dict(user=asdict(self.current_user))
)
data = json.dumps(_message_to_dict(message))
# broadcast the message to other clients
self.broadcast_message(message=data, exclude_current_user=True)

# process the message
response = self.chat_provider.predict(input=message.content)

response = AIMessage(
content=response
)
# broadcast to all clients
self.broadcast_message(message=json.dumps(_message_to_dict(response)))


def on_close(self):
self.log.debug("Disconnecting client with user %s", self.current_user.username)
self.remove_chat_client(self.current_user.username)
18 changes: 16 additions & 2 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from pydantic import BaseModel
from typing import Dict, List
from pydantic import BaseModel, validator
from typing import Dict, List, Literal

from langchain.schema import BaseMessage, _message_to_dict

class PromptRequest(BaseModel):
task_id: str
engine_id: str
prompt_variables: Dict[str, str]

class ChatRequest(BaseModel):
prompt: str

class ListEnginesEntry(BaseModel):
id: str
name: str
Expand All @@ -22,3 +27,12 @@ class DescribeTaskResponse(BaseModel):
insertion_mode: str
prompt_template: str
engines: List[ListEnginesEntry]

class ChatHistory(BaseModel):
"""History of chat messages"""
messages: List[BaseMessage]

class Config:
json_encoders = {
BaseMessage: lambda v: _message_to_dict(v)
}
21 changes: 19 additions & 2 deletions packages/jupyter-ai/jupyter_ai/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
Cohere,
HuggingFaceHub,
OpenAI,
OpenAIChat,
SagemakerEndpoint
)

from pydantic import BaseModel, Extra
from langchain.chat_models import ChatOpenAI


class EnvAuthStrategy(BaseModel):
"""Require one auth token via an environment variable."""
Expand Down Expand Up @@ -153,7 +155,7 @@ class OpenAIProvider(BaseProvider, OpenAI):
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

class ChatOpenAIProvider(BaseProvider, OpenAIChat):
class ChatOpenAIProvider(BaseProvider, ChatOpenAI):
id = "openai-chat"
name = "OpenAI"
models = [
Expand All @@ -168,6 +170,21 @@ class ChatOpenAIProvider(BaseProvider, OpenAIChat):
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def append_exchange(self, prompt: str, output: str):
"""Appends a conversational exchange between user and an OpenAI Chat
model to a transcript that will be included in future exchanges."""
self.prefix_messages.append({
"role": "user",
"content": prompt
})
self.prefix_messages.append({
"role": "assistant",
"content": output
})

class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
id = "sagemaker-endpoint"
name = "Sagemaker Endpoint"
Expand Down
Loading