Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate streaming logic to BaseChatHandler #1039

Merged
merged 5 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ async def process_message(self, message: HumanChatMessage):
try:
with self.pending("Searching learned documents", message):
assert self.llm_chain
result = await self.llm_chain.acall({"question": query})
# TODO: migrate this class to use a LCEL `Runnable` instead of
# `Chain`, then remove the below ignore comment.
result = await self.llm_chain.acall( # type:ignore[attr-defined]
{"question": query}
)
response = result["answer"]
self.reply(response, message)
except AssertionError as e:
Expand Down
146 changes: 141 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import argparse
import asyncio
import contextlib
import os
import time
import traceback
from asyncio import Event
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
ClassVar,
Dict,
Expand All @@ -20,10 +21,13 @@
from uuid import uuid4

from dask.distributed import Client as DaskClient
from jupyter_ai.callback_handlers import MetadataCallbackHandler
from jupyter_ai.config_manager import ConfigManager, Logger
from jupyter_ai.history import WrappedBoundedChatHistory
from jupyter_ai.models import (
AgentChatMessage,
AgentStreamChunkMessage,
AgentStreamMessage,
ChatMessage,
ClosePendingMessage,
HumanChatMessage,
Expand All @@ -32,8 +36,12 @@
)
from jupyter_ai_magics import Persona
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import LLMChain
from langchain.pydantic_v1 import BaseModel
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables import Runnable
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.config import merge_configs as merge_runnable_configs
from langchain_core.runnables.utils import Input

if TYPE_CHECKING:
from jupyter_ai.context_providers import BaseCommandContextProvider
Expand Down Expand Up @@ -129,7 +137,7 @@ class BaseChatHandler:
"""Dictionary of context providers. Allows chat handlers to reference
context providers, which can be used to provide context to the LLM."""

message_interrupted: Dict[str, Event]
message_interrupted: Dict[str, asyncio.Event]
"""Dictionary mapping an agent message identifier to an asyncio Event
which indicates if the message generation/streaming was interrupted."""

Expand All @@ -147,7 +155,7 @@ def __init__(
help_message_template: str,
chat_handlers: Dict[str, "BaseChatHandler"],
context_providers: Dict[str, "BaseCommandContextProvider"],
message_interrupted: Dict[str, Event],
message_interrupted: Dict[str, asyncio.Event],
):
self.log = log
self.config_manager = config_manager
Expand All @@ -173,7 +181,7 @@ def __init__(

self.llm: Optional[BaseProvider] = None
self.llm_params: Optional[dict] = None
self.llm_chain: Optional[LLMChain] = None
self.llm_chain: Optional[Runnable] = None

async def on_message(self, message: HumanChatMessage):
"""
Expand Down Expand Up @@ -471,3 +479,131 @@ def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> Non
)

self.broadcast_message(help_message)

def _start_stream(self, human_msg: HumanChatMessage) -> str:
"""
Sends an `agent-stream` message to indicate the start of a response
stream. Returns the ID of the message, denoted as the `stream_id`.
"""
stream_id = uuid4().hex
stream_msg = AgentStreamMessage(
id=stream_id,
time=time.time(),
body="",
reply_to=human_msg.id,
persona=self.persona,
complete=False,
)

self.broadcast_message(stream_msg)
return stream_id

def _send_stream_chunk(
self,
stream_id: str,
content: str,
complete: bool = False,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""
Sends an `agent-stream-chunk` message containing content that should be
appended to an existing `agent-stream` message with ID `stream_id`.
"""
if not metadata:
metadata = {}

stream_chunk_msg = AgentStreamChunkMessage(
id=stream_id, content=content, stream_complete=complete, metadata=metadata
)
self.broadcast_message(stream_chunk_msg)

async def stream_reply(
self,
input: Input,
human_msg: HumanChatMessage,
config: Optional[RunnableConfig] = None,
):
"""
Streams a reply to a human message by invoking
`self.llm_chain.astream()`. A LangChain `Runnable` instance must be
bound to `self.llm_chain` before invoking this method.

Arguments
---------
- `input`: The input to your runnable. The type of `input` depends on
the runnable in `self.llm_chain`, but is usually a dictionary whose keys
refer to input variables in your prompt template.

- `human_msg`: The `HumanChatMessage` being replied to.

- `config` (optional): A `RunnableConfig` object that specifies
additional configuration when streaming from the runnable.
"""
assert self.llm_chain
assert isinstance(self.llm_chain, Runnable)

received_first_chunk = False
metadata_handler = MetadataCallbackHandler()
base_config: RunnableConfig = {
"configurable": {"last_human_msg": human_msg},
"callbacks": [metadata_handler],
}
merged_config: RunnableConfig = merge_runnable_configs(base_config, config)

# start with a pending message
with self.pending("Generating response", human_msg) as pending_message:
# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
chunk_generator = self.llm_chain.astream(input, config=merged_config)
stream_interrupted = False
async for chunk in chunk_generator:
if not received_first_chunk:
# when receiving the first chunk, close the pending message and
# start the stream.
self.close_pending(pending_message)
stream_id = self._start_stream(human_msg=human_msg)
received_first_chunk = True
self.message_interrupted[stream_id] = asyncio.Event()

if self.message_interrupted[stream_id].is_set():
try:
# notify the model provider that streaming was interrupted
# (this is essential to allow the model to stop generating)
#
# note: `mypy` flags this line, claiming that `athrow` is
# not defined on `AsyncIterator`. This is why an ignore
# comment is placed here.
await chunk_generator.athrow( # type:ignore[attr-defined]
GenerationInterrupted()
)
except GenerationInterrupted:
# do not let the exception bubble up in case if
# the provider did not handle it
pass
stream_interrupted = True
break

if isinstance(chunk, AIMessageChunk) and isinstance(chunk.content, str):
self._send_stream_chunk(stream_id, chunk.content)
elif isinstance(chunk, str):
self._send_stream_chunk(stream_id, chunk)
else:
self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}")
break

