Skip to content

Commit

Permalink
Improve file handling in JS Client (#8462)
Browse files Browse the repository at this point in the history
* add handler for URLs, Blobs and Files

* add changeset

* remove NodeBlob

* add local file handling

* handle buffers

* add test

* type tweaks

* fix node test with file

* test

* fix test

* handle nested files

* env tweaks

* tweak

* fix test

* use file instead of blob

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
hannahblair and gradio-pr-bot authored Jun 6, 2024
1 parent 8c18114 commit 6447dfa
Show file tree
Hide file tree
Showing 9 changed files with 256 additions and 47 deletions.
6 changes: 6 additions & 0 deletions .changeset/violet-pans-doubt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@gradio/client": patch
"gradio": patch
---

fix:Improve file handling in JS Client
8 changes: 0 additions & 8 deletions client/js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@ import { check_space_status } from "./helpers/spaces";
import { open_stream } from "./utils/stream";
import { API_INFO_ERROR_MSG, CONFIG_ERROR_MSG } from "./constants";

export class NodeBlob extends Blob {
constructor(blobParts?: BlobPart[], options?: BlobPropertyBag) {
super(blobParts, options);
}
}

export class Client {
app_reference: string;
options: ClientOptions;
Expand Down Expand Up @@ -141,8 +135,6 @@ export class Client {
!global.WebSocket
) {
const ws = await import("ws");
// @ts-ignore
NodeBlob = (await import("node:buffer")).Blob;
global.WebSocket = ws.WebSocket as unknown as typeof WebSocket;
}

Expand Down
4 changes: 4 additions & 0 deletions client/js/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,7 @@ export const UNAUTHORIZED_MSG = "Not authorized to access this space. ";
export const INVALID_CREDENTIALS_MSG = "Invalid credentials. Could not login. ";
export const MISSING_CREDENTIALS_MSG =
"Login credentials are required to access this space.";
export const NODEJS_FS_ERROR_MSG =
"File system access is only available in Node.js environments";
export const ROOT_URL_ERROR_MSG = "Root URL not found in client config";
export const FILE_PROCESSING_ERROR_MSG = "Error uploading file";
77 changes: 65 additions & 12 deletions client/js/src/helpers/data.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import { NodeBlob } from "../client";
import type {
ApiData,
BlobRef,
Config,
EndpointInfo,
JsApiData,
DataType,
Dependency,
ComponentMeta
import {
type ApiData,
type BlobRef,
type Config,
type EndpointInfo,
type JsApiData,
type DataType,
Command,
type Dependency,
type ComponentMeta
} from "../types";
import { FileData } from "../upload";

const is_node =
typeof process !== "undefined" && process.versions && process.versions.node;

export function update_object(
object: { [x: string]: any },
Expand Down Expand Up @@ -66,11 +70,10 @@ export async function walk_and_store_blobs(
(globalThis.Buffer && data instanceof globalThis.Buffer) ||
data instanceof Blob
) {
const is_image = type === "Image";
return [
{
path: path,
blob: is_image ? false : new NodeBlob([data]),
blob: new Blob([data]),
type
}
];
Expand Down Expand Up @@ -121,6 +124,56 @@ export function post_message<Res = any>(
});
}

export function handle_file(
file_or_url: File | string | Blob | Buffer
): FileData | Blob | Command {
if (typeof file_or_url === "string") {
if (
file_or_url.startsWith("http://") ||
file_or_url.startsWith("https://")
) {
return {
path: file_or_url,
url: file_or_url,
orig_name: file_or_url.split("/").pop() ?? "unknown",
meta: { _type: "gradio.FileData" }
};
}

if (is_node) {
// Handle local file paths
return new Command("upload_file", {
path: file_or_url,
name: file_or_url,
orig_path: file_or_url
});
}
} else if (typeof File !== "undefined" && file_or_url instanceof File) {
return {
path: file_or_url instanceof File ? file_or_url.name : "blob",
orig_name: file_or_url instanceof File ? file_or_url.name : "unknown",
// @ts-ignore
blob: file_or_url instanceof File ? file_or_url : new Blob([file_or_url]),
size:
file_or_url instanceof Blob
? file_or_url.size
: Buffer.byteLength(file_or_url as Buffer),
mime_type:
file_or_url instanceof File
? file_or_url.type
: "application/octet-stream", // Default MIME type for buffers
meta: { _type: "gradio.FileData" }
};
} else if (file_or_url instanceof Buffer) {
return new Blob([file_or_url]);
} else if (file_or_url instanceof Blob) {
return file_or_url;
}
throw new Error(
"Invalid input: must be a URL, File, Blob, or Buffer object."
);
}

/**
* Handles the payload by filtering out state inputs and returning an array of resolved payload values.
* We send null values for state inputs to the server, but we don't want to include them in the resolved payload.
Expand Down
1 change: 1 addition & 0 deletions client/js/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export { predict } from "./utils/predict";
export { submit } from "./utils/submit";
export { upload_files } from "./utils/upload_files";
export { FileData, upload, prepare_files } from "./upload";
export { handle_file } from "./helpers/data";

export type {
SpaceStatus,
Expand Down
92 changes: 67 additions & 25 deletions client/js/src/test/data.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,22 @@ import {
walk_and_store_blobs,
skip_queue,
post_message,
handle_file,
handle_payload
} from "../helpers/data";
import { NodeBlob } from "../client";
import { config_response, endpoint_info } from "./test_data";
import { BlobRef } from "../types";
import { BlobRef, Command } from "../types";
import { FileData } from "../upload";

const IS_NODE = process.env.TEST_MODE === "node";

describe("walk_and_store_blobs", () => {
it("should convert a Buffer to a Blob", async () => {
const buffer = Buffer.from("test data");
const parts = await walk_and_store_blobs(buffer, "text");

expect(parts).toHaveLength(1);
expect(parts[0].blob).toBeInstanceOf(NodeBlob);
expect(parts[0].blob).toBeInstanceOf(Blob);
});

it("should return a Blob when passed a Blob", async () => {
Expand All @@ -29,27 +32,15 @@ describe("walk_and_store_blobs", () => {
endpoint_info
);

expect(parts[0].blob).toBeInstanceOf(NodeBlob);
});

it("should return blob: false when passed an image", async () => {
const blob = new Blob([]);
const parts = await walk_and_store_blobs(
blob,
"Image",
[],
true,
endpoint_info
);
expect(parts[0].blob).toBe(false);
expect(parts[0].blob).toBeInstanceOf(Blob);
});

it("should handle arrays", async () => {
const image = new Blob([]);
const parts = await walk_and_store_blobs([image]);

expect(parts).toHaveLength(1);
expect(parts[0].blob).toBeInstanceOf(NodeBlob);
expect(parts[0].blob).toBeInstanceOf(Blob);
expect(parts[0].path).toEqual(["0"]);
});

Expand All @@ -58,7 +49,7 @@ describe("walk_and_store_blobs", () => {
const parts = await walk_and_store_blobs({ a: { b: { data: { image } } } });

expect(parts).toHaveLength(1);
expect(parts[0].blob).toBeInstanceOf(NodeBlob);
expect(parts[0].blob).toBeInstanceOf(Blob);
expect(parts[0].path).toEqual(["a", "b", "data", "image"]);
});

Expand All @@ -80,7 +71,7 @@ describe("walk_and_store_blobs", () => {
]
});

expect(parts[0].blob).toBeInstanceOf(NodeBlob);
expect(parts[0].blob).toBeInstanceOf(Blob);
});

it("should handle deep structures with arrays (with equality check)", async () => {
Expand All @@ -104,8 +95,8 @@ describe("walk_and_store_blobs", () => {
let ref = obj;
path.forEach((p) => (ref = ref[p]));

// since ref is a Blob and blob is a NodeBlob, we deep equal check the two buffers instead
if (ref instanceof Blob && blob instanceof NodeBlob) {
// since ref is a Blob and blob is a Blob, we deep equal check the two buffers instead
if (ref instanceof Blob && blob instanceof Blob) {
const refBuffer = Buffer.from(await ref.arrayBuffer());
const blobBuffer = Buffer.from(await blob.arrayBuffer());
return refBuffer.equals(blobBuffer);
Expand All @@ -114,7 +105,7 @@ describe("walk_and_store_blobs", () => {
return ref === blob;
}

expect(parts[0].blob).toBeInstanceOf(NodeBlob);
expect(parts[0].blob).toBeInstanceOf(Blob);
expect(map_path(obj, parts)).toBeTruthy();
});

Expand All @@ -123,7 +114,7 @@ describe("walk_and_store_blobs", () => {
const parts = await walk_and_store_blobs(buffer, undefined, ["blob"]);

expect(parts).toHaveLength(1);
expect(parts[0].blob).toBeInstanceOf(NodeBlob);
expect(parts[0].blob).toBeInstanceOf(Blob);
expect(parts[0].path).toEqual(["blob"]);
});

Expand All @@ -133,7 +124,7 @@ describe("walk_and_store_blobs", () => {

expect(parts).toHaveLength(1);
expect(parts[0].path).toEqual([]);
expect(parts[0].blob).toBeInstanceOf(NodeBlob);
expect(parts[0].blob).toBeInstanceOf(Blob);
});

it("should convert an object with deep structures to BlobRefs", async () => {
Expand All @@ -150,7 +141,7 @@ describe("walk_and_store_blobs", () => {

expect(parts).toHaveLength(1);
expect(parts[0].path).toEqual(["a", "b", "data", "image"]);
expect(parts[0].blob).toBeInstanceOf(NodeBlob);
expect(parts[0].blob).toBeInstanceOf(Blob);
});
});
describe("update_object", () => {
Expand Down Expand Up @@ -278,6 +269,57 @@ describe("post_message", () => {
});
});

describe("handle_file", () => {
it("should handle a Blob object and return the blob", () => {
const blob = new Blob(["test data"], { type: "image/png" });
const result = handle_file(blob) as FileData;

expect(result).toBe(blob);
});

it("should handle a Buffer object and return it as a blob", () => {
const buffer = Buffer.from("test data");
const result = handle_file(buffer) as FileData;
expect(result).toBeInstanceOf(Blob);
});
it("should handle a local file path and return a Command object", () => {
const file_path = "./owl.png";
const result = handle_file(file_path) as Command;
expect(result).toBeInstanceOf(Command);
expect(result).toEqual({
type: "command",
command: "upload_file",
meta: { path: "./owl.png", name: "./owl.png", orig_path: "./owl.png" },
fileData: undefined
});
});

it("should handle a File object and return it as FileData", () => {
if (IS_NODE) {
return;
}
const file = new File(["test image"], "test.png", { type: "image/png" });
const result = handle_file(file) as FileData;
expect(result.path).toBe("test.png");
expect(result.orig_name).toBe("test.png");
expect(result.blob).toBeInstanceOf(Blob);
expect(result.size).toBe(file.size);
expect(result.mime_type).toBe("image/png");
expect(result.meta).toEqual({ _type: "gradio.FileData" });
});

it("should throw an error for invalid input", () => {
const invalid_input = 123;

expect(() => {
// @ts-ignore
handle_file(invalid_input);
}).toThrowError(
"Invalid input: must be a URL, File, Blob, or Buffer object."
);
});
});

describe("handle_payload", () => {
it("should return an input payload with null in place of `state` when with_null_state is true", () => {
const resolved_payload = [2];
Expand Down
21 changes: 21 additions & 0 deletions client/js/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,27 @@ export interface BlobRef {

export type DataType = string | Buffer | Record<string, any> | any[];

// custom class used for uploading local files
export class Command {
type: string;
command: string;
meta: {
path: string;
name: string;
orig_path: string;
};
fileData?: FileData;

constructor(
command: string,
meta: { path: string; name: string; orig_path: string }
) {
this.type = "command";
this.command = command;
this.meta = meta;
}
}

// Function Signature Types

export type SubmitFunction = (
Expand Down
Loading

0 comments on commit 6447dfa

Please sign in to comment.