Skip to content

Commit

Permalink
Refactor Chatinterface to use Chatbot instead of gr.State variables (#…
Browse files Browse the repository at this point in the history
…8847)

* add code

* add changeset

* fix

* fix multimodal case

* e2e test fix

* fix: wrong named param check for js client (#8820)

* fix: wrong named param check for js client

* rearrange type imports

* add changeset

* fix: workaround for undefined endpoint_info

---------

Co-authored-by: Hannah <hannahblair@users.noreply.github.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* trigger ci

* Add code

* Fix

* Add code

* code

* code'

* Add code

* Wait for upload

* fix tests

* trigger ci

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: JacobLinCool <jacoblincool@gmail.com>
Co-authored-by: Hannah <hannahblair@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
5 people authored Aug 5, 2024
1 parent f32ed12 commit 4d8a473
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 81 deletions.
6 changes: 6 additions & 0 deletions .changeset/happy-lands-matter.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@gradio/chatbot": patch
"gradio": patch
---

fix:Refactor Chatinterface to use Chatbot instead of gr.State variables
6 changes: 6 additions & 0 deletions .changeset/odd-apes-run.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@gradio/client": patch
"gradio": patch
---

fix:fix: wrong named param check for js client
97 changes: 38 additions & 59 deletions gradio/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def _setup_events(self) -> None:
if self.submit_btn
else [self.textbox.submit]
)

