Skip to content

Commit

Permalink
Add ability to provide preset response options in gr.Chatbot / `gr.…
Browse files Browse the repository at this point in the history
…ChatInterface` (#9989)

* options

* add changeset

* list

* types

* add changeset

* types

* docs

* changes

* more docs

* chatbot

* changes

* changes

* changes

* format

* notebooks

* typedict

* docs

* console logs

* docs

* docs

* styling

* docs

* Update guides/05_chatbots/01_creating-a-chatbot-fast.md

Co-authored-by: Hannah <hannahblair@users.noreply.github.com>

* Update guides/05_chatbots/01_creating-a-chatbot-fast.md

Co-authored-by: Hannah <hannahblair@users.noreply.github.com>

* Update guides/05_chatbots/01_creating-a-chatbot-fast.md

Co-authored-by: Hannah <hannahblair@users.noreply.github.com>

* Update guides/05_chatbots/01_creating-a-chatbot-fast.md

Co-authored-by: Hannah <hannahblair@users.noreply.github.com>

* Update guides/05_chatbots/02_chat_interface_examples.md

Co-authored-by: Hannah <hannahblair@users.noreply.github.com>

* Update guides/05_chatbots/01_creating-a-chatbot-fast.md

Co-authored-by: Hannah <hannahblair@users.noreply.github.com>

* restore

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Hannah <hannahblair@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 22, 2024
1 parent 74f22d5 commit 369a44e
Show file tree
Hide file tree
Showing 16 changed files with 447 additions and 206 deletions.
6 changes: 6 additions & 0 deletions .changeset/orange-cobras-suffer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@gradio/chatbot": minor
"gradio": minor
---

feat:Add ability to provide preset response options in `gr.Chatbot` / `gr.ChatInterface`
1 change: 1 addition & 0 deletions demo/chatinterface_options/run.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_options"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "example_code = \"\"\"\n", "Here's the code I generated:\n", "\n", "```python\n", "def greet(x):\n", " return f\"Hello, {x}!\"\n", "```\n", "\n", "Is this correct?\n", "\"\"\"\n", "\n", "def chat(message, history):\n", " if message == \"Yes, that's correct.\":\n", " return \"Great!\"\n", " else:\n", " return {\n", " \"role\": \"assistant\",\n", " \"content\": example_code,\n", " \"options\": [\n", " {\"value\": \"Yes, that's correct.\", \"label\": \"Yes\"},\n", " {\"value\": \"No\"}\n", " ]\n", " }\n", "\n", "demo = gr.ChatInterface(\n", " chat,\n", " type=\"messages\",\n", " examples=[\"Write a Python function that takes a string and returns a greeting.\"]\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
34 changes: 34 additions & 0 deletions demo/chatinterface_options/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import gradio as gr

example_code = """
Here's the code I generated:
```python
def greet(x):
return f"Hello, {x}!"
```
Is this correct?
"""

def chat(message, history):
if message == "Yes, that's correct.":
return "Great!"
else:
return {
"role": "assistant",
"content": example_code,
"options": [
{"value": "Yes, that's correct.", "label": "Yes"},
{"value": "No"}
]
}

demo = gr.ChatInterface(
chat,
type="messages",
examples=["Write a Python function that takes a string and returns a greeting."]
)

if __name__ == "__main__":
demo.launch()
30 changes: 30 additions & 0 deletions gradio/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,24 @@ def _setup_events(self) -> None:
queue=False,
)

self.chatbot.option_select(
self.option_clicked,
[self.chatbot],
[self.chatbot, self.saved_input],
show_api=False,
).then(
submit_fn,
[self.saved_input, self.chatbot],
[self.chatbot],
show_api=False,
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
),
show_progress=cast(
Literal["full", "minimal", "hidden"], self.show_progress
),
)

def _setup_stop_events(
self, event_triggers: list[Callable], events_to_cancel: list[Dependency]
) -> None:
Expand Down Expand Up @@ -686,6 +704,18 @@ async def _stream_fn(
self._append_history(history_with_input, response, first_response=False)
yield history_with_input

def option_clicked(
self, history: list[MessageDict], option: SelectData
) -> tuple[TupleFormat | list[MessageDict], str | MultimodalPostprocess]:
"""
When an option is clicked, the chat history is appended with the option value.
The saved input value is also set to option value. Note that event can only
be called if self.type is "messages" since options are only available for this
chatbot type.
"""
history.append({"role": "user", "content": option.value})
return history, option.value

def example_clicked(
self, example: SelectData
) -> tuple[TupleFormat | list[MessageDict], str | MultimodalPostprocess]:
Expand Down
13 changes: 11 additions & 2 deletions gradio/components/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
Any,
Literal,
Optional,
TypedDict,
Union,
cast,
)

from gradio_client import utils as client_utils
from gradio_client.documentation import document
from pydantic import Field
from typing_extensions import NotRequired
from typing_extensions import NotRequired, TypedDict

from gradio import utils
from gradio.component_meta import ComponentMeta
Expand All @@ -37,6 +36,11 @@ class MetadataDict(TypedDict):
title: Union[str, None]


class Option(TypedDict):
label: NotRequired[str]
value: str


class FileDataDict(TypedDict):
path: str # server filepath
url: NotRequired[Optional[str]] # normalised server url
Expand All @@ -51,6 +55,7 @@ class MessageDict(TypedDict):
content: str | FileDataDict | tuple | Component
role: Literal["user", "assistant", "system"]
metadata: NotRequired[MetadataDict]
options: NotRequired[list[Option]]


class FileMessage(GradioModel):
Expand Down Expand Up @@ -82,6 +87,7 @@ class Message(GradioModel):
role: str
metadata: Metadata = Field(default_factory=Metadata)
content: Union[str, FileMessage, ComponentMessage]
options: Optional[list[Option]] = None


class ExampleMessage(TypedDict):
Expand All @@ -102,6 +108,7 @@ class ChatMessage:
role: Literal["user", "assistant", "system"]
content: str | FileData | Component | FileDataDict | tuple | list
metadata: MetadataDict | Metadata = field(default_factory=Metadata)
options: Optional[list[Option]] = None


class ChatbotDataMessages(GradioRootModel):
Expand Down Expand Up @@ -150,6 +157,7 @@ class Chatbot(Component):
Events.retry,
Events.undo,
Events.example_select,
Events.option_select,
Events.clear,
Events.copy,
]
Expand Down Expand Up @@ -502,6 +510,7 @@ def _postprocess_message_messages(
role=message.role,
content=message.content, # type: ignore
metadata=message.metadata, # type: ignore
options=message.options,
)
elif isinstance(message, Message):
return message
Expand Down
6 changes: 4 additions & 2 deletions gradio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,10 +963,12 @@ class Events:
)
example_select = EventListener(
"example_select",
config_data=lambda: {"example_selectable": False},
callback=lambda block: setattr(block, "example_selectable", True),
doc="This listener is triggered when the user clicks on an example from within the {{ component }}. This event has SelectData of type gradio.SelectData that carries information, accessible through SelectData.index and SelectData.value. See SelectData documentation on how to use this event data.",
)
option_select = EventListener(
"option_select",
doc="This listener is triggered when the user clicks on an option from within the {{ component }}. This event has SelectData of type gradio.SelectData that carries information, accessible through SelectData.index and SelectData.value. See SelectData documentation on how to use this event data.",
)
load = EventListener(
"load",
doc="This listener is triggered when the {{ component }} initially loads in the browser.",
Expand Down
Loading

0 comments on commit 369a44e

Please sign in to comment.