diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index 2da303e5a..b5c4fa38b 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -73,7 +73,11 @@ async def process_message(self, message: HumanChatMessage): try: with self.pending("Searching learned documents", message): assert self.llm_chain - result = await self.llm_chain.acall({"question": query}) + # TODO: migrate this class to use a LCEL `Runnable` instead of + # `Chain`, then remove the below ignore comment. + result = await self.llm_chain.acall( # type:ignore[attr-defined] + {"question": query} + ) response = result["answer"] self.reply(response, message) except AssertionError as e: diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index fc09d8f19..107b5a000 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -1,11 +1,12 @@ import argparse +import asyncio import contextlib import os import time import traceback -from asyncio import Event from typing import ( TYPE_CHECKING, + Any, Awaitable, ClassVar, Dict, @@ -20,10 +21,13 @@ from uuid import uuid4 from dask.distributed import Client as DaskClient +from jupyter_ai.callback_handlers import MetadataCallbackHandler from jupyter_ai.config_manager import ConfigManager, Logger from jupyter_ai.history import WrappedBoundedChatHistory from jupyter_ai.models import ( AgentChatMessage, + AgentStreamChunkMessage, + AgentStreamMessage, ChatMessage, ClosePendingMessage, HumanChatMessage, @@ -32,8 +36,12 @@ ) from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import BaseProvider -from langchain.chains import LLMChain from langchain.pydantic_v1 import BaseModel +from langchain_core.messages import AIMessageChunk +from langchain_core.runnables import Runnable +from langchain_core.runnables.config import RunnableConfig +from langchain_core.runnables.config import merge_configs as merge_runnable_configs +from langchain_core.runnables.utils import Input if TYPE_CHECKING: from jupyter_ai.context_providers import BaseCommandContextProvider @@ -129,7 +137,7 @@ class BaseChatHandler: """Dictionary of context providers. Allows chat handlers to reference context providers, which can be used to provide context to the LLM.""" - message_interrupted: Dict[str, Event] + message_interrupted: Dict[str, asyncio.Event] """Dictionary mapping an agent message identifier to an asyncio Event which indicates if the message generation/streaming was interrupted.""" @@ -147,7 +155,7 @@ def __init__( help_message_template: str, chat_handlers: Dict[str, "BaseChatHandler"], context_providers: Dict[str, "BaseCommandContextProvider"], - message_interrupted: Dict[str, Event], + message_interrupted: Dict[str, asyncio.Event], ): self.log = log self.config_manager = config_manager @@ -173,7 +181,7 @@ def __init__( self.llm: Optional[BaseProvider] = None self.llm_params: Optional[dict] = None - self.llm_chain: Optional[LLMChain] = None + self.llm_chain: Optional[Runnable] = None async def on_message(self, message: HumanChatMessage): """ @@ -471,3 +479,131 @@ def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> Non ) self.broadcast_message(help_message) + + def _start_stream(self, human_msg: HumanChatMessage) -> str: + """ + Sends an `agent-stream` message to indicate the start of a response + stream. Returns the ID of the message, denoted as the `stream_id`. + """ + stream_id = uuid4().hex + stream_msg = AgentStreamMessage( + id=stream_id, + time=time.time(), + body="", + reply_to=human_msg.id, + persona=self.persona, + complete=False, + ) + + self.broadcast_message(stream_msg) + return stream_id + + def _send_stream_chunk( + self, + stream_id: str, + content: str, + complete: bool = False, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Sends an `agent-stream-chunk` message containing content that should be + appended to an existing `agent-stream` message with ID `stream_id`. + """ + if not metadata: + metadata = {} + + stream_chunk_msg = AgentStreamChunkMessage( + id=stream_id, content=content, stream_complete=complete, metadata=metadata + ) + self.broadcast_message(stream_chunk_msg) + + async def stream_reply( + self, + input: Input, + human_msg: HumanChatMessage, + config: Optional[RunnableConfig] = None, + ): + """ + Streams a reply to a human message by invoking + `self.llm_chain.astream()`. A LangChain `Runnable` instance must be + bound to `self.llm_chain` before invoking this method. + + Arguments + --------- + - `input`: The input to your runnable. The type of `input` depends on + the runnable in `self.llm_chain`, but is usually a dictionary whose keys + refer to input variables in your prompt template. + + - `human_msg`: The `HumanChatMessage` being replied to. + + - `config` (optional): A `RunnableConfig` object that specifies + additional configuration when streaming from the runnable. + """ + assert self.llm_chain + assert isinstance(self.llm_chain, Runnable) + + received_first_chunk = False + metadata_handler = MetadataCallbackHandler() + base_config: RunnableConfig = { + "configurable": {"last_human_msg": human_msg}, + "callbacks": [metadata_handler], + } + merged_config: RunnableConfig = merge_runnable_configs(base_config, config) + + # start with a pending message + with self.pending("Generating response", human_msg) as pending_message: + # stream response in chunks. this works even if a provider does not + # implement streaming, as `astream()` defaults to yielding `_call()` + # when `_stream()` is not implemented on the LLM class. + chunk_generator = self.llm_chain.astream(input, config=merged_config) + stream_interrupted = False + async for chunk in chunk_generator: + if not received_first_chunk: + # when receiving the first chunk, close the pending message and + # start the stream. + self.close_pending(pending_message) + stream_id = self._start_stream(human_msg=human_msg) + received_first_chunk = True + self.message_interrupted[stream_id] = asyncio.Event() + + if self.message_interrupted[stream_id].is_set(): + try: + # notify the model provider that streaming was interrupted + # (this is essential to allow the model to stop generating) + # + # note: `mypy` flags this line, claiming that `athrow` is + # not defined on `AsyncIterator`. This is why an ignore + # comment is placed here. + await chunk_generator.athrow( # type:ignore[attr-defined] + GenerationInterrupted() + ) + except GenerationInterrupted: + # do not let the exception bubble up in case if + # the provider did not handle it + pass + stream_interrupted = True + break + + if isinstance(chunk, AIMessageChunk) and isinstance(chunk.content, str): + self._send_stream_chunk(stream_id, chunk.content) + elif isinstance(chunk, str): + self._send_stream_chunk(stream_id, chunk) + else: + self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}") + break + + # complete stream after all chunks have been streamed + stream_tombstone = ( + "\n\n(AI response stopped by user)" if stream_interrupted else "" + ) + self._send_stream_chunk( + stream_id, + stream_tombstone, + complete=True, + metadata=metadata_handler.jai_metadata, + ) + del self.message_interrupted[stream_id] + + +class GenerationInterrupted(asyncio.CancelledError): + """Exception raised when streaming is cancelled by the user""" diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 46606d994..266ad73ad 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -1,16 +1,8 @@ import asyncio -import time -from typing import Any, Dict, Type -from uuid import uuid4 +from typing import Dict, Type -from jupyter_ai.callback_handlers import MetadataCallbackHandler -from jupyter_ai.models import ( - AgentStreamChunkMessage, - AgentStreamMessage, - HumanChatMessage, -) +from jupyter_ai.models import HumanChatMessage from jupyter_ai_magics.providers import BaseProvider -from langchain_core.messages import AIMessageChunk from langchain_core.runnables import ConfigurableFieldSpec from langchain_core.runnables.history import RunnableWithMessageHistory @@ -18,10 +10,6 @@ from .base import BaseChatHandler, SlashCommandRoutingType -class GenerationInterrupted(asyncio.CancelledError): - """Exception raised when streaming is cancelled by the user""" - - class DefaultChatHandler(BaseChatHandler): id = "default" name = "Default" @@ -65,55 +53,8 @@ def create_llm_chain( ) self.llm_chain = runnable - def _start_stream(self, human_msg: HumanChatMessage) -> str: - """ - Sends an `agent-stream` message to indicate the start of a response - stream. Returns the ID of the message, denoted as the `stream_id`. - """ - stream_id = uuid4().hex - stream_msg = AgentStreamMessage( - id=stream_id, - time=time.time(), - body="", - reply_to=human_msg.id, - persona=self.persona, - complete=False, - ) - - for handler in self._root_chat_handlers.values(): - if not handler: - continue - - handler.broadcast_message(stream_msg) - break - - return stream_id - - def _send_stream_chunk( - self, - stream_id: str, - content: str, - complete: bool = False, - metadata: Dict[str, Any] = {}, - ): - """ - Sends an `agent-stream-chunk` message containing content that should be - appended to an existing `agent-stream` message with ID `stream_id`. - """ - stream_chunk_msg = AgentStreamChunkMessage( - id=stream_id, content=content, stream_complete=complete, metadata=metadata - ) - - for handler in self._root_chat_handlers.values(): - if not handler: - continue - - handler.broadcast_message(stream_chunk_msg) - break - async def process_message(self, message: HumanChatMessage): self.get_llm_chain() - received_first_chunk = False assert self.llm_chain inputs = {"input": message.body} @@ -127,60 +68,7 @@ async def process_message(self, message: HumanChatMessage): inputs["context"] = context_prompt inputs["input"] = self.replace_prompt(inputs["input"]) - # start with a pending message - with self.pending("Generating response", message) as pending_message: - # stream response in chunks. this works even if a provider does not - # implement streaming, as `astream()` defaults to yielding `_call()` - # when `_stream()` is not implemented on the LLM class. - metadata_handler = MetadataCallbackHandler() - chunk_generator = self.llm_chain.astream( - inputs, - config={ - "configurable": {"last_human_msg": message}, - "callbacks": [metadata_handler], - }, - ) - stream_interrupted = False - async for chunk in chunk_generator: - if not received_first_chunk: - # when receiving the first chunk, close the pending message and - # start the stream. - self.close_pending(pending_message) - stream_id = self._start_stream(human_msg=message) - received_first_chunk = True - self.message_interrupted[stream_id] = asyncio.Event() - - if self.message_interrupted[stream_id].is_set(): - try: - # notify the model provider that streaming was interrupted - # (this is essential to allow the model to stop generating) - await chunk_generator.athrow(GenerationInterrupted()) - except GenerationInterrupted: - # do not let the exception bubble up in case if - # the provider did not handle it - pass - stream_interrupted = True - break - - if isinstance(chunk, AIMessageChunk) and isinstance(chunk.content, str): - self._send_stream_chunk(stream_id, chunk.content) - elif isinstance(chunk, str): - self._send_stream_chunk(stream_id, chunk) - else: - self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}") - break - - # complete stream after all chunks have been streamed - stream_tombstone = ( - "\n\n(AI response stopped by user)" if stream_interrupted else "" - ) - self._send_stream_chunk( - stream_id, - stream_tombstone, - complete=True, - metadata=metadata_handler.jai_metadata, - ) - del self.message_interrupted[stream_id] + await self.stream_reply(inputs, message) async def make_context_prompt(self, human_msg: HumanChatMessage) -> str: return "\n\n".join( diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py index 1056e592c..4daf70e03 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py @@ -75,7 +75,11 @@ def create_llm_chain( llm = provider(**unified_parameters) self.llm = llm - self.llm_chain = LLMChain(llm=llm, prompt=FIX_PROMPT_TEMPLATE, verbose=True) + # TODO: migrate this class to use a LCEL `Runnable` instead of + # `Chain`, then remove the below ignore comment. + self.llm_chain = LLMChain( # type:ignore[arg-type] + llm=llm, prompt=FIX_PROMPT_TEMPLATE, verbose=True + ) async def process_message(self, message: HumanChatMessage): if not (message.selection and message.selection.type == "cell-with-error"): @@ -94,7 +98,9 @@ async def process_message(self, message: HumanChatMessage): self.get_llm_chain() with self.pending("Analyzing error", message): assert self.llm_chain - response = await self.llm_chain.apredict( + # TODO: migrate this class to use a LCEL `Runnable` instead of + # `Chain`, then remove the below ignore comment. + response = await self.llm_chain.apredict( # type:ignore[attr-defined] extra_instructions=extra_instructions, stop=["\nHuman:"], cell_content=selection.source,