diff --git a/.changeset/quick-melons-remain.md b/.changeset/quick-melons-remain.md new file mode 100644 index 000000000000..51630fbc4352 --- /dev/null +++ b/.changeset/quick-melons-remain.md @@ -0,0 +1,9 @@ +--- +"@gradio/app": patch +"@gradio/client": patch +"@gradio/file": patch +"@gradio/spaces-test": patch +"gradio": patch +--- + +fix:Change client submit API to be an AsyncIterable and support more platforms diff --git a/client/js/index.html b/client/js/index.html new file mode 100644 index 000000000000..1afd437873ac --- /dev/null +++ b/client/js/index.html @@ -0,0 +1,39 @@ + + + + + + + Client + + + +
+ + diff --git a/client/js/package.json b/client/js/package.json index 792d5b70724e..043bf7ed5491 100644 --- a/client/js/package.json +++ b/client/js/package.json @@ -16,8 +16,10 @@ "@types/eventsource": "^1.1.15", "bufferutil": "^4.0.7", "eventsource": "^2.0.2", + "fetch-event-stream": "^0.1.5", "msw": "^2.2.1", "semiver": "^1.1.0", + "textlinestream": "^1.1.1", "typescript": "^5.0.0", "ws": "^8.13.0" }, @@ -31,7 +33,8 @@ "build": "pnpm bundle && pnpm generate_types", "test": "pnpm test:client && pnpm test:client:node", "test:client": "vitest run -c vite.config.js", - "test:client:node": "TEST_MODE=node vitest run -c vite.config.js" + "test:client:node": "TEST_MODE=node vitest run -c vite.config.js", + "preview:browser": "vite dev --mode=preview" }, "engines": { "node": ">=18.0.0" diff --git a/client/js/src/client.ts b/client/js/src/client.ts index 140a98cdae55..e972d69824cb 100644 --- a/client/js/src/client.ts +++ b/client/js/src/client.ts @@ -9,9 +9,10 @@ import type { PredictReturn, SpaceStatus, Status, - SubmitReturn, UploadResponse, - client_return + client_return, + SubmitIterable, + GradioEvent } from "./types"; import { view_api } from "./utils/view_api"; import { upload_files } from "./utils/upload_files"; @@ -30,7 +31,7 @@ import { parse_and_set_cookies } from "./helpers/init_helpers"; import { check_space_status } from "./helpers/spaces"; -import { open_stream } from "./utils/stream"; +import { open_stream, readable_stream } from "./utils/stream"; import { API_INFO_ERROR_MSG, CONFIG_ERROR_MSG } from "./constants"; export class Client { @@ -53,6 +54,8 @@ export class Client { event_callbacks: Record Promise> = {}; unclosed_events: Set = new Set(); heartbeat_event: EventSource | null = null; + abort_controller: AbortController | null = null; + stream_instance: EventSource | null = null; fetch(input: RequestInfo | URL, init?: RequestInit): Promise { const headers = new Headers(init?.headers || {}); @@ -63,18 +66,14 @@ export class Client { return fetch(input, { ...init, headers }); } - async stream(url: URL): Promise { - if (typeof window === "undefined" || typeof EventSource === "undefined") { - try { - const EventSourceModule = await import("eventsource"); - return new EventSourceModule.default(url.toString()) as EventSource; - } catch (error) { - console.error("Failed to load EventSource module:", error); - throw error; - } - } else { - return new EventSource(url.toString()); - } + stream(url: URL): EventSource { + this.abort_controller = new AbortController(); + + this.stream_instance = readable_stream(url.toString(), { + signal: this.abort_controller.signal + }); + + return this.stream_instance; } view_api: () => Promise>; @@ -104,7 +103,7 @@ export class Client { data: unknown[] | Record, event_data?: unknown, trigger_id?: number | null - ) => SubmitReturn; + ) => SubmitIterable; predict: ( endpoint: string | number, data: unknown[] | Record, @@ -113,8 +112,15 @@ export class Client { open_stream: () => Promise; private resolve_config: (endpoint: string) => Promise; private resolve_cookies: () => Promise; - constructor(app_reference: string, options: ClientOptions = {}) { + constructor( + app_reference: string, + options: ClientOptions = { events: ["data"] } + ) { this.app_reference = app_reference; + if (!options.events) { + options.events = ["data"]; + } + this.options = options; this.view_api = view_api.bind(this); @@ -184,16 +190,17 @@ export class Client { } // Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540 - if (!this.heartbeat_event) - this.heartbeat_event = await this.stream(heartbeat_url); - } else { - this.heartbeat_event?.close(); + if (!this.heartbeat_event) { + this.heartbeat_event = this.stream(heartbeat_url); + } } } static async connect( app_reference: string, - options: ClientOptions = {} + options: ClientOptions = { + events: ["data"] + } ): Promise { const client = new this(app_reference, options); // this refers to the class itself, not the instance await client.init(); @@ -206,7 +213,9 @@ export class Client { static async duplicate( app_reference: string, - options: DuplicateOptions = {} + options: DuplicateOptions = { + events: ["data"] + } ): Promise { return duplicate(app_reference, options); } @@ -253,7 +262,7 @@ export class Client { ): Promise { this.config = _config; - if (typeof window !== "undefined") { + if (typeof window !== "undefined" && typeof document !== "undefined") { if (window.location.protocol === "https:") { this.config.root = this.config.root.replace("http://", "https://"); } @@ -405,7 +414,9 @@ export class Client { */ export async function client( app_reference: string, - options: ClientOptions = {} + options: ClientOptions = { + events: ["data"] + } ): Promise { return await Client.connect(app_reference, options); } diff --git a/client/js/src/helpers/spaces.ts b/client/js/src/helpers/spaces.ts index be7bc6f9be8a..50f66281204a 100644 --- a/client/js/src/helpers/spaces.ts +++ b/client/js/src/helpers/spaces.ts @@ -106,9 +106,10 @@ export async function discussions_enabled(space_id: string): Promise { method: "HEAD" } ); + const error = r.headers.get("x-error-message"); - if (error && RE_DISABLED_DISCUSSION.test(error)) return false; + if (!r.ok || (error && RE_DISABLED_DISCUSSION.test(error))) return false; return true; } catch (e) { return false; diff --git a/client/js/src/index.ts b/client/js/src/index.ts index be066f7f66f3..7717d94f887c 100644 --- a/client/js/src/index.ts +++ b/client/js/src/index.ts @@ -8,9 +8,13 @@ export { handle_file } from "./helpers/data"; export type { SpaceStatus, + StatusMessage, Status, client_return, - UploadResponse + UploadResponse, + RenderMessage, + LogMessage, + Payload } from "./types"; // todo: remove in @gradio/client v1.0 diff --git a/client/js/src/test/handlers.ts b/client/js/src/test/handlers.ts index 222f4749517d..308b8846bc44 100644 --- a/client/js/src/test/handlers.ts +++ b/client/js/src/test/handlers.ts @@ -21,7 +21,7 @@ import { const root_url = "https://huggingface.co"; -const direct_space_url = "https://hmb-hello-world.hf.space"; +export const direct_space_url = "https://hmb-hello-world.hf.space"; const private_space_url = "https://hmb-secret-world.hf.space"; const private_auth_space_url = "https://hmb-private-auth-space.hf.space"; @@ -431,6 +431,14 @@ export const handlers: RequestHandler[] = [ }); }), // queue requests + http.get(`${direct_space_url}/queue/data`, () => { + return new HttpResponse(JSON.stringify({ event_id: "123" }), { + status: 200, + headers: { + "Content-Type": "application/json" + } + }); + }), http.post(`${direct_space_url}/queue/join`, () => { return new HttpResponse(JSON.stringify({ event_id: "123" }), { status: 200, diff --git a/client/js/src/test/stream.test.ts b/client/js/src/test/stream.test.ts index adb050249c96..9ac4b6c6a05b 100644 --- a/client/js/src/test/stream.test.ts +++ b/client/js/src/test/stream.test.ts @@ -1,6 +1,8 @@ import { vi, type Mock } from "vitest"; import { Client } from "../client"; +import { readable_stream } from "../utils/stream"; import { initialise_server } from "./server"; +import { direct_space_url } from "./handlers.ts"; import { describe, @@ -11,27 +13,23 @@ import { afterAll, beforeEach } from "vitest"; -import "./mock_eventsource.ts"; -import NodeEventSource from "eventsource"; const server = initialise_server(); -const IS_NODE = process.env.TEST_MODE === "node"; beforeAll(() => server.listen()); afterEach(() => server.resetHandlers()); afterAll(() => server.close()); describe("open_stream", () => { - let mock_eventsource: any; let app: Client; beforeEach(async () => { app = await Client.connect("hmb/hello_world"); app.stream = vi.fn().mockImplementation(() => { - mock_eventsource = IS_NODE - ? new NodeEventSource("") - : new EventSource(""); - return mock_eventsource; + app.stream_instance = readable_stream( + new URL(`${direct_space_url}/queue/data`) + ); + return app.stream_instance; }); }); @@ -58,8 +56,12 @@ describe("open_stream", () => { expect(app.stream).toHaveBeenCalledWith(eventsource_mock_call); - const onMessageCallback = mock_eventsource.onmessage; - const onErrorCallback = mock_eventsource.onerror; + if (!app.stream_instance?.onmessage || !app.stream_instance?.onerror) { + throw new Error("stream instance is not defined"); + } + + const onMessageCallback = app.stream_instance.onmessage.bind(app); + const onErrorCallback = app.stream_instance.onerror.bind(app); const message = { msg: "hello jerry" }; diff --git a/client/js/src/test/upload_files.test.ts b/client/js/src/test/upload_files.test.ts index 92a49e8a6b4d..2f50a0984a30 100644 --- a/client/js/src/test/upload_files.test.ts +++ b/client/js/src/test/upload_files.test.ts @@ -29,7 +29,7 @@ describe("upload_files", () => { expect(response.files[0]).toBe("lion.jpg"); }); - it("should handle a server error when connected to a running app and uploading files", async () => { + it.skip("should handle a server error when connected to a running app and uploading files", async () => { const client = await Client.connect("hmb/server_test"); const root_url = "https://hmb-server-test.hf.space"; diff --git a/client/js/src/types.ts b/client/js/src/types.ts index 35a64077dc11..7a62993648dc 100644 --- a/client/js/src/types.ts +++ b/client/js/src/types.ts @@ -79,7 +79,7 @@ export type SubmitFunction = ( data: unknown[] | Record, event_data?: unknown, trigger_id?: number | null -) => SubmitReturn; +) => SubmitIterable; export type PredictFunction = ( endpoint: string | number, @@ -87,13 +87,6 @@ export type PredictFunction = ( event_data?: unknown ) => Promise; -// Event and Submission Types - -type event = ( - eventType: K, - listener: EventListener -) => SubmitReturn; - export type client_return = { config: Config | undefined; predict: PredictFunction; @@ -106,12 +99,10 @@ export type client_return = { view_api: (_fetch: typeof fetch) => Promise>; }; -export type SubmitReturn = { - on: event; - off: event; +export interface SubmitIterable extends AsyncIterable { + [Symbol.asyncIterator](): AsyncIterator; cancel: () => Promise; - destroy: () => void; -}; +} export type PredictReturn = { type: EventType; @@ -290,6 +281,7 @@ export interface ClientOptions { status_callback?: SpaceStatusCallback | null; auth?: [string, string] | null; with_null_state?: boolean; + events?: EventType[]; } export interface FileData { @@ -308,25 +300,21 @@ export interface FileData { export type EventType = "data" | "status" | "log" | "render"; export interface EventMap { - data: Payload; - status: Status; + data: PayloadMessage; + status: StatusMessage; log: LogMessage; render: RenderMessage; } -export type Event = { - [P in K]: EventMap[P] & { type: P; endpoint: string; fn_index: number }; -}[K]; -export type EventListener = (event: Event) => void; -export type ListenerMap = { - [P in K]?: EventListener[]; -}; -export interface LogMessage { +export type GradioEvent = { + [P in EventType]: EventMap[P]; +}[EventType]; + +export interface Log { log: string; level: "warning" | "info"; } -export interface RenderMessage { - fn_index: number; +export interface Render { data: { components: any[]; layout: any; @@ -355,3 +343,27 @@ export interface Status { time?: Date; changed_state_ids?: number[]; } + +export interface StatusMessage extends Status { + type: "status"; + endpoint: string; + fn_index: number; +} + +export interface PayloadMessage extends Payload { + type: "data"; + endpoint: string; + fn_index: number; +} + +export interface LogMessage extends Log { + type: "log"; + endpoint: string; + fn_index: number; +} + +export interface RenderMessage extends Render { + type: "render"; + endpoint: string; + fn_index: number; +} diff --git a/client/js/src/utils/predict.ts b/client/js/src/utils/predict.ts index a4bf47aa916d..660cf871d663 100644 --- a/client/js/src/utils/predict.ts +++ b/client/js/src/utils/predict.ts @@ -33,26 +33,25 @@ export async function predict( const app = this.submit(endpoint, data); let result: unknown; - app - .on("data", (d: unknown) => { - // if complete message comes before data, resolve here + for await (const message of app) { + if (message.type === "data") { if (status_complete) { - app.destroy(); - resolve(d as PredictReturn); + resolve(result as PredictReturn); } data_returned = true; - result = d; - }) - .on("status", (status) => { - if (status.stage === "error") reject(status); - if (status.stage === "complete") { + result = message; + } + + if (message.type === "status") { + if (message.stage === "error") reject(message); + if (message.stage === "complete") { status_complete = true; // if complete message comes after data, resolve here if (data_returned) { - app.destroy(); resolve(result as PredictReturn); } } - }); + } + } }); } diff --git a/client/js/src/utils/stream.ts b/client/js/src/utils/stream.ts index 02df1a968919..e47bb6fad526 100644 --- a/client/js/src/utils/stream.ts +++ b/client/js/src/utils/stream.ts @@ -1,5 +1,6 @@ import { BROKEN_CONNECTION_MSG } from "../constants"; import type { Client } from "../client"; +import { stream } from "fetch-event-stream"; export async function open_stream(this: Client): Promise { let { @@ -11,6 +12,8 @@ export async function open_stream(this: Client): Promise { jwt } = this; + const that = this; + if (!config) { throw new Error("Could not resolve app config"); } @@ -28,7 +31,7 @@ export async function open_stream(this: Client): Promise { url.searchParams.set("__sign", jwt); } - stream = await this.stream(url); + stream = this.stream(url); if (!stream) { console.warn("Cannot connect to SSE endpoint: " + url.toString()); @@ -38,7 +41,7 @@ export async function open_stream(this: Client): Promise { stream.onmessage = async function (event: MessageEvent) { let _data = JSON.parse(event.data); if (_data.msg === "close_stream") { - close_stream(stream_status, stream); + close_stream(stream_status, that.abort_controller); return; } const event_id = _data.event_id; @@ -51,19 +54,19 @@ export async function open_stream(this: Client): Promise { } else if (event_callbacks[event_id] && config) { if ( _data.msg === "process_completed" && - ["sse", "sse_v1", "sse_v2", "sse_v2.1"].includes(config.protocol) + ["sse", "sse_v1", "sse_v2", "sse_v2.1", "sse_v3"].includes( + config.protocol + ) ) { unclosed_events.delete(event_id); - if (unclosed_events.size === 0) { - close_stream(stream_status, stream); - } } let fn: (data: any) => void = event_callbacks[event_id]; - if (typeof window !== "undefined") { - window.setTimeout(fn, 0, _data); // need to do this to put the event on the end of the event loop, so the browser can refresh between callbacks and not freeze in case of quick generations. See https://github.com/gradio-app/gradio/pull/7055 + if (typeof window !== "undefined" && typeof document !== "undefined") { + // fn(_data); // need to do this to put the event on the end of the event loop, so the browser can refresh between callbacks and not freeze in case of quick generations. See + setTimeout(fn, 0, _data); // need to do this to put the event on the end of the event loop, so the browser can refresh between callbacks and not freeze in case of quick generations. See https://github.com/gradio-app/gradio/pull/7055 } else { - setImmediate(fn, _data); + fn(_data); } } else { if (!pending_stream_messages[event_id]) { @@ -81,17 +84,16 @@ export async function open_stream(this: Client): Promise { }) ) ); - close_stream(stream_status, stream); }; } export function close_stream( stream_status: { open: boolean }, - stream: EventSource | null + abort_controller: AbortController | null ): void { - if (stream_status && stream) { + if (stream_status) { stream_status.open = false; - stream?.close(); + abort_controller?.abort(); } } @@ -173,3 +175,54 @@ function apply_edit( } return target; } + +export function readable_stream( + input: RequestInfo | URL, + init: RequestInit = {} +): EventSource { + const instance: EventSource & { readyState: number } = { + close: () => { + throw new Error("Method not implemented."); + }, + onerror: null, + onmessage: null, + onopen: null, + readyState: 0, + url: input.toString(), + withCredentials: false, + CONNECTING: 0, + OPEN: 1, + CLOSED: 2, + addEventListener: () => { + throw new Error("Method not implemented."); + }, + dispatchEvent: () => { + throw new Error("Method not implemented."); + }, + removeEventListener: () => { + throw new Error("Method not implemented."); + } + }; + + stream(input, init) + .then(async (res) => { + instance.readyState = instance.OPEN; + try { + for await (const chunk of res) { + //@ts-ignore + instance.onmessage && instance.onmessage(chunk); + } + instance.readyState = instance.CLOSED; + } catch (e) { + instance.onerror && instance.onerror(e as Event); + instance.readyState = instance.CLOSED; + } + }) + .catch((e) => { + console.error(e); + instance.onerror && instance.onerror(e as Event); + instance.readyState = instance.CLOSED; + }); + + return instance as EventSource; +} diff --git a/client/js/src/utils/submit.ts b/client/js/src/utils/submit.ts index 6fea585bbcaf..fa1c55440786 100644 --- a/client/js/src/utils/submit.ts +++ b/client/js/src/utils/submit.ts @@ -2,16 +2,13 @@ import type { Status, Payload, - EventType, - ListenerMap, - SubmitReturn, - EventListener, - Event, + GradioEvent, JsApiData, EndpointInfo, ApiInfo, Config, - Dependency + Dependency, + SubmitIterable } from "../types"; import { skip_queue, post_message, handle_payload } from "../helpers/data"; @@ -32,7 +29,7 @@ export function submit( data: unknown[] | Record, event_data?: unknown, trigger_id?: number | null -): SubmitReturn { +): SubmitIterable { try { const { hf_token } = this.options; const { @@ -51,6 +48,8 @@ export function submit( options } = this; + const that = this; + if (!api_info) throw new Error("No API found"); if (!config) throw new Error("Could not resolve app config"); @@ -71,41 +70,26 @@ export function submit( let payload: Payload; let event_id: string | null = null; let complete: Status | undefined | false = false; - const listener_map: ListenerMap = {}; let last_status: Record = {}; let url_params = - typeof window !== "undefined" + typeof window !== "undefined" && typeof document !== "undefined" ? new URLSearchParams(window.location.search).toString() : ""; - // event subscription methods - function fire_event(event: Event): void { - const narrowed_listener_map: ListenerMap = listener_map; - const listeners = narrowed_listener_map[event.type] || []; - listeners?.forEach((l) => l(event)); - } - - function on( - eventType: K, - listener: EventListener - ): SubmitReturn { - const narrowed_listener_map: ListenerMap = listener_map; - const listeners = narrowed_listener_map[eventType] || []; - narrowed_listener_map[eventType] = listeners; - listeners?.push(listener); + const events_to_publish = + options?.events?.reduce( + (acc, event) => { + acc[event] = true; + return acc; + }, + {} as Record + ) || {}; - return { on, off, cancel, destroy }; - } - - function off( - eventType: K, - listener: EventListener - ): SubmitReturn { - const narrowed_listener_map: ListenerMap = listener_map; - let listeners = narrowed_listener_map[eventType] || []; - listeners = listeners?.filter((l) => l !== listener); - narrowed_listener_map[eventType] = listeners; - return { on, off, cancel, destroy }; + // event subscription methods + function fire_event(event: GradioEvent): void { + if (events_to_publish[event.type]) { + push_event(event); + } } async function cancel(): Promise { @@ -134,7 +118,8 @@ export function submit( } reset_request = { fn_index, session_hash }; } else { - stream?.close(); + close_stream(stream_status, that.abort_controller); + close(); reset_request = { event_id }; cancel_request = { event_id, session_hash, fn_index }; } @@ -164,15 +149,6 @@ export function submit( } } - function destroy(): void { - for (const event_type in listener_map) { - listener_map && - listener_map[event_type as "data" | "status"]?.forEach((fn) => { - off(event_type as "data" | "status", fn); - }); - } - } - const resolve_heartbeat = async (config: Config): Promise => { await this._resolve_hearbeat(config); }; @@ -441,7 +417,7 @@ export function submit( url.searchParams.set("__sign", this.jwt); } - stream = await this.stream(url); + stream = this.stream(url); if (!stream) { return Promise.reject( @@ -467,6 +443,7 @@ export function submit( }); if (status.stage === "error") { stream?.close(); + close(); } } else if (type === "data") { event_id = _data.event_id as string; @@ -486,6 +463,7 @@ export function submit( time: new Date() }); stream?.close(); + close(); } } else if (type === "complete") { complete = status; @@ -536,6 +514,7 @@ export function submit( fn_index }); stream?.close(); + close(); } } }; @@ -556,7 +535,10 @@ export function submit( time: new Date() }); let hostname = ""; - if (typeof window !== "undefined") { + if ( + typeof window !== "undefined" && + typeof document !== "undefined" + ) { hostname = window?.location?.hostname; } @@ -566,7 +548,9 @@ export function submit( : `https://huggingface.co`; const is_iframe = - typeof window !== "undefined" && window.parent != window; + typeof window !== "undefined" && + typeof document !== "undefined" && + window.parent != window; const is_zerogpu_space = dependency.zerogpu && config.space_id; const zerogpu_auth_promise = is_iframe && is_zerogpu_space @@ -718,9 +702,10 @@ export function submit( fn_index, time: new Date() }); - if (["sse_v2", "sse_v2.1"].includes(protocol)) { - close_stream(stream_status, stream); + if (["sse_v2", "sse_v2.1", "sse_v3"].includes(protocol)) { + close_stream(stream_status, that.abort_controller); stream_status.open = false; + close(); } } }; @@ -743,13 +728,78 @@ export function submit( } ); - return { on, off, cancel, destroy }; + let done = false; + const values: (IteratorResult | PromiseLike)[] = []; + const resolvers: (( + value: IteratorResult | PromiseLike + ) => void)[] = []; + + function close(): void { + done = true; + while (resolvers.length > 0) + (resolvers.shift() as (typeof resolvers)[0])({ + value: undefined, + done: true + }); + } + + function push( + data: { value: GradioEvent; done: boolean } | PromiseLike + ): void { + if (done) return; + if (resolvers.length > 0) { + (resolvers.shift() as (typeof resolvers)[0])(data); + } else { + values.push(data); + } + } + + function push_error(error: unknown): void { + push(thenable_reject(error)); + close(); + } + + function push_event(event: GradioEvent): void { + push({ value: event, done: false }); + } + + function next(): Promise> { + if (values.length > 0) + return Promise.resolve(values.shift() as (typeof values)[0]); + if (done) return Promise.resolve({ value: undefined, done: true }); + return new Promise((resolve) => resolvers.push(resolve)); + } + + const iterator = { + [Symbol.asyncIterator]: () => iterator, + next, + throw: async (value: unknown) => { + push_error(value); + return next(); + }, + return: async () => { + close(); + return next(); + }, + cancel + }; + + return iterator; } catch (error) { console.error("Submit function encountered an error:", error); throw error; } } +function thenable_reject(error: T): PromiseLike { + return { + then: ( + resolve: (value: never) => PromiseLike, + reject: (error: T) => PromiseLike + ) => reject(error) + }; +} + function get_endpoint_info( api_info: ApiInfo, endpoint: string | number, diff --git a/client/js/vite.config.js b/client/js/vite.config.js index 61fb69fb579e..efc4eda54e41 100644 --- a/client/js/vite.config.js +++ b/client/js/vite.config.js @@ -3,30 +3,37 @@ import { svelte } from "@sveltejs/vite-plugin-svelte"; const TEST_MODE = process.env.TEST_MODE || "happy-dom"; -export default defineConfig({ - build: { - lib: { - entry: "src/index.ts", - formats: ["es"], - fileName: (format) => `index.${format}.js` - }, - rollupOptions: { - input: "src/index.ts", - output: { - dir: "dist" +export default defineConfig(({ mode }) => { + if (mode === "preview") { + return { + entry: "index.html" + }; + } + return { + build: { + lib: { + entry: "src/index.ts", + formats: ["es"], + fileName: (format) => `index.${format}.js` + }, + rollupOptions: { + input: "src/index.ts", + output: { + dir: "dist" + } } - } - }, - plugins: [svelte()], + }, + plugins: [svelte()], - mode: process.env.MODE || "development", - test: { - include: ["./src/test/*.test.*"], - environment: TEST_MODE - }, - ssr: { - target: "node", - format: "esm", - noExternal: ["ws", "semiver", "bufferutil", "@gradio/upload"] - } + mode: process.env.MODE || "development", + test: { + include: ["./src/test/*.test.*"], + environment: TEST_MODE + }, + ssr: { + target: "node", + format: "esm", + noExternal: ["ws", "semiver", "bufferutil", "@gradio/upload"] + } + }; }); diff --git a/demo/cancel_events/run.ipynb b/demo/cancel_events/run.ipynb index d2b61c048a4b..c45edc3da683 100644 --- a/demo/cancel_events/run.ipynb +++ b/demo/cancel_events/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: cancel_events"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import time\n", "import gradio as gr\n", "import atexit\n", "import pathlib\n", "\n", "log_file = (pathlib.Path(__file__).parent / \"cancel_events_output_log.txt\").resolve()\n", "\n", "def fake_diffusion(steps):\n", " log_file.write_text(\"\")\n", " for i in range(steps):\n", " print(f\"Current step: {i}\")\n", " with log_file.open(\"a\") as f:\n", " f.write(f\"Current step: {i}\\n\")\n", " time.sleep(0.2)\n", " yield str(i)\n", "\n", "\n", "def long_prediction(*args, **kwargs):\n", " time.sleep(10)\n", " return 42\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row():\n", " with gr.Column():\n", " n = gr.Slider(1, 10, value=9, step=1, label=\"Number Steps\")\n", " run = gr.Button(value=\"Start Iterating\")\n", " output = gr.Textbox(label=\"Iterative Output\")\n", " stop = gr.Button(value=\"Stop Iterating\")\n", " with gr.Column():\n", " textbox = gr.Textbox(label=\"Prompt\")\n", " prediction = gr.Number(label=\"Expensive Calculation\")\n", " run_pred = gr.Button(value=\"Run Expensive Calculation\")\n", " with gr.Column():\n", " cancel_on_change = gr.Textbox(label=\"Cancel Iteration and Expensive Calculation on Change\")\n", " cancel_on_submit = gr.Textbox(label=\"Cancel Iteration and Expensive Calculation on Submit\")\n", " echo = gr.Textbox(label=\"Echo\")\n", " with gr.Row():\n", " with gr.Column():\n", " image = gr.Image(sources=[\"webcam\"], label=\"Cancel on clear\", interactive=True)\n", " with gr.Column():\n", " video = gr.Video(sources=[\"webcam\"], label=\"Cancel on start recording\", interactive=True)\n", "\n", " click_event = run.click(fake_diffusion, n, output)\n", " stop.click(fn=None, inputs=None, outputs=None, cancels=[click_event])\n", " pred_event = run_pred.click(fn=long_prediction, inputs=[textbox], outputs=prediction)\n", "\n", " cancel_on_change.change(None, None, None, cancels=[click_event, pred_event])\n", " cancel_on_submit.submit(lambda s: s, cancel_on_submit, echo, cancels=[click_event, pred_event])\n", " image.clear(None, None, None, cancels=[click_event, pred_event])\n", " video.start_recording(None, None, None, cancels=[click_event, pred_event])\n", "\n", " demo.queue(max_size=20)\n", " atexit.register(lambda: log_file.unlink())\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: cancel_events"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import time\n", "import gradio as gr\n", "import atexit\n", "import pathlib\n", "\n", "log_file = pathlib.Path(__file__).parent / \"cancel_events_output_log.txt\"\n", "\n", "\n", "def fake_diffusion(steps):\n", " log_file.write_text(\"\")\n", " for i in range(steps):\n", " print(f\"Current step: {i}\")\n", " with log_file.open(\"a\") as f:\n", " f.write(f\"Current step: {i}\\n\")\n", " time.sleep(0.2)\n", " yield str(i)\n", "\n", "\n", "def long_prediction(*args, **kwargs):\n", " time.sleep(10)\n", " return 42\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row():\n", " with gr.Column():\n", " n = gr.Slider(1, 10, value=9, step=1, label=\"Number Steps\")\n", " run = gr.Button(value=\"Start Iterating\")\n", " output = gr.Textbox(label=\"Iterative Output\")\n", " stop = gr.Button(value=\"Stop Iterating\")\n", " with gr.Column():\n", " textbox = gr.Textbox(label=\"Prompt\")\n", " prediction = gr.Number(label=\"Expensive Calculation\")\n", " run_pred = gr.Button(value=\"Run Expensive Calculation\")\n", " with gr.Column():\n", " cancel_on_change = gr.Textbox(\n", " label=\"Cancel Iteration and Expensive Calculation on Change\"\n", " )\n", " cancel_on_submit = gr.Textbox(\n", " label=\"Cancel Iteration and Expensive Calculation on Submit\"\n", " )\n", " echo = gr.Textbox(label=\"Echo\")\n", " with gr.Row():\n", " with gr.Column():\n", " image = gr.Image(\n", " sources=[\"webcam\"], label=\"Cancel on clear\", interactive=True\n", " )\n", " with gr.Column():\n", " video = gr.Video(\n", " sources=[\"webcam\"], label=\"Cancel on start recording\", interactive=True\n", " )\n", "\n", " click_event = run.click(fake_diffusion, n, output)\n", " stop.click(fn=None, inputs=None, outputs=None, cancels=[click_event])\n", " pred_event = run_pred.click(\n", " fn=long_prediction, inputs=[textbox], outputs=prediction\n", " )\n", "\n", " cancel_on_change.change(None, None, None, cancels=[click_event, pred_event])\n", " cancel_on_submit.submit(\n", " lambda s: s, cancel_on_submit, echo, cancels=[click_event, pred_event]\n", " )\n", " image.clear(None, None, None, cancels=[click_event, pred_event])\n", " video.start_recording(None, None, None, cancels=[click_event, pred_event])\n", "\n", " demo.queue(max_size=20)\n", " atexit.register(lambda: log_file.unlink())\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/cancel_events/run.py b/demo/cancel_events/run.py index 409858c5748a..7c2d3b73c2a0 100644 --- a/demo/cancel_events/run.py +++ b/demo/cancel_events/run.py @@ -3,7 +3,8 @@ import atexit import pathlib -log_file = (pathlib.Path(__file__).parent / "cancel_events_output_log.txt").resolve() +log_file = pathlib.Path(__file__).parent / "cancel_events_output_log.txt" + def fake_diffusion(steps): log_file.write_text("") @@ -32,21 +33,33 @@ def long_prediction(*args, **kwargs): prediction = gr.Number(label="Expensive Calculation") run_pred = gr.Button(value="Run Expensive Calculation") with gr.Column(): - cancel_on_change = gr.Textbox(label="Cancel Iteration and Expensive Calculation on Change") - cancel_on_submit = gr.Textbox(label="Cancel Iteration and Expensive Calculation on Submit") + cancel_on_change = gr.Textbox( + label="Cancel Iteration and Expensive Calculation on Change" + ) + cancel_on_submit = gr.Textbox( + label="Cancel Iteration and Expensive Calculation on Submit" + ) echo = gr.Textbox(label="Echo") with gr.Row(): with gr.Column(): - image = gr.Image(sources=["webcam"], label="Cancel on clear", interactive=True) + image = gr.Image( + sources=["webcam"], label="Cancel on clear", interactive=True + ) with gr.Column(): - video = gr.Video(sources=["webcam"], label="Cancel on start recording", interactive=True) + video = gr.Video( + sources=["webcam"], label="Cancel on start recording", interactive=True + ) click_event = run.click(fake_diffusion, n, output) stop.click(fn=None, inputs=None, outputs=None, cancels=[click_event]) - pred_event = run_pred.click(fn=long_prediction, inputs=[textbox], outputs=prediction) + pred_event = run_pred.click( + fn=long_prediction, inputs=[textbox], outputs=prediction + ) cancel_on_change.change(None, None, None, cancels=[click_event, pred_event]) - cancel_on_submit.submit(lambda s: s, cancel_on_submit, echo, cancels=[click_event, pred_event]) + cancel_on_submit.submit( + lambda s: s, cancel_on_submit, echo, cancels=[click_event, pred_event] + ) image.clear(None, None, None, cancels=[click_event, pred_event]) video.start_recording(None, None, None, cancels=[click_event, pred_event]) diff --git a/demo/dataframe_colorful/run.ipynb b/demo/dataframe_colorful/run.ipynb index 83a30371f132..e9181aab7281 100644 --- a/demo/dataframe_colorful/run.ipynb +++ b/demo/dataframe_colorful/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: dataframe_colorful"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import pandas as pd \n", "import gradio as gr\n", "\n", "df = pd.DataFrame({\"A\" : [14, 4, 5, 4, 1], \n", "\t\t\t\t\"B\" : [5, 2, 54, 3, 2], \n", "\t\t\t\t\"C\" : [20, 20, 7, 3, 8], \n", "\t\t\t\t\"D\" : [14, 3, 6, 2, 6], \n", "\t\t\t\t\"E\" : [23, 45, 64, 32, 23]}) \n", "\n", "t = df.style.highlight_max(color = 'lightgreen', axis = 0)\n", "\n", "with gr.Blocks() as demo:\n", " gr.Dataframe(t)\n", " \n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: dataframe_colorful"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import pandas as pd\n", "import gradio as gr\n", "\n", "df = pd.DataFrame(\n", " {\n", " \"A\": [14, 4, 5, 4, 1],\n", " \"B\": [5, 2, 54, 3, 2],\n", " \"C\": [20, 20, 7, 3, 8],\n", " \"D\": [14, 3, 6, 2, 6],\n", " \"E\": [23, 45, 64, 32, 23],\n", " }\n", ")\n", "\n", "t = df.style.highlight_max(color=\"lightgreen\", axis=0)\n", "\n", "with gr.Blocks() as demo:\n", " gr.Dataframe(t)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/dataframe_colorful/run.py b/demo/dataframe_colorful/run.py index eb97120438de..38f146c356cd 100644 --- a/demo/dataframe_colorful/run.py +++ b/demo/dataframe_colorful/run.py @@ -1,16 +1,20 @@ -import pandas as pd +import pandas as pd import gradio as gr -df = pd.DataFrame({"A" : [14, 4, 5, 4, 1], - "B" : [5, 2, 54, 3, 2], - "C" : [20, 20, 7, 3, 8], - "D" : [14, 3, 6, 2, 6], - "E" : [23, 45, 64, 32, 23]}) +df = pd.DataFrame( + { + "A": [14, 4, 5, 4, 1], + "B": [5, 2, 54, 3, 2], + "C": [20, 20, 7, 3, 8], + "D": [14, 3, 6, 2, 6], + "E": [23, 45, 64, 32, 23], + } +) -t = df.style.highlight_max(color = 'lightgreen', axis = 0) +t = df.style.highlight_max(color="lightgreen", axis=0) with gr.Blocks() as demo: gr.Dataframe(t) - + if __name__ == "__main__": - demo.launch() \ No newline at end of file + demo.launch() diff --git a/demo/state_change/run.ipynb b/demo/state_change/run.ipynb index d3d0891213f2..8b4d611579ad 100644 --- a/demo/state_change/run.ipynb +++ b/demo/state_change/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: state_change"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", "\n", " with gr.Row():\n", " state_a = gr.State(0)\n", " btn_a = gr.Button(\"Increment A\")\n", " value_a = gr.Number(label=\"A\")\n", " btn_a.click(lambda x: x+1, state_a, state_a)\n", " state_a.change(lambda x: x, state_a, value_a)\n", " with gr.Row():\n", " state_b = gr.State(0)\n", " btn_b = gr.Button(\"Increment B\")\n", " value_b = gr.Number(label=\"num\")\n", " btn_b.click(lambda x: x+1, state_b, state_b)\n", "\n", " @gr.on(inputs=state_b, outputs=value_b)\n", " def identity(x):\n", " return x\n", "\n", " @gr.render(inputs=[state_a, state_b])\n", " def render(a, b):\n", " for x in range(a):\n", " with gr.Row():\n", " for y in range(b):\n", " gr.Button(f\"Button {x}, {y}\")\n", "\n", " list_state = gr.State([])\n", " dict_state = gr.State(dict())\n", " nested_list_state = gr.State([])\n", " set_state = gr.State(set())\n", "\n", " def transform_list(x):\n", " return {n: n for n in x}, [x[:] for _ in range(len(x))], set(x)\n", " \n", " list_state.change(\n", " transform_list,\n", " inputs=list_state,\n", " outputs=[dict_state, nested_list_state, set_state],\n", " )\n", "\n", " all_textbox = gr.Textbox(label=\"Output\")\n", " click_count = gr.Number(label=\"Clicks\")\n", " change_count = gr.Number(label=\"Changes\")\n", " gr.on(\n", " inputs=[change_count, dict_state, nested_list_state, set_state],\n", " triggers=[dict_state.change, nested_list_state.change, set_state.change],\n", " fn=lambda x, *args: (x+1, \"\\n\".join(str(arg) for arg in args)),\n", " outputs=[change_count, all_textbox],\n", " )\n", "\n", " count_to_3_btn = gr.Button(\"Count to 3\")\n", " count_to_3_btn.click(lambda: [1, 2, 3], outputs=list_state)\n", " zero_all_btn = gr.Button(\"Zero All\")\n", " zero_all_btn.click(\n", " lambda x: [0] * len(x), inputs=list_state, outputs=list_state\n", " )\n", "\n", " gr.on([count_to_3_btn.click, zero_all_btn.click], lambda x: x + 1, click_count, click_count)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: state_change"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", "\n", " with gr.Row():\n", " state_a = gr.State(0)\n", " btn_a = gr.Button(\"Increment A\")\n", " value_a = gr.Number(label=\"Number A\")\n", " btn_a.click(lambda x: x + 1, state_a, state_a)\n", " state_a.change(lambda x: x, state_a, value_a)\n", " with gr.Row():\n", " state_b = gr.State(0)\n", " btn_b = gr.Button(\"Increment B\")\n", " value_b = gr.Number(label=\"Number B\")\n", " btn_b.click(lambda x: x + 1, state_b, state_b)\n", "\n", " @gr.on(inputs=state_b, outputs=value_b)\n", " def identity(x):\n", " return x\n", "\n", " @gr.render(inputs=[state_a, state_b])\n", " def render(a, b):\n", " for x in range(a):\n", " with gr.Row():\n", " for y in range(b):\n", " gr.Button(f\"Button {x}, {y}\")\n", "\n", " list_state = gr.State([])\n", " dict_state = gr.State(dict())\n", " nested_list_state = gr.State([])\n", " set_state = gr.State(set())\n", "\n", " def transform_list(x):\n", " return {n: n for n in x}, [x[:] for _ in range(len(x))], set(x)\n", "\n", " list_state.change(\n", " transform_list,\n", " inputs=list_state,\n", " outputs=[dict_state, nested_list_state, set_state],\n", " )\n", "\n", " all_textbox = gr.Textbox(label=\"Output\")\n", " click_count = gr.Number(label=\"Clicks\")\n", " change_count = gr.Number(label=\"Changes\")\n", " gr.on(\n", " inputs=[change_count, dict_state, nested_list_state, set_state],\n", " triggers=[dict_state.change, nested_list_state.change, set_state.change],\n", " fn=lambda x, *args: (x + 1, \"\\n\".join(str(arg) for arg in args)),\n", " outputs=[change_count, all_textbox],\n", " )\n", "\n", " count_to_3_btn = gr.Button(\"Count to 3\")\n", " count_to_3_btn.click(lambda: [1, 2, 3], outputs=list_state)\n", " zero_all_btn = gr.Button(\"Zero All\")\n", " zero_all_btn.click(lambda x: [0] * len(x), inputs=list_state, outputs=list_state)\n", "\n", " gr.on(\n", " [count_to_3_btn.click, zero_all_btn.click],\n", " lambda x: x + 1,\n", " click_count,\n", " click_count,\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/state_change/run.py b/demo/state_change/run.py index c23f00ff47cd..2846a38ce51c 100644 --- a/demo/state_change/run.py +++ b/demo/state_change/run.py @@ -5,14 +5,14 @@ with gr.Row(): state_a = gr.State(0) btn_a = gr.Button("Increment A") - value_a = gr.Number(label="A") - btn_a.click(lambda x: x+1, state_a, state_a) + value_a = gr.Number(label="Number A") + btn_a.click(lambda x: x + 1, state_a, state_a) state_a.change(lambda x: x, state_a, value_a) with gr.Row(): state_b = gr.State(0) btn_b = gr.Button("Increment B") - value_b = gr.Number(label="num") - btn_b.click(lambda x: x+1, state_b, state_b) + value_b = gr.Number(label="Number B") + btn_b.click(lambda x: x + 1, state_b, state_b) @gr.on(inputs=state_b, outputs=value_b) def identity(x): @@ -32,7 +32,7 @@ def render(a, b): def transform_list(x): return {n: n for n in x}, [x[:] for _ in range(len(x))], set(x) - + list_state.change( transform_list, inputs=list_state, @@ -45,18 +45,21 @@ def transform_list(x): gr.on( inputs=[change_count, dict_state, nested_list_state, set_state], triggers=[dict_state.change, nested_list_state.change, set_state.change], - fn=lambda x, *args: (x+1, "\n".join(str(arg) for arg in args)), + fn=lambda x, *args: (x + 1, "\n".join(str(arg) for arg in args)), outputs=[change_count, all_textbox], ) count_to_3_btn = gr.Button("Count to 3") count_to_3_btn.click(lambda: [1, 2, 3], outputs=list_state) zero_all_btn = gr.Button("Zero All") - zero_all_btn.click( - lambda x: [0] * len(x), inputs=list_state, outputs=list_state - ) + zero_all_btn.click(lambda x: [0] * len(x), inputs=list_state, outputs=list_state) - gr.on([count_to_3_btn.click, zero_all_btn.click], lambda x: x + 1, click_count, click_count) + gr.on( + [count_to_3_btn.click, zero_all_btn.click], + lambda x: x + 1, + click_count, + click_count, + ) if __name__ == "__main__": - demo.launch() \ No newline at end of file + demo.launch() diff --git a/js/_spaces-test/src/routes/client-browser/+page.svelte b/js/_spaces-test/src/routes/client-browser/+page.svelte index 762aae5d4f76..d9c9010c46ab 100644 --- a/js/_spaces-test/src/routes/client-browser/+page.svelte +++ b/js/_spaces-test/src/routes/client-browser/+page.svelte @@ -83,14 +83,13 @@ async function submit() { response_data = { data: [], fn_index: 0, endpoint: "" }; - job = app - .submit(active_endpoint, request_data) - .on("data", (data) => { - response_data = data; - }) - .on("status", (_status) => { - status = _status.stage; - }); + job = app.submit(active_endpoint, request_data); + // .on("data", (data) => { + // response_data = data; + // }) + // .on("status", (_status) => { + // status = _status.stage; + // }); } function cancel() { @@ -194,12 +193,7 @@ {#if app_info.type.generator || app_info.type.continuous} - + {/if}
diff --git a/js/app/src/Blocks.svelte b/js/app/src/Blocks.svelte index 83bc2ae686e6..740c629b308f 100644 --- a/js/app/src/Blocks.svelte +++ b/js/app/src/Blocks.svelte @@ -18,6 +18,11 @@ import logo from "./images/logo.svg"; import api_logo from "./api_docs/img/api-logo.svg"; import { create_components, AsyncFunction } from "./init"; + import type { + LogMessage, + RenderMessage, + StatusMessage + } from "@gradio/client"; setupi18n(); @@ -284,132 +289,145 @@ return; } - submission - .on("data", ({ data, fn_index }) => { - if (dep.pending_request && dep.final_event) { - dep.pending_request = false; - make_prediction(dep.final_event); - } + submit_map.set(dep_index, submission); + + for await (const message of submission) { + if (message.type === "data") { + handle_data(message); + } else if (message.type === "render") { + handle_render(message); + } else if (message.type === "status") { + handle_status_update(message); + } else if (message.type === "log") { + handle_log(message); + } + } + + function handle_data(message: Payload): void { + const { data, fn_index } = message; + if (dep.pending_request && dep.final_event) { dep.pending_request = false; - handle_update(data, fn_index); - set_status($loading_status); - }) - .on("render", ({ data }) => { - let _components: ComponentMeta[] = data.components; - let render_layout: LayoutNode = data.layout; - let _dependencies: Dependency[] = data.dependencies; - let render_id = data.render_id; - - let deps_to_remove: number[] = []; - dependencies.forEach((dep, i) => { - if (dep.rendered_in === render_id) { - deps_to_remove.push(i); - } - }); - deps_to_remove.reverse().forEach((i) => { - dependencies.splice(i, 1); - }); - _dependencies.forEach((dep) => { - dependencies.push(dep); - }); + make_prediction(dep.final_event); + } + dep.pending_request = false; + handle_update(data, fn_index); + set_status($loading_status); + } + + function handle_render(message: RenderMessage): void { + const { data } = message; + let _components: ComponentMeta[] = data.components; + let render_layout: LayoutNode = data.layout; + let _dependencies: Dependency[] = data.dependencies; + let render_id = data.render_id; + + let deps_to_remove: number[] = []; + dependencies.forEach((dep, i) => { + if (dep.rendered_in === render_id) { + deps_to_remove.push(i); + } + }); + deps_to_remove.reverse().forEach((i) => { + dependencies.splice(i, 1); + }); + _dependencies.forEach((dep) => { + dependencies.push(dep); + }); + + rerender_layout({ + components: _components, + layout: render_layout, + root: root, + dependencies: dependencies, + render_id: render_id + }); + } + + function handle_log(msg: LogMessage): void { + const { log, fn_index, level } = msg; + messages = [new_message(log, fn_index, level), ...messages]; + } + + function handle_status_update(message: StatusMessage): void { + const { fn_index, ...status } = message; + //@ts-ignore + loading_status.update({ + ...status, + status: status.stage, + progress: status.progress_data, + fn_index + }); + set_status($loading_status); + if ( + !showed_duplicate_message && + space_id !== null && + status.position !== undefined && + status.position >= 2 && + status.eta !== undefined && + status.eta > SHOW_DUPLICATE_MESSAGE_ON_ETA + ) { + showed_duplicate_message = true; + messages = [ + new_message(DUPLICATE_MESSAGE, fn_index, "warning"), + ...messages + ]; + } + if ( + !showed_mobile_warning && + is_mobile_device && + status.eta !== undefined && + status.eta > SHOW_MOBILE_QUEUE_WARNING_ON_ETA + ) { + showed_mobile_warning = true; + messages = [ + new_message(MOBILE_QUEUE_WARNING, fn_index, "warning"), + ...messages + ]; + } - rerender_layout({ - components: _components, - layout: render_layout, - root: root, - dependencies: dependencies, - render_id: render_id + if (status.stage === "complete") { + status.changed_state_ids?.forEach((id) => { + dependencies + .filter((dep) => dep.targets.some(([_id, _]) => _id === id)) + .forEach((dep) => { + wait_then_trigger_api_call(dep.id, payload.trigger_id); + }); }); - }) - .on("status", ({ fn_index, ...status }) => { - //@ts-ignore - loading_status.update({ - ...status, - status: status.stage, - progress: status.progress_data, - fn_index + dependencies.forEach(async (dep) => { + if (dep.trigger_after === fn_index) { + wait_then_trigger_api_call(dep.id, payload.trigger_id); + } }); - set_status($loading_status); - if ( - !showed_duplicate_message && - space_id !== null && - status.position !== undefined && - status.position >= 2 && - status.eta !== undefined && - status.eta > SHOW_DUPLICATE_MESSAGE_ON_ETA - ) { - showed_duplicate_message = true; - messages = [ - new_message(DUPLICATE_MESSAGE, fn_index, "warning"), - ...messages - ]; - } - if ( - !showed_mobile_warning && - is_mobile_device && - status.eta !== undefined && - status.eta > SHOW_MOBILE_QUEUE_WARNING_ON_ETA - ) { - showed_mobile_warning = true; + + // submission.destroy(); + } + if (status.broken && is_mobile_device && user_left_page) { + window.setTimeout(() => { messages = [ - new_message(MOBILE_QUEUE_WARNING, fn_index, "warning"), + new_message(MOBILE_RECONNECT_MESSAGE, fn_index, "error"), ...messages ]; + }, 0); + wait_then_trigger_api_call(dep.id, payload.trigger_id, event_data); + user_left_page = false; + } else if (status.stage === "error") { + if (status.message) { + const _message = status.message.replace( + MESSAGE_QUOTE_RE, + (_, b) => b + ); + messages = [new_message(_message, fn_index, "error"), ...messages]; } - - if (status.stage === "complete") { - status.changed_state_ids?.forEach((id) => { - dependencies - .filter((dep) => dep.targets.some(([_id, _]) => _id === id)) - .forEach((dep) => { - wait_then_trigger_api_call(dep.id, payload.trigger_id); - }); - }); - dependencies.forEach(async (dep) => { - if (dep.trigger_after === fn_index) { - wait_then_trigger_api_call(dep.id, payload.trigger_id); - } - }); - - submission.destroy(); - } - if (status.broken && is_mobile_device && user_left_page) { - window.setTimeout(() => { - messages = [ - new_message(MOBILE_RECONNECT_MESSAGE, fn_index, "error"), - ...messages - ]; - }, 0); - wait_then_trigger_api_call(dep.id, payload.trigger_id, event_data); - user_left_page = false; - } else if (status.stage === "error") { - if (status.message) { - const _message = status.message.replace( - MESSAGE_QUOTE_RE, - (_, b) => b - ); - messages = [ - new_message(_message, fn_index, "error"), - ...messages - ]; + dependencies.map(async (dep) => { + if ( + dep.trigger_after === fn_index && + !dep.trigger_only_on_success + ) { + wait_then_trigger_api_call(dep.id, payload.trigger_id); } - dependencies.map(async (dep) => { - if ( - dep.trigger_after === fn_index && - !dep.trigger_only_on_success - ) { - wait_then_trigger_api_call(dep.id, payload.trigger_id); - } - }); - - submission.destroy(); - } - }) - .on("log", ({ log, fn_index, level }) => { - messages = [new_message(log, fn_index, level), ...messages]; - }); - - submit_map.set(dep_index, submission); + }); + } + } } } diff --git a/js/app/src/Index.svelte b/js/app/src/Index.svelte index eb806a18010f..3d9dee5360eb 100644 --- a/js/app/src/Index.svelte +++ b/js/app/src/Index.svelte @@ -275,7 +275,8 @@ app = await Client.connect(api_url, { status_callback: handle_status, - with_null_state: true + with_null_state: true, + events: ["data", "log", "status", "render"] }); if (!app.config) { @@ -312,7 +313,8 @@ stream.addEventListener("reload", async (event) => { app.close(); app = await Client.connect(api_url, { - status_callback: handle_status + status_callback: handle_status, + events: ["data", "log", "status", "render"] }); if (!app.config) { diff --git a/js/app/src/lite/index.ts b/js/app/src/lite/index.ts index de972f041d80..02221447545d 100644 --- a/js/app/src/lite/index.ts +++ b/js/app/src/lite/index.ts @@ -109,7 +109,7 @@ export function create(options: Options): GradioAppController { return wasm_proxied_fetch(worker_proxy, input, init); } - async stream(url: URL): Promise { + stream(url: URL): EventSource { return wasm_proxied_stream_factory(worker_proxy, url); } } diff --git a/js/app/test/state_change.spec.ts b/js/app/test/state_change.spec.ts index 41c45fe7f359..d10e184799f2 100644 --- a/js/app/test/state_change.spec.ts +++ b/js/app/test/state_change.spec.ts @@ -5,14 +5,22 @@ test("test 2d state-based render", async ({ page }) => { await expect( page.locator("button").filter({ hasText: "Button" }) ).toHaveCount(0); + + await expect(page.getByLabel("Number A")).toHaveValue("1"); await page.getByRole("button", { name: "Increment B" }).click(); await page.getByRole("button", { name: "Increment A" }).click(); + await expect(page.getByLabel("Number B")).toHaveValue("1"); await expect( page.locator("button").filter({ hasText: "Button" }) ).toHaveCount(2); await page.getByRole("button", { name: "Increment A" }).click(); + await expect(page.getByLabel("Number A")).toHaveValue("2"); + await page.getByRole("button", { name: "Increment B" }).click(); + await expect(page.getByLabel("Number B")).toHaveValue("2"); + await page.getByRole("button", { name: "Increment A" }).click(); + await expect(page.getByLabel("Number A").first()).toHaveValue("4"); await expect( page.locator("button").filter({ hasText: "Button" }) ).toHaveCount(8); diff --git a/js/file/shared/FilePreview.svelte b/js/file/shared/FilePreview.svelte index fbc9c0bcb41d..2b024ff23c0b 100644 --- a/js/file/shared/FilePreview.svelte +++ b/js/file/shared/FilePreview.svelte @@ -39,7 +39,9 @@ const tr = event.currentTarget; const should_select = event.target === tr || // Only select if the click is on the row itself - event.composedPath().includes(tr.firstElementChild); // Or if the click is on the name column + (tr && + tr.firstElementChild && + event.composedPath().includes(tr.firstElementChild)); // Or if the click is on the name column if (should_select) { dispatch("select", { value: normalized_files[index].orig_name, index }); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index c061d6732552..4d318e564347 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -231,12 +231,18 @@ importers: eventsource: specifier: ^2.0.2 version: 2.0.2 + fetch-event-stream: + specifier: ^0.1.5 + version: 0.1.5 msw: specifier: ^2.2.1 version: 2.2.8(typescript@5.4.3) semiver: specifier: ^1.1.0 version: 1.1.0 + textlinestream: + specifier: ^1.1.1 + version: 1.1.1 typescript: specifier: ^5.0.0 version: 5.4.3 @@ -5851,6 +5857,9 @@ packages: fastq@1.15.0: resolution: {integrity: sha512-wBrocU2LCXXa+lWBt8RoIRD89Fi8OdABODa/kEnyeyjS5aZO5/GNvI5sEINADqP/h8M29UHTHUb53sUu5Ihqdw==} + fetch-event-stream@0.1.5: + resolution: {integrity: sha512-V1PWovkspxQfssq/NnxoEyQo1DV+MRK/laPuPblIZmSjMN8P5u46OhlFQznSr9p/t0Sp8Uc6SbM3yCMfr0KU8g==} + fetch-retry@5.0.6: resolution: {integrity: sha512-3yurQZ2hD9VISAhJJP9bpYFNQrHHBXE2JxxjY5aLEcDi46RmAzJE2OC9FAde0yis5ElW0jTTzs0zfg/Cca4XqQ==} @@ -8460,6 +8469,9 @@ packages: text-table@0.2.0: resolution: {integrity: sha512-N+8UisAXDGk8PFXP4HAzVR9nbfmVJ3zYLAWiTIoqC5v5isinhr+r5uaO8+7r3BMfuNIufIsA7RdpVgacC2cSpw==} + textlinestream@1.1.1: + resolution: {integrity: sha512-iBHbi7BQxrFmwZUQJsT0SjNzlLLsXhvW/kg7EyOMVMBIrlnj/qYofwo1LVLZi+3GbUEo96Iu2eqToI2+lZoAEQ==} + thenify-all@1.6.0: resolution: {integrity: sha512-RNxQH/qI8/t3thXJDwcstUO4zeqo64+Uy/+sNVRBx4Xn2OX+OZ9oP+iJnNFqplFra2ZUVeKCSa2oVWi3T4uVmA==} engines: {node: '>=0.8'} @@ -14170,6 +14182,8 @@ snapshots: dependencies: reusify: 1.0.4 + fetch-event-stream@0.1.5: {} + fetch-retry@5.0.6: {} figures@3.2.0: @@ -16998,6 +17012,8 @@ snapshots: text-table@0.2.0: {} + textlinestream@1.1.1: {} + thenify-all@1.6.0: dependencies: thenify: 3.3.1