Skip to content

Commit

Permalink
Handle gradio apps using state in the JS Client (#8439)
Browse files Browse the repository at this point in the history
* send `null` for each `state` param in space api

* add changeset

* test

* remove state value from payload from server

* tweak

* test

* test

* Revert "test"

This reverts commit 182045e.

* Revert "test"

This reverts commit 70e074d.

* fixes

* add changeset

* fixes

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: pngwn <hello@pngwn.io>
  • Loading branch information
3 people authored Jun 5, 2024
1 parent 5c8915b commit 63d36fb
Show file tree
Hide file tree
Showing 10 changed files with 304 additions and 19 deletions.
8 changes: 8 additions & 0 deletions .changeset/young-poets-change.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
"@gradio/app": patch
"@gradio/client": patch
"@gradio/preview": patch
"gradio": patch
---

fix:Handle gradio apps using `state` in the JS Client
3 changes: 2 additions & 1 deletion client/js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import type {
DuplicateOptions,
EndpointInfo,
JsApiData,
PredictReturn,
SpaceStatus,
Status,
SubmitReturn,
Expand Down Expand Up @@ -114,7 +115,7 @@ export class Client {
endpoint: string | number,
data: unknown[] | Record<string, unknown>,
event_data?: unknown
) => Promise<SubmitReturn>;
) => Promise<PredictReturn>;
open_stream: () => Promise<void>;
private resolve_config: (endpoint: string) => Promise<Config | undefined>;
private resolve_cookies: () => Promise<void>;
Expand Down
63 changes: 62 additions & 1 deletion client/js/src/helpers/data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import type {
Config,
EndpointInfo,
JsApiData,
DataType
DataType,
Dependency,
ComponentMeta
} from "../types";

export function update_object(
Expand Down Expand Up @@ -118,3 +120,62 @@ export function post_message<Res = any>(
window.parent.postMessage(message, origin, [channel.port2]);
});
}

/**
* Handles the payload by filtering out state inputs and returning an array of resolved payload values.
* We send null values for state inputs to the server, but we don't want to include them in the resolved payload.
*
* @param resolved_payload - The resolved payload values received from the client or the server
* @param dependency - The dependency object.
* @param components - The array of component metadata.
* @param with_null_state - Optional. Specifies whether to include null values for state inputs. Default is false.
* @returns An array of resolved payload values, filtered based on the dependency and component metadata.
*/
export function handle_payload(
resolved_payload: unknown[],
dependency: Dependency,
components: ComponentMeta[],
type: "input" | "output",
with_null_state = false
): unknown[] {
if (type === "input" && !with_null_state) {
throw new Error("Invalid code path. Cannot skip state inputs for input.");
}
// data comes from the server with null state values so we skip
if (type === "output" && with_null_state) {
return resolved_payload;
}

let updated_payload: unknown[] = [];
let payload_index = 0;
for (let i = 0; i < dependency.inputs.length; i++) {
const input_id = dependency.inputs[i];
const component = components.find((c) => c.id === input_id);

if (component?.type === "state") {
// input + with_null_state needs us to fill state with null values
if (with_null_state) {
if (resolved_payload.length === dependency.inputs.length) {
const value = resolved_payload[payload_index];
updated_payload.push(value);
payload_index++;
} else {
updated_payload.push(null);
}
} else {
// this is output & !with_null_state, we skip state inputs
// the server payload always comes with null state values so we move along the payload index
payload_index++;
continue;
}
// input & !with_null_state isn't a case we care about, server needs null
continue;
} else {
const value = resolved_payload[payload_index];
updated_payload.push(value);
payload_index++;
}
}

return updated_payload;
}
1 change: 0 additions & 1 deletion client/js/src/test/api_info.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import { initialise_server } from "./server";
import { transformed_api_info } from "./test_data";

const server = initialise_server();
const IS_NODE = process.env.TEST_MODE === "node";

beforeAll(() => server.listen());
afterEach(() => server.resetHandlers());
Expand Down
135 changes: 134 additions & 1 deletion client/js/src/test/data.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ import {
update_object,
walk_and_store_blobs,
skip_queue,
post_message
post_message,
handle_payload
} from "../helpers/data";
import { NodeBlob } from "../client";
import { config_response, endpoint_info } from "./test_data";
Expand Down Expand Up @@ -276,3 +277,135 @@ describe("post_message", () => {
]);
});
});