submit_event = (
on(
submit_triggers,
Expand All @@ -351,15 +352,15 @@ def _setup_events(self) -> None:
)
.then(
self._display_input,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
[self.saved_input, self.chatbot],
[self.chatbot],
show_api=False,
queue=False,
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state] + self.additional_inputs,
[self.chatbot, self.chatbot_state],
[self.saved_input, self.chatbot] + self.additional_inputs,
[self.chatbot],
show_api=False,
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
Expand All @@ -375,22 +376,22 @@ def _setup_events(self) -> None:
retry_event = (
self.retry_btn.click(
self._delete_prev_fn,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.saved_input, self.chatbot_state],
[self.saved_input, self.chatbot],
[self.chatbot, self.saved_input],
show_api=False,
queue=False,
)
.then(
self._display_input,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
[self.saved_input, self.chatbot],
[self.chatbot],
show_api=False,
queue=False,
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state] + self.additional_inputs,
[self.chatbot, self.chatbot_state],
[self.saved_input, self.chatbot] + self.additional_inputs,
[self.chatbot],
show_api=False,
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
Expand All @@ -411,8 +412,8 @@ async def format_textbox(data: str | MultimodalData) -> str | dict:
if self.undo_btn:
self.undo_btn.click(
self._delete_prev_fn,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.saved_input, self.chatbot_state],
[self.saved_input, self.chatbot],
[self.chatbot, self.saved_input],
show_api=False,
queue=False,
).then(
Expand All @@ -427,7 +428,7 @@ async def format_textbox(data: str | MultimodalData) -> str | dict:
self.clear_btn.click(
async_lambda(lambda: ([], [], None)),
None,
[self.chatbot, self.chatbot_state, self.saved_input],
[self.chatbot, self.saved_input],
queue=False,
show_api=False,
)
Expand Down Expand Up @@ -567,7 +568,7 @@ async def _display_input(
history.append([message, None]) # type: ignore
elif isinstance(message, str) and self.type == "messages":
history.append({"role": "user", "content": message}) # type: ignore
return history, history # type: ignore
return history # type: ignore

def response_as_dict(self, response: MessageDict | Message | str) -> MessageDict:
if isinstance(response, Message):
Expand Down Expand Up @@ -613,13 +614,11 @@ async def _submit_fn(
else:
new_response = response

if self.multimodal and isinstance(message, MultimodalData):
self._append_multimodal_history(message, new_response, history) # type: ignore
elif isinstance(message, str) and self.type == "tuples":
history.append([message, new_response]) # type: ignore
elif isinstance(message, str) and self.type == "messages":
history.extend([{"role": "user", "content": message}, new_response]) # type: ignore
return history, history # type: ignore
if self.type == "tuples":
history_with_input[-1][1] = new_response # type: ignore
elif self.type == "messages":
history_with_input.append(new_response) # type: ignore
return history_with_input # type: ignore

async def _stream_fn(
self,
Expand Down Expand Up @@ -657,40 +656,23 @@ async def _stream_fn(
and isinstance(message, MultimodalData)
and self.type == "tuples"
):
for x in message.files:
history.append([(x,), None]) # type: ignore
update = history + [[message.text, first_response]]
yield update, update
history_with_input[-1][1] = first_response # type: ignore
yield history_with_input
elif (
self.multimodal
and isinstance(message, MultimodalData)
and self.type == "messages"
):
for x in message.files:
history.append(
{"role": "user", "content": cast(FileDataDict, x.model_dump())} # type: ignore
)
update = history + [
{"role": "user", "content": message.text},
first_response,
]
yield update, update
history_with_input.append(first_response) # type: ignore
yield history_with_input
elif self.type == "tuples":
update = history + [[message, first_response]]
yield update, update
history_with_input[-1][1] = first_response # type: ignore
yield history_with_input
else:
update = history + [
{"role": "user", "content": message},
first_response,
]
yield update, update
history_with_input.append(first_response) # type: ignore
yield history_with_input
except StopIteration:
if self.multimodal and isinstance(message, MultimodalData):
self._append_multimodal_history(message, None, history)
yield history, history
else:
update = history + [[message, None]]
yield update, update
yield history_with_input
async for response in generator:
if self.type == "messages":
response = self.response_as_dict(response)
Expand All @@ -699,24 +681,21 @@ async def _stream_fn(
and isinstance(message, MultimodalData)
and self.type == "tuples"
):
update = history + [[message.text, response]]
yield update, update
history_with_input[-1][1] = response # type: ignore
yield history_with_input
elif (
self.multimodal
and isinstance(message, MultimodalData)
and self.type == "messages"
):
update = history + [
{"role": "user", "content": message.text},
response,
]
yield update, update
history_with_input[-1] = response # type: ignore
yield history_with_input
elif self.type == "tuples":
update = history + [[message, response]]
yield update, update
history_with_input[-1][1] = response # type: ignore
yield history_with_input
else:
update = history + [{"role": "user", "content": message}, response]
yield update, update
history_with_input[-1] = response # type: ignore
yield history_with_input

async def _api_submit_fn(
self,
Expand Down Expand Up @@ -833,4 +812,4 @@ async def _delete_prev_fn(
history = history[:-remove_input]
else:
history = history[: -(1 + extra)]
return history, message or "", history
return history, message or "" # type: ignore
17 changes: 9 additions & 8 deletions gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,13 @@ def _is_gr_no_reload(expr: ast.AST) -> bool:
return code_removed


def _find_module(source_file: Path) -> ModuleType:
def _find_module(source_file: Path) -> ModuleType | None:
for s, v in sys.modules.items():
if s not in {"__main__", "__mp_main__"} and getattr(v, "__file__", None) == str(
source_file
):
return v
raise ValueError(f"Cannot find module for source file: {source_file}")
return None


def watchfn(reloader: SourceFileReloader):
Expand Down Expand Up @@ -267,12 +267,13 @@ def iter_py_files() -> Iterator[Path]:
changed_in_copy = _remove_no_reload_codeblocks(str(changed))
if changed != reloader.demo_file:
changed_module = _find_module(changed)
exec(changed_in_copy, changed_module.__dict__)
top_level_parent = sys.modules[
changed_module.__name__.split(".")[0]
]
if top_level_parent != changed_module:
importlib.reload(top_level_parent)
if changed_module:
exec(changed_in_copy, changed_module.__dict__)
top_level_parent = sys.modules[
changed_module.__name__.split(".")[0]
]
if top_level_parent != changed_module:
importlib.reload(top_level_parent)

changed_demo_file = _remove_no_reload_codeblocks(
str(reloader.demo_file)
Expand Down
2 changes: 1 addition & 1 deletion js/app/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"test:snapshot": "pnpm exec playwright test snapshots/ --config=../../.config/playwright.config.js",
"test:browser": "pnpm exec playwright test test/ --grep-invert 'reload.spec.ts' --config=../../.config/playwright.config.js",
"test:browser:dev": "pnpm exec playwright test test/ --ui --config=../../.config/playwright.config.js",
"test:browser:reload": "pnpm exec playwright test test/ --grep 'reload.spec.ts' --config=../../.config/playwright.config.js",
"test:browser:reload": "CI=1 pnpm exec playwright test test/ --grep 'reload.spec.ts' --config=../../.config/playwright.config.js",
"test:browser:lite": "GRADIO_E2E_TEST_LITE=1 pnpm test:browser",
"test:browser:lite:dev": "GRADIO_E2E_TEST_LITE=1 pnpm test:browser:dev",
"build:css": "pollen -c pollen.config.cjs -o src/pollen-dev.css"
Expand Down
3 changes: 3 additions & 0 deletions js/app/test/chatbot_multimodal.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ for (const msg_format of ["tuples", "messages"]) {
await page.getByTestId("upload-button").click();
const fileChooser = await fileChooserPromise;
await fileChooser.setFiles("./test/files/cheetah1.jpg");
await expect(page.locator(".thumbnail-item")).toBeVisible();
await page.getByTestId("textbox").click();
await page.keyboard.press("Enter");

Expand Down Expand Up @@ -70,6 +71,7 @@ for (const msg_format of ["tuples", "messages"]) {
await page.getByTestId("upload-button").click();
const fileChooser = await fileChooserPromise;
await fileChooser.setFiles("../../test/test_files/audio_sample.wav");
await expect(page.locator(".thumbnail-item")).toBeVisible();
await page.getByTestId("textbox").click();
await page.keyboard.press("Enter");

Expand All @@ -95,6 +97,7 @@ for (const msg_format of ["tuples", "messages"]) {
await page.getByTestId("upload-button").click();
const fileChooser = await fileChooserPromise;
await fileChooser.setFiles("../../test/test_files/video_sample.mp4");
await expect(page.locator(".thumbnail-item")).toBeVisible();
await page.getByTestId("textbox").click();
await page.keyboard.press("Enter");

Expand Down
12 changes: 6 additions & 6 deletions js/app/test/hello_world.reload.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ test("gradio dev mode correctly reloads the page", async ({ page }) => {
test.setTimeout(20 * 1000);

try {
const port = 7880;
const { _process: server_process } = await launch_app_background(
`GRADIO_SERVER_PORT=${port} gradio ${join(process.cwd(), "run.py")}`,
process.cwd()
);
const { _process: server_process, port: port } =
await launch_app_background(
`gradio ${join(process.cwd(), "run.py")}`,
process.cwd()
);
_process = server_process;
console.log("Connected to port", port);
const demo = `
Expand All @@ -64,6 +64,7 @@ if __name__ == "__main__":
`;
// write contents of demo to a local 'run.py' file
await page.goto(`http://localhost:${port}`);
await page.waitForTimeout(2000);
spawnSync(`echo '${demo}' > ${join(process.cwd(), "run.py")}`, {
shell: true,
stdio: "pipe",
Expand All @@ -72,7 +73,6 @@ if __name__ == "__main__":
PYTHONUNBUFFERED: "true"
}
});
//await page.reload();

await page.getByLabel("x").fill("Maria");
await page.getByRole("button", { name: "Submit" }).click();
Expand Down
11 changes: 6 additions & 5 deletions js/app/test/test_chatinterface.reload.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ test("gradio dev mode correctly reloads a stateful ChatInterface demo", async ({
test.setTimeout(20 * 1000);

try {
const port = 7890;
const { _process: server_process } = await launch_app_background(
`GRADIO_SERVER_PORT=${port} gradio ${join(process.cwd(), demo_file)}`,
process.cwd()
);
const { _process: server_process, port: port } =
await launch_app_background(
`gradio ${join(process.cwd(), demo_file)}`,
process.cwd()
);
_process = server_process;
console.log("Connected to port", port);
const demo = `
Expand All @@ -66,6 +66,7 @@ if __name__ == "__main__":
demo.launch()
`;
await page.goto(`http://localhost:${port}`);
await page.waitForTimeout(2000);
spawnSync(`echo '${demo}' > ${join(process.cwd(), demo_file)}`, {
shell: true,
stdio: "pipe",
Expand Down
2 changes: 1 addition & 1 deletion js/app/test/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { spawn } from "node:child_process";
import type { ChildProcess } from "child_process";

export function kill_process(process: ChildProcess) {
process.kill("SIGKILL");
process.kill("SIGTERM");
}

type LaunchAppBackgroundReturn = {
Expand Down
2 changes: 1 addition & 1 deletion js/chatbot/shared/ButtonPanel.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
$: show_download =
!Array.isArray(message) &&
is_component_message(message) &&
message.content.value.url;
message.content.value?.url;
</script>

{#if show}
Expand Down

0 comments on commit 4d8a473

Please sign in to comment.