diff --git a/.changeset/thin-glasses-serve.md b/.changeset/thin-glasses-serve.md new file mode 100644 index 0000000000000..46bf355cf2103 --- /dev/null +++ b/.changeset/thin-glasses-serve.md @@ -0,0 +1,10 @@ +--- +"@gradio/client": minor +"@gradio/core": minor +"@gradio/statustracker": minor +"@self/app": minor +"@self/spa": minor +"gradio": minor +--- + +feat:Allow setting title in gr.Info/Warning/Error diff --git a/client/js/src/helpers/api_info.ts b/client/js/src/helpers/api_info.ts index b3e5a655a2dc5..6e932a66c6525 100644 --- a/client/js/src/helpers/api_info.ts +++ b/client/js/src/helpers/api_info.ts @@ -342,6 +342,7 @@ export function handle_message( type: "update", status: { queue, + title: data.output.title as string, message: data.output.error as string, visible: data.output.visible as boolean, duration: data.output.duration as number, diff --git a/client/js/src/types.ts b/client/js/src/types.ts index ac4f2b54721db..fa7d5c7f198bb 100644 --- a/client/js/src/types.ts +++ b/client/js/src/types.ts @@ -329,6 +329,7 @@ export type GradioEvent = { export interface Log { log: string; + title: string; level: "warning" | "info"; } export interface Render { @@ -351,6 +352,7 @@ export interface Status { size?: number; position?: number; eta?: number; + title?: string; message?: string; progress_data?: { progress: number | null; diff --git a/client/js/src/utils/submit.ts b/client/js/src/utils/submit.ts index f5863b478e45f..8b976ba016fe9 100644 --- a/client/js/src/utils/submit.ts +++ b/client/js/src/utils/submit.ts @@ -350,6 +350,7 @@ export function submit( } else if (type === "log") { fire_event({ type: "log", + title: data.title, log: data.log, level: data.level, endpoint: _endpoint, @@ -485,6 +486,7 @@ export function submit( } else if (type === "log") { fire_event({ type: "log", + title: data.title, log: data.log, level: data.level, endpoint: _endpoint, @@ -645,6 +647,7 @@ export function submit( } else if (type === "log") { fire_event({ type: "log", + title: data.title, log: data.log, level: data.level, endpoint: _endpoint, diff --git a/gradio/exceptions.py b/gradio/exceptions.py index c750bce3eb2dd..7c015592e5e1e 100644 --- a/gradio/exceptions.py +++ b/gradio/exceptions.py @@ -80,13 +80,16 @@ def __init__( message: str = "Error raised.", duration: float | None = 10, visible: bool = True, + title: str = "Error", ): """ Parameters: message: The error message to be displayed to the user. Can be HTML, which will be rendered in the modal. duration: The duration in seconds to display the error message. If None or 0, the error message will be displayed until the user closes it. visible: Whether the error message should be displayed in the UI. + title: The title to be displayed to the user at the top of the error modal. """ + self.title = title self.message = message self.duration = duration self.visible = visible diff --git a/gradio/helpers.py b/gradio/helpers.py index 95ae1b2b40b5c..37991c7dc144a 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -1036,6 +1036,7 @@ def skip() -> dict: def log_message( message: str, + title: str, level: Literal["info", "warning"] = "info", duration: float | None = 10, visible: bool = True, @@ -1053,13 +1054,21 @@ def log_message( warnings.warn(message) return blocks._queue.log_message( - event_id=event_id, log=message, level=level, duration=duration, visible=visible + event_id=event_id, + log=message, + title=title, + level=level, + duration=duration, + visible=visible, ) @document(documentation_group="modals") def Warning( # noqa: N802 - message: str = "Warning issued.", duration: float | None = 10, visible: bool = True + message: str = "Warning issued.", + duration: float | None = 10, + visible: bool = True, + title: str = "Warning", ): """ This function allows you to pass custom warning messages to the user. You can do so simply by writing `gr.Warning('message here')` in your function, and when that line is executed the custom message will appear in a modal on the demo. The modal is yellow by default and has the heading: "Warning." Queue must be enabled for this behavior; otherwise, the warning will be printed to the console using the `warnings` library. @@ -1068,6 +1077,7 @@ def Warning( # noqa: N802 message: The warning message to be displayed to the user. Can be HTML, which will be rendered in the modal. duration: The duration in seconds that the warning message should be displayed for. If None or 0, the message will be displayed indefinitely until the user closes it. visible: Whether the error message should be displayed in the UI. + title: The title to be displayed to the user at the top of the modal. Example: import gradio as gr def hello_world(): @@ -1078,7 +1088,9 @@ def hello_world(): demo.load(hello_world, inputs=None, outputs=[md]) demo.queue().launch() """ - log_message(message, level="warning", duration=duration, visible=visible) + log_message( + message, title=title, level="warning", duration=duration, visible=visible + ) @document(documentation_group="modals") @@ -1086,6 +1098,7 @@ def Info( # noqa: N802 message: str = "Info issued.", duration: float | None = 10, visible: bool = True, + title: str = "Info", ): """ This function allows you to pass custom info messages to the user. You can do so simply by writing `gr.Info('message here')` in your function, and when that line is executed the custom message will appear in a modal on the demo. The modal is gray by default and has the heading: "Info." Queue must be enabled for this behavior; otherwise, the message will be printed to the console. @@ -1094,6 +1107,7 @@ def Info( # noqa: N802 message: The info message to be displayed to the user. Can be HTML, which will be rendered in the modal. duration: The duration in seconds that the info message should be displayed for. If None or 0, the message will be displayed indefinitely until the user closes it. visible: Whether the error message should be displayed in the UI. + title: The title to be displayed to the user at the top of the modal. Example: import gradio as gr def hello_world(): @@ -1104,4 +1118,4 @@ def hello_world(): demo.load(hello_world, inputs=None, outputs=[md]) demo.queue().launch() """ - log_message(message, level="info", duration=duration, visible=visible) + log_message(message, title=title, level="info", duration=duration, visible=visible) diff --git a/gradio/queueing.py b/gradio/queueing.py index 1a71fd8e8bf25..b98f85a6801d3 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -394,6 +394,7 @@ def log_message( self, event_id: str, log: str, + title: str, level: Literal["info", "warning"], duration: float | None = 10, visible: bool = True, @@ -406,6 +407,7 @@ def log_message( level=level, duration=duration, visible=visible, + title=title, ) self.send_message(event, log_message) @@ -644,6 +646,7 @@ async def process_events( event, ProcessCompletedMessage( output=content, + title=content["title"], # type: ignore success=False, ), ) diff --git a/gradio/server_messages.py b/gradio/server_messages.py index 0c746bef886c0..7eadc4437d2a5 100644 --- a/gradio/server_messages.py +++ b/gradio/server_messages.py @@ -28,6 +28,7 @@ class LogMessage(BaseMessage): level: Literal["info", "warning"] duration: Optional[float] = 10 visible: bool = True + title: str class EstimationMessage(BaseMessage): @@ -46,6 +47,7 @@ class ProcessCompletedMessage(BaseMessage): msg: Literal[ServerMessage.process_completed] = ServerMessage.process_completed # type: ignore output: dict success: bool + title: Optional[str] = None class ProcessGeneratingMessage(BaseMessage): diff --git a/gradio/utils.py b/gradio/utils.py index 977695c6425fb..7e9a75b1a9620 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -1409,10 +1409,13 @@ def error_payload( content: dict[str, bool | str | float | None] = {"error": None} show_error = show_error or isinstance(error, Error) if show_error: - content["error"] = str(error) - if isinstance(error, Error): - content["duration"] = error.duration - content["visible"] = error.visible + if isinstance(error, Error): + content["error"] = error.message + content["duration"] = error.duration + content["visible"] = error.visible + content["title"] = error.title + else: + content["error"] = str(error) return content diff --git a/js/app/src/routes/[...catchall]/+page.svelte b/js/app/src/routes/[...catchall]/+page.svelte index 813eb3617e734..c83be0b09389c 100644 --- a/js/app/src/routes/[...catchall]/+page.svelte +++ b/js/app/src/routes/[...catchall]/+page.svelte @@ -172,7 +172,7 @@ let url = new URL(`http://${host}${app.api_prefix}/dev/reload`); stream = new EventSource(url); stream.addEventListener("error", async (e) => { - new_message_fn("Error reloading app", "error"); + new_message_fn("Error", "Error reloading app", "error"); // @ts-ignore console.error(JSON.parse(e.data)); }); @@ -195,7 +195,7 @@ } }); - let new_message_fn: (message: string, type: string) => void; + let new_message_fn: (title: string, message: string, type: string) => void; onMount(async () => { intersecting = create_intersection_store(); diff --git a/js/core/src/Blocks.svelte b/js/core/src/Blocks.svelte index 13f0cb8d9e1bf..9d5872c8aec96 100644 --- a/js/core/src/Blocks.svelte +++ b/js/core/src/Blocks.svelte @@ -152,6 +152,7 @@ let messages: (ToastMessage & { fn_index: number })[] = []; function new_message( + title: string, message: string, fn_index: number, type: ToastMessage["type"], @@ -159,6 +160,7 @@ visible = true ): ToastMessage & { fn_index: number } { return { + title, message, fn_index, type, @@ -169,10 +171,11 @@ } export function add_new_message( + title: string, message: string, type: ToastMessage["type"] ): void { - messages = [new_message(message, -1, type), ...messages]; + messages = [new_message(title, message, -1, type), ...messages]; } let _error_id = -1; @@ -241,7 +244,7 @@ if (inputs_waiting.length > 0) { for (const input of inputs_waiting) { if (dep.inputs.includes(input)) { - add_new_message(WAITING_FOR_INPUTS_MESSAGE, "warning"); + add_new_message("Warning", WAITING_FOR_INPUTS_MESSAGE, "warning"); return; } } @@ -346,7 +349,10 @@ ); } catch (e) { const fn_index = 0; // Mock value for fn_index - messages = [new_message(String(e), fn_index, "error"), ...messages]; + messages = [ + new_message("Error", String(e), fn_index, "error"), + ...messages + ]; loading_status.update({ status: "error", fn_index, @@ -413,9 +419,9 @@ } function handle_log(msg: LogMessage): void { - const { log, fn_index, level, duration, visible } = msg; + const { title, log, fn_index, level, duration, visible } = msg; messages = [ - new_message(log, fn_index, level, duration, visible), + new_message(title, log, fn_index, level, duration, visible), ...messages ]; } @@ -463,7 +469,7 @@ ) { showed_duplicate_message = true; messages = [ - new_message(DUPLICATE_MESSAGE, fn_index, "warning"), + new_message("Warning", DUPLICATE_MESSAGE, fn_index, "warning"), ...messages ]; } @@ -475,7 +481,7 @@ ) { showed_mobile_warning = true; messages = [ - new_message(MOBILE_QUEUE_WARNING, fn_index, "warning"), + new_message("Warning", MOBILE_QUEUE_WARNING, fn_index, "warning"), ...messages ]; } @@ -503,7 +509,7 @@ if (status.broken && is_mobile_device && user_left_page) { window.setTimeout(() => { messages = [ - new_message(MOBILE_RECONNECT_MESSAGE, fn_index, "error"), + new_message("Error", MOBILE_RECONNECT_MESSAGE, fn_index, "error"), ...messages ]; }, 0); @@ -515,8 +521,10 @@ MESSAGE_QUOTE_RE, (_, b) => b ); + const _title = status.title ?? "Error"; messages = [ new_message( + _title, _message, fn_index, "error", @@ -612,8 +620,10 @@ if (event === "share") { const { title, description } = data as ShareData; trigger_share(title, description); - } else if (event === "error" || event === "warning") { - messages = [new_message(data, -1, event), ...messages]; + } else if (event === "error") { + messages = [new_message("Error", data, -1, event), ...messages]; + } else if (event === "warning") { + messages = [new_message("Warning", data, -1, event), ...messages]; } else if (event == "clear_status") { update_status(id, "complete", data); } else if (event == "close_stream") { diff --git a/js/spa/src/Index.svelte b/js/spa/src/Index.svelte index 9c0620ed3d4de..1414068b8b0fd 100644 --- a/js/spa/src/Index.svelte +++ b/js/spa/src/Index.svelte @@ -319,7 +319,7 @@ let url = new URL(`http://${host}${app.api_prefix}/dev/reload`); stream = new EventSource(url); stream.addEventListener("error", async (e) => { - new_message_fn("Error reloading app", "error"); + new_message_fn("Error", "Error reloading app", "error"); // @ts-ignore console.error(JSON.parse(e.data)); }); @@ -400,7 +400,7 @@ } }; - let new_message_fn: (message: string, type: string) => void; + let new_message_fn: (title: string, message: string, type: string) => void; onMount(async () => { intersecting.register(_id, wrapper); diff --git a/js/statustracker/static/Toast.svelte b/js/statustracker/static/Toast.svelte index 7139ef5dbe661..8cf2cf79a00ad 100644 --- a/js/statustracker/static/Toast.svelte +++ b/js/statustracker/static/Toast.svelte @@ -17,9 +17,17 @@