describe("handle_payload", () => {
it("should return an input payload with null in place of `state` when with_null_state is true", () => {
const resolved_payload = [2];
const dependency = {
inputs: [1, 2]
};
const components = [
{ id: 1, type: "number" },
{ id: 2, type: "state" }
];
const with_null_state = true;
const result = handle_payload(
resolved_payload,
// @ts-ignore
dependency,
components,
"input",
with_null_state
);
expect(result).toEqual([2, null]);
});
it("should return an input payload with null in place of two `state` components when with_null_state is true", () => {
const resolved_payload = ["hello", "goodbye"];
const dependency = {
inputs: [1, 2, 3, 4]
};
const components = [
{ id: 1, type: "textbox" },
{ id: 2, type: "state" },
{ id: 3, type: "textbox" },
{ id: 4, type: "state" }
];
const with_null_state = true;
const result = handle_payload(
resolved_payload,
// @ts-ignore
dependency,
components,
"input",
with_null_state
);
expect(result).toEqual(["hello", null, "goodbye", null]);
});

it("should return an output payload without the state component value when with_null_state is false", () => {
const resolved_payload = ["hello", null];
const dependency = {
inputs: [2, 3]
};
const components = [
{ id: 2, type: "textbox" },
{ id: 3, type: "state" }
];
const with_null_state = false;
const result = handle_payload(
resolved_payload,
// @ts-ignore
dependency,
components,
"output",
with_null_state
);
expect(result).toEqual(["hello"]);
});

it("should return an ouput payload without the two state component values when with_null_state is false", () => {
const resolved_payload = ["hello", null, "world", null];
const dependency = {
inputs: [2, 3, 4, 5]
};
const components = [
{ id: 2, type: "textbox" },
{ id: 3, type: "state" },
{ id: 4, type: "textbox" },
{ id: 5, type: "state" }
];
const with_null_state = false;
const result = handle_payload(
resolved_payload,
// @ts-ignore
dependency,
components,
"output",
with_null_state
);
expect(result).toEqual(["hello", "world"]);
});

it("should return an ouput payload with the two state component values when with_null_state is true", () => {
const resolved_payload = ["hello", null, "world", null];
const dependency = {
inputs: [2, 3, 4, 5]
};
const components = [
{ id: 2, type: "textbox" },
{ id: 3, type: "state" },
{ id: 4, type: "textbox" },
{ id: 5, type: "state" }
];
const with_null_state = true;
const result = handle_payload(
resolved_payload,
// @ts-ignore
dependency,
components,
"output",
with_null_state
);
expect(result).toEqual(["hello", null, "world", null]);
});

it("should return the same payload where no state components are defined", () => {
const resolved_payload = ["hello", "world"];
const dependency = {
inputs: [2, 3]
};
const components = [
{ id: 2, type: "textbox" },
{ id: 3, type: "textbox" }
];
const with_null_state = true;
const result = handle_payload(
resolved_payload,
// @ts-ignore
dependency,
components,
with_null_state
);
expect(result).toEqual(["hello", "world"]);
});
});
54 changes: 52 additions & 2 deletions client/js/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// API Data Types

import { hardware_types } from "./helpers/spaces";
import type { SvelteComponent } from "svelte";
import type { ComponentType } from "svelte";

export interface ApiData {
label: string;
Expand Down Expand Up @@ -62,7 +64,7 @@ export type PredictFunction = (
endpoint: string | number,
data: unknown[] | Record<string, unknown>,
event_data?: unknown
) => Promise<SubmitReturn>;
) => Promise<PredictReturn>;

// Event and Submission Types

Expand Down Expand Up @@ -90,6 +92,14 @@ export type SubmitReturn = {
destroy: () => void;
};

export type PredictReturn = {
type: EventType;
time: Date;
data: unknown;
endpoint: string;
fn_index: number;
};

// Space Status Types

export type SpaceStatus = SpaceStatusNormal | SpaceStatusError;
Expand Down Expand Up @@ -128,7 +138,7 @@ export interface Config {
analytics_enabled: boolean;
connect_heartbeat: boolean;
auth_message: string;
components: any[];
components: ComponentMeta[];
css: string | null;
js: string | null;
head: string | null;
Expand All @@ -153,6 +163,45 @@ export interface Config {
max_file_size?: number;
}

// todo: DRY up types
export interface ComponentMeta {
type: string;
id: number;
has_modes: boolean;
props: SharedProps;
instance: SvelteComponent;
component: ComponentType<SvelteComponent>;
documentation?: Documentation;
children?: ComponentMeta[];
parent?: ComponentMeta;
value?: any;
component_class_id: string;
key: string | number | null;
rendered_in?: number;
}

interface SharedProps {
elem_id?: string;
elem_classes?: string[];
components?: string[];
server_fns?: string[];
interactive: boolean;
[key: string]: unknown;
root_url?: string;
}

export interface Documentation {
type?: TypeDescription;
description?: TypeDescription;
example_data?: string;
}

interface TypeDescription {
input_payload?: string;
response_object?: string;
payload?: string;
}

export interface Dependency {
id: number;
targets: [number, string][];
Expand Down Expand Up @@ -218,6 +267,7 @@ export interface ClientOptions {
hf_token?: `hf_${string}`;
status_callback?: SpaceStatusCallback | null;
auth?: [string, string] | null;
with_null_state?: boolean;
}

export interface FileData {
Expand Down
8 changes: 4 additions & 4 deletions client/js/src/utils/predict.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import { Client } from "../client";
import type { Dependency, SubmitReturn } from "../types";
import type { Dependency, PredictReturn } from "../types";

export async function predict(
this: Client,
endpoint: string | number,
data: unknown[] | Record<string, unknown>
): Promise<SubmitReturn> {
): Promise<PredictReturn> {
let data_returned = false;
let status_complete = false;
let dependency: Dependency;
Expand Down Expand Up @@ -38,7 +38,7 @@ export async function predict(
// if complete message comes before data, resolve here
if (status_complete) {
app.destroy();
resolve(d as SubmitReturn);
resolve(d as PredictReturn);
}
data_returned = true;
result = d;
Expand All @@ -50,7 +50,7 @@ export async function predict(
// if complete message comes after data, resolve here
if (data_returned) {
app.destroy();
resolve(result as SubmitReturn);
resolve(result as PredictReturn);
}
}
});
Expand Down
Loading

0 comments on commit 63d36fb

Please sign in to comment.