Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow setting title in gr.Info/Warning/Error #9681

Merged
merged 9 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .changeset/thin-glasses-serve.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
"@gradio/client": minor
"@gradio/core": minor
"@gradio/statustracker": minor
"@self/app": minor
"@self/spa": minor
"gradio": minor
---

feat:Allow setting title in gr.Info/Warning/Error
1 change: 1 addition & 0 deletions client/js/src/helpers/api_info.ts
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ export function handle_message(
type: "update",
status: {
queue,
title: data.output.title as string,
message: data.output.error as string,
visible: data.output.visible as boolean,
duration: data.output.duration as number,
Expand Down
2 changes: 2 additions & 0 deletions client/js/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ export type GradioEvent = {

export interface Log {
log: string;
title: string;
level: "warning" | "info";
}
export interface Render {
Expand All @@ -351,6 +352,7 @@ export interface Status {
size?: number;
position?: number;
eta?: number;
title?: string;
message?: string;
progress_data?: {
progress: number | null;
Expand Down
3 changes: 3 additions & 0 deletions client/js/src/utils/submit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ export function submit(
} else if (type === "log") {
fire_event({
type: "log",
title: data.title,
log: data.log,
level: data.level,
endpoint: _endpoint,
Expand Down Expand Up @@ -485,6 +486,7 @@ export function submit(
} else if (type === "log") {
fire_event({
type: "log",
title: data.title,
log: data.log,
level: data.level,
endpoint: _endpoint,
Expand Down Expand Up @@ -645,6 +647,7 @@ export function submit(
} else if (type === "log") {
fire_event({
type: "log",
title: data.title,
log: data.log,
level: data.level,
endpoint: _endpoint,
Expand Down
2 changes: 2 additions & 0 deletions gradio/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,15 @@ def __init__(
message: str = "Error raised.",
duration: float | None = 10,
visible: bool = True,
title: str = "Error",
):
"""
Parameters:
message: The error message to be displayed to the user. Can be HTML, which will be rendered in the modal.
duration: The duration in seconds to display the error message. If None or 0, the error message will be displayed until the user closes it.
visible: Whether the error message should be displayed in the UI.
"""
self.title = title
self.message = message
self.duration = duration
self.visible = visible
Expand Down
19 changes: 15 additions & 4 deletions gradio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,7 @@ def skip() -> dict:

def log_message(
message: str,
title: str,
level: Literal["info", "warning"] = "info",
duration: float | None = 10,
visible: bool = True,
Expand All @@ -1053,13 +1054,21 @@ def log_message(
warnings.warn(message)
return
blocks._queue.log_message(
event_id=event_id, log=message, level=level, duration=duration, visible=visible
event_id=event_id,
title=title,
log=message,
level=level,
duration=duration,
visible=visible,
)


@document(documentation_group="modals")
def Warning( # noqa: N802
message: str = "Warning issued.", duration: float | None = 10, visible: bool = True
title: str = "Warning",
message: str = "Warning issued.",
duration: float | None = 10,
visible: bool = True,
):
"""
This function allows you to pass custom warning messages to the user. You can do so simply by writing `gr.Warning('message here')` in your function, and when that line is executed the custom message will appear in a modal on the demo. The modal is yellow by default and has the heading: "Warning." Queue must be enabled for this behavior; otherwise, the warning will be printed to the console using the `warnings` library.
Expand All @@ -1078,12 +1087,13 @@ def hello_world():
demo.load(hello_world, inputs=None, outputs=[md])
demo.queue().launch()
"""
log_message(message, level="warning", duration=duration, visible=visible)
log_message(title, message, level="warning", duration=duration, visible=visible)


@document(documentation_group="modals")
def Info( # noqa: N802
message: str = "Info issued.",
title: str = "Info",
duration: float | None = 10,
visible: bool = True,
):
Expand All @@ -1092,6 +1102,7 @@ def Info( # noqa: N802
Demos: blocks_chained_events
Parameters:
message: The info message to be displayed to the user. Can be HTML, which will be rendered in the modal.
title: The title to be displayed to the user. Can be HTML, which will be rendered in the modal.
duration: The duration in seconds that the info message should be displayed for. If None or 0, the message will be displayed indefinitely until the user closes it.
visible: Whether the error message should be displayed in the UI.
Example:
Expand All @@ -1104,4 +1115,4 @@ def hello_world():
demo.load(hello_world, inputs=None, outputs=[md])
demo.queue().launch()
"""
log_message(message, level="info", duration=duration, visible=visible)
log_message(title, message, level="info", duration=duration, visible=visible)
3 changes: 3 additions & 0 deletions gradio/queueing.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def log_message(
self,
event_id: str,
log: str,
title: str,
level: Literal["info", "warning"],
duration: float | None = 10,
visible: bool = True,
Expand All @@ -402,6 +403,7 @@ def log_message(
for event in events:
if event._id == event_id:
log_message = LogMessage(
title=title,
log=log,
level=level,
duration=duration,
Expand Down Expand Up @@ -644,6 +646,7 @@ async def process_events(
event,
ProcessCompletedMessage(
output=content,
title=content["title"], # type: ignore
success=False,
),
)
Expand Down
2 changes: 2 additions & 0 deletions gradio/server_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class ProgressMessage(BaseMessage):

class LogMessage(BaseMessage):
msg: Literal[ServerMessage.log] = ServerMessage.log # type: ignore
title: str
log: str
level: Literal["info", "warning"]
duration: Optional[float] = 10
Expand All @@ -44,6 +45,7 @@ class ProcessStartsMessage(BaseMessage):

class ProcessCompletedMessage(BaseMessage):
msg: Literal[ServerMessage.process_completed] = ServerMessage.process_completed # type: ignore
title: Optional[str] = None
output: dict
success: bool

Expand Down
11 changes: 7 additions & 4 deletions gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,10 +1409,13 @@ def error_payload(
content: dict[str, bool | str | float | None] = {"error": None}
show_error = show_error or isinstance(error, Error)
if show_error:
content["error"] = str(error)
if isinstance(error, Error):
content["duration"] = error.duration
content["visible"] = error.visible
if isinstance(error, Error):
content["error"] = error.message
content["duration"] = error.duration
content["visible"] = error.visible
content["title"] = error.title
else:
content["error"] = str(error)
return content


Expand Down
4 changes: 2 additions & 2 deletions js/app/src/routes/[...catchall]/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@
let url = new URL(`http://${host}${app.api_prefix}/dev/reload`);
stream = new EventSource(url);
stream.addEventListener("error", async (e) => {
new_message_fn("Error reloading app", "error");
new_message_fn("Error", "Error reloading app", "error");
// @ts-ignore
console.error(JSON.parse(e.data));
});
Expand All @@ -195,7 +195,7 @@
}
});

let new_message_fn: (message: string, type: string) => void;
let new_message_fn: (title: string, message: string, type: string) => void;

onMount(async () => {
intersecting = create_intersection_store();
Expand Down
30 changes: 20 additions & 10 deletions js/core/src/Blocks.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,15 @@

let messages: (ToastMessage & { fn_index: number })[] = [];
function new_message(
title: string,
message: string,
fn_index: number,
type: ToastMessage["type"],
duration: number | null = 10,
visible = true
): ToastMessage & { fn_index: number } {
return {
title,
message,
fn_index,
type,
Expand All @@ -169,10 +171,11 @@
}

export function add_new_message(
title: string,
message: string,
type: ToastMessage["type"]
): void {
messages = [new_message(message, -1, type), ...messages];
messages = [new_message(title, message, -1, type), ...messages];
}

let _error_id = -1;
Expand Down Expand Up @@ -241,7 +244,7 @@
if (inputs_waiting.length > 0) {
for (const input of inputs_waiting) {
if (dep.inputs.includes(input)) {
add_new_message(WAITING_FOR_INPUTS_MESSAGE, "warning");
add_new_message("Warning", WAITING_FOR_INPUTS_MESSAGE, "warning");
return;
}
}
Expand Down Expand Up @@ -346,7 +349,10 @@
);
} catch (e) {
const fn_index = 0; // Mock value for fn_index
messages = [new_message(String(e), fn_index, "error"), ...messages];
messages = [
new_message("Error", String(e), fn_index, "error"),
...messages
];
loading_status.update({
status: "error",
fn_index,
Expand Down Expand Up @@ -413,9 +419,9 @@
}

function handle_log(msg: LogMessage): void {
const { log, fn_index, level, duration, visible } = msg;
const { title, log, fn_index, level, duration, visible } = msg;
messages = [
new_message(log, fn_index, level, duration, visible),
new_message(title, log, fn_index, level, duration, visible),
...messages
];
}
Expand Down Expand Up @@ -463,7 +469,7 @@
) {
showed_duplicate_message = true;
messages = [
new_message(DUPLICATE_MESSAGE, fn_index, "warning"),
new_message("Warning", DUPLICATE_MESSAGE, fn_index, "warning"),
...messages
];
}
Expand All @@ -475,7 +481,7 @@
) {
showed_mobile_warning = true;
messages = [
new_message(MOBILE_QUEUE_WARNING, fn_index, "warning"),
new_message("Warning", MOBILE_QUEUE_WARNING, fn_index, "warning"),
...messages
];
}
Expand Down Expand Up @@ -503,7 +509,7 @@
if (status.broken && is_mobile_device && user_left_page) {
window.setTimeout(() => {
messages = [
new_message(MOBILE_RECONNECT_MESSAGE, fn_index, "error"),
new_message("Error", MOBILE_RECONNECT_MESSAGE, fn_index, "error"),
...messages
];
}, 0);
Expand All @@ -515,8 +521,10 @@
MESSAGE_QUOTE_RE,
(_, b) => b
);
const _title = status.title ?? "Error";
messages = [
new_message(
_title,
_message,
fn_index,
"error",
Expand Down Expand Up @@ -612,8 +620,10 @@
if (event === "share") {
const { title, description } = data as ShareData;
trigger_share(title, description);
} else if (event === "error" || event === "warning") {
messages = [new_message(data, -1, event), ...messages];
} else if (event === "error") {
messages = [new_message("Error", data, -1, event), ...messages];
} else if (event === "warning") {
messages = [new_message("Warning", data, -1, event), ...messages];
} else if (event == "clear_status") {
update_status(id, "complete", data);
} else if (event == "close_stream") {
Expand Down
4 changes: 2 additions & 2 deletions js/spa/src/Index.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@
let url = new URL(`http://${host}${app.api_prefix}/dev/reload`);
stream = new EventSource(url);
stream.addEventListener("error", async (e) => {
new_message_fn("Error reloading app", "error");
new_message_fn("Error", "Error reloading app", "error");
// @ts-ignore
console.error(JSON.parse(e.data));
});
Expand Down Expand Up @@ -400,7 +400,7 @@
}
};

let new_message_fn: (message: string, type: string) => void;
let new_message_fn: (title: string, message: string, type: string) => void;

onMount(async () => {
intersecting.register(_id, wrapper);
Expand Down
12 changes: 10 additions & 2 deletions js/statustracker/static/Toast.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@
</script>

<div class="toast-wrap">
{#each messages as { type, message, id, duration, visible } (id)}
{#each messages as { type, title, message, id, duration, visible } (id)}
<div animate:flip={{ duration: 300 }} style:width="100%">
<ToastContent {type} {message} {duration} {visible} on:close {id} />
<ToastContent
{type}
{title}
{message}
{duration}
{visible}
on:close
{id}
/>
</div>
{/each}
</div>
Expand Down
5 changes: 3 additions & 2 deletions js/statustracker/static/ToastContent.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import { fade } from "svelte/transition";
import type { ToastMessage } from "./types";

export let title = "";
export let message = "";
export let type: ToastMessage["type"];
export let id: number;
Expand All @@ -27,7 +28,7 @@
}
}
});

$: title = DOMPurify.sanitize(title);
$: message = DOMPurify.sanitize(message);

$: display = visible;
Expand Down Expand Up @@ -73,7 +74,7 @@
</div>

<div class="toast-details {type}">
<div class="toast-title {type}">{type}</div>
<div class="toast-title {type}">{@html title}</div>
<div class="toast-text {type}">
{@html message}
</div>
Expand Down
1 change: 1 addition & 0 deletions js/statustracker/static/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export interface LoadingStatus {

export interface ToastMessage {
type: "error" | "warning" | "info";
title: string;
message: string;
id: number;
duration: number | null;
Expand Down