# complete stream after all chunks have been streamed
stream_tombstone = (
"\n\n(AI response stopped by user)" if stream_interrupted else ""
)
self._send_stream_chunk(
stream_id,
stream_tombstone,
complete=True,
metadata=metadata_handler.jai_metadata,
)
del self.message_interrupted[stream_id]


class GenerationInterrupted(asyncio.CancelledError):
"""Exception raised when streaming is cancelled by the user"""
118 changes: 3 additions & 115 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,15 @@
import asyncio
import time
from typing import Any, Dict, Type
from uuid import uuid4
from typing import Dict, Type

from jupyter_ai.callback_handlers import MetadataCallbackHandler
from jupyter_ai.models import (
AgentStreamChunkMessage,
AgentStreamMessage,
HumanChatMessage,
)
from jupyter_ai.models import HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables import ConfigurableFieldSpec
from langchain_core.runnables.history import RunnableWithMessageHistory

from ..context_providers import ContextProviderException, find_commands
from .base import BaseChatHandler, SlashCommandRoutingType


class GenerationInterrupted(asyncio.CancelledError):
"""Exception raised when streaming is cancelled by the user"""


class DefaultChatHandler(BaseChatHandler):
id = "default"
name = "Default"
Expand Down Expand Up @@ -65,55 +53,8 @@ def create_llm_chain(
)
self.llm_chain = runnable

def _start_stream(self, human_msg: HumanChatMessage) -> str:
"""
Sends an `agent-stream` message to indicate the start of a response
stream. Returns the ID of the message, denoted as the `stream_id`.
"""
stream_id = uuid4().hex
stream_msg = AgentStreamMessage(
id=stream_id,
time=time.time(),
body="",
reply_to=human_msg.id,
persona=self.persona,
complete=False,
)

for handler in self._root_chat_handlers.values():
if not handler:
continue

handler.broadcast_message(stream_msg)
break

return stream_id

def _send_stream_chunk(
self,
stream_id: str,
content: str,
complete: bool = False,
metadata: Dict[str, Any] = {},
):
"""
Sends an `agent-stream-chunk` message containing content that should be
appended to an existing `agent-stream` message with ID `stream_id`.
"""
stream_chunk_msg = AgentStreamChunkMessage(
id=stream_id, content=content, stream_complete=complete, metadata=metadata
)

for handler in self._root_chat_handlers.values():
if not handler:
continue

handler.broadcast_message(stream_chunk_msg)
break

async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()
received_first_chunk = False
assert self.llm_chain

inputs = {"input": message.body}
Expand All @@ -127,60 +68,7 @@ async def process_message(self, message: HumanChatMessage):
inputs["context"] = context_prompt
inputs["input"] = self.replace_prompt(inputs["input"])

# start with a pending message
with self.pending("Generating response", message) as pending_message:
# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
metadata_handler = MetadataCallbackHandler()
chunk_generator = self.llm_chain.astream(
inputs,
config={
"configurable": {"last_human_msg": message},
"callbacks": [metadata_handler],
},
)
stream_interrupted = False
async for chunk in chunk_generator:
if not received_first_chunk:
# when receiving the first chunk, close the pending message and
# start the stream.
self.close_pending(pending_message)
stream_id = self._start_stream(human_msg=message)
received_first_chunk = True
self.message_interrupted[stream_id] = asyncio.Event()

if self.message_interrupted[stream_id].is_set():
try:
# notify the model provider that streaming was interrupted
# (this is essential to allow the model to stop generating)
await chunk_generator.athrow(GenerationInterrupted())
except GenerationInterrupted:
# do not let the exception bubble up in case if
# the provider did not handle it
pass
stream_interrupted = True
break

if isinstance(chunk, AIMessageChunk) and isinstance(chunk.content, str):
self._send_stream_chunk(stream_id, chunk.content)
elif isinstance(chunk, str):
self._send_stream_chunk(stream_id, chunk)
else:
self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}")
break

# complete stream after all chunks have been streamed
stream_tombstone = (
"\n\n(AI response stopped by user)" if stream_interrupted else ""
)
self._send_stream_chunk(
stream_id,
stream_tombstone,
complete=True,
metadata=metadata_handler.jai_metadata,
)
del self.message_interrupted[stream_id]
await self.stream_reply(inputs, message)

async def make_context_prompt(self, human_msg: HumanChatMessage) -> str:
return "\n\n".join(
Expand Down
10 changes: 8 additions & 2 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def create_llm_chain(
llm = provider(**unified_parameters)

self.llm = llm
self.llm_chain = LLMChain(llm=llm, prompt=FIX_PROMPT_TEMPLATE, verbose=True)
# TODO: migrate this class to use a LCEL `Runnable` instead of
# `Chain`, then remove the below ignore comment.
self.llm_chain = LLMChain( # type:ignore[arg-type]
llm=llm, prompt=FIX_PROMPT_TEMPLATE, verbose=True
)

async def process_message(self, message: HumanChatMessage):
if not (message.selection and message.selection.type == "cell-with-error"):
Expand All @@ -94,7 +98,9 @@ async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()
with self.pending("Analyzing error", message):
assert self.llm_chain
response = await self.llm_chain.apredict(
# TODO: migrate this class to use a LCEL `Runnable` instead of
# `Chain`, then remove the below ignore comment.
response = await self.llm_chain.apredict( # type:ignore[attr-defined]
extra_instructions=extra_instructions,
stop=["\nHuman:"],
cell_content=selection.source,
Expand Down
Loading