diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 73ad040d64..e10d21bb41 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -18,7 +18,7 @@ It's not a hard requirement, but please consider using an icon from [Gitmoji](ht ## Tests -If you want to run only specific tests, you can do `pnpm test -- -t "test name"` +If you want to run only specific tests, you can do `pnpm test "test name"` ## Adding a package diff --git a/packages/inference/src/lib/InferenceOutputError.ts b/packages/inference/src/lib/InferenceOutputError.ts index 0765b99944..07090c5144 100644 --- a/packages/inference/src/lib/InferenceOutputError.ts +++ b/packages/inference/src/lib/InferenceOutputError.ts @@ -1,7 +1,10 @@ export class InferenceOutputError extends TypeError { - constructor(message: string) { + constructor(err: unknown) { super( - `Invalid inference output: ${message}. Use the 'request' method with the same parameters to do a custom call with no type checking.` + `Invalid inference output: ${ + err instanceof Error ? err.message : String(err) + }. Use the 'request' method with the same parameters to do a custom call with no type checking.`, + err instanceof Error ? { cause: err } : undefined ); this.name = "InferenceOutputError"; } diff --git a/packages/inference/src/lib/validateOutput.spec.ts b/packages/inference/src/lib/validateOutput.spec.ts new file mode 100644 index 0000000000..50e004655b --- /dev/null +++ b/packages/inference/src/lib/validateOutput.spec.ts @@ -0,0 +1,84 @@ +import { describe, expect, it } from "vitest"; +import { z } from "./validateOutput"; + +describe("validateOutput", () => { + it("validates simple types", () => { + expect(z.string().parse("foo")).toBe("foo"); + expect(z.number().parse(42)).toBe(42); + expect(z.blob().parse(new Blob())).toBeInstanceOf(Blob); + }); + + it("errors on simple types", () => { + expect(() => z.string().parse(42)).toThrow(/Expected string/); + expect(() => z.number().parse("foo")).toThrow(/Expected number/); + expect(() => z.blob().parse(42)).toThrow(/Expected Blob/); + }); + it("validates arrays", () => { + expect(z.array(z.string()).parse(["foo"])).toEqual(["foo"]); + expect(z.array(z.number()).parse([42])).toEqual([42]); + expect(z.array(z.blob()).parse([new Blob()])).toEqual([new Blob()]); + }); + it("errors on arrays", () => { + expect(() => z.array(z.string()).parse([42])).toThrow(/Expected Array/); + expect(() => z.array(z.number()).parse(["foo"])).toThrow(/Expected Array/); + expect(() => z.array(z.blob()).parse([42])).toThrow(/Expected Array/); + }); + it("validates objects", () => { + expect(z.object({ foo: z.string() }).parse({ foo: "foo" })).toEqual({ foo: "foo" }); + expect(z.object({ foo: z.number() }).parse({ foo: 42 })).toEqual({ foo: 42 }); + expect(z.object({ foo: z.blob() }).parse({ foo: new Blob() })).toEqual({ foo: new Blob() }); + expect(z.object({ foo: z.string(), bar: z.number() }).parse({ foo: "foo", bar: 42 })).toEqual({ + foo: "foo", + bar: 42, + }); + }); + it("errors on objects", () => { + expect(() => z.object({ foo: z.string() }).parse({ foo: 42 })).toThrow(/Expected { foo: string }/); + expect(() => z.object({ foo: z.number() }).parse({ foo: "foo" })).toThrow(/Expected { foo: number }/); + expect(() => z.object({ foo: z.blob() }).parse({ foo: 42 })).toThrow(/Expected { foo: Blob }/); + expect(() => z.object({ foo: z.string(), bar: z.number() }).parse({ foo: "foo", bar: "bar" })).toThrow( + /Expected { foo: string, bar: number }/ + ); + }); + it("validates unions", () => { + expect(z.or(z.string(), z.number()).parse("foo")).toBe("foo"); + expect(z.or(z.blob(), z.string(), z.number()).parse(42)).toBe(42); + }); + it("errors on unions", () => { + expect(() => z.or(z.string(), z.number()).parse(new Blob())).toThrow(/Expected string | number/); + }); + it("validates a complex object", () => { + expect( + z + .object({ + foo: z.string(), + bar: z.array(z.number()), + baz: z.object({ a: z.string(), b: z.number() }), + }) + .parse({ + foo: "foo", + bar: [42], + baz: { a: "a", b: 42 }, + }) + ).toEqual({ + foo: "foo", + bar: [42], + baz: { a: "a", b: 42 }, + }); + }); + it("errors on a complex object", () => { + expect(() => + z + .object({ + foo: z.string(), + bar: z.array(z.number()), + baz: z.object({ a: z.string(), b: z.number() }), + }) + .parse({ + foo: "foo", + bar: [42], + baz: { a: "a", b: "b" }, + }) + ).toThrow(/Expected { foo: string, bar: Array, baz: { a: string, b: number } }/); + }); +}); diff --git a/packages/inference/src/lib/validateOutput.ts b/packages/inference/src/lib/validateOutput.ts new file mode 100644 index 0000000000..3e07024c4c --- /dev/null +++ b/packages/inference/src/lib/validateOutput.ts @@ -0,0 +1,155 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +/** + * Heavily inspired by zod + * + * Re-created to avoid extra dependencies + */ + +import { InferenceOutputError } from "./InferenceOutputError"; + +interface Parser { + parse: (value: any) => T; + toString: () => string; +} + +export type Infer = ReturnType; + +export const z = { + array(items: T): Parser[]> { + return { + parse: (value: unknown) => { + if (!Array.isArray(value)) { + throw new Error("Expected " + z.array(items).toString()); + } + try { + return value.map((val) => items.parse(val)); + } catch (err) { + throw new Error("Expected " + z.array(items).toString(), { cause: err }); + } + }, + toString(): string { + return `Array<${items.toString()}>`; + }, + }; + }, + first(items: T): Parser> { + return { + parse: (value: unknown) => { + if (!Array.isArray(value) || value.length === 0) { + throw new Error("Expected " + z.first(items).toString()); + } + try { + return items.parse(value[0]); + } catch (err) { + throw new Error("Expected " + z.first(items).toString(), { cause: err }); + } + }, + toString(): string { + return `[${items.toString()}]`; + }, + }; + }, + or: (...items: T): Parser> => ({ + parse: (value: unknown): ReturnType => { + const errors: Error[] = []; + for (const item of items) { + try { + return item.parse(value); + } catch (err) { + errors.push(err as Error); + } + } + throw new Error("Expected " + z.or(...items).toString(), { cause: errors }); + }, + toString(): string { + return items.map((item) => item.toString()).join(" | "); + }, + }), + object>(item: T): Parser<{ [key in keyof T]: Infer }> { + return { + parse: (value: unknown) => { + if (typeof value !== "object" || value === null || Array.isArray(value)) { + throw new Error("Expected " + z.object(item).toString()); + } + try { + return Object.fromEntries( + Object.entries(item).map(([key, val]) => [key, val.parse((value as any)[key])]) + ) as { + [key in keyof T]: Infer; + }; + } catch (err) { + throw new Error("Expected " + z.object(item).toString(), { cause: err }); + } + }, + toString(): string { + return `{ ${Object.entries(item) + .map(([key, val]) => `${key}: ${val.toString()}`) + .join(", ")} }`; + }, + }; + }, + string(): Parser { + return { + parse: (value: unknown): string => { + if (typeof value !== "string") { + throw new Error("Expected " + z.string().toString()); + } + return value; + }, + toString(): string { + return "string"; + }, + }; + }, + number(): Parser { + return { + parse: (value: unknown): number => { + if (typeof value !== "number") { + throw new Error("Expected " + z.number().toString()); + } + return value; + }, + toString(): string { + return "number"; + }, + }; + }, + blob(): Parser { + return { + parse: (value: unknown): Blob => { + if (!(value instanceof Blob)) { + throw new Error("Expected " + z.blob().toString()); + } + return value; + }, + toString(): string { + return "Blob"; + }, + }; + }, + optional(item: T): Parser | undefined> { + return { + parse: (value: unknown): ReturnType | undefined => { + if (value === undefined) { + return undefined; + } + try { + return item.parse(value); + } catch (err) { + throw new Error("Expected " + z.optional(item).toString(), { cause: err }); + } + }, + toString(): string { + return `${item.toString()} | undefined`; + }, + }; + }, +}; + +export function validateOutput(value: unknown, schema: { parse: (value: any) => T }): T { + try { + return schema.parse(value); + } catch (err) { + throw new InferenceOutputError(err); + } +} diff --git a/packages/inference/src/tasks/audio/audioClassification.ts b/packages/inference/src/tasks/audio/audioClassification.ts index 5d7e274e5a..9f2caa7190 100644 --- a/packages/inference/src/tasks/audio/audioClassification.ts +++ b/packages/inference/src/tasks/audio/audioClassification.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -35,10 +35,5 @@ export async function audioClassification( ...options, taskHint: "audio-classification", }); - const isValidOutput = - Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{label: string, score: number}>"); - } - return res; + return validateOutput(res, z.array(z.object({ label: z.string(), score: z.number() }))); } diff --git a/packages/inference/src/tasks/audio/audioToAudio.ts b/packages/inference/src/tasks/audio/audioToAudio.ts index c339cdf61a..a6efb0bc8c 100644 --- a/packages/inference/src/tasks/audio/audioToAudio.ts +++ b/packages/inference/src/tasks/audio/audioToAudio.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -37,13 +37,14 @@ export async function audioToAudio(args: AudioToAudioArgs, options?: Options): P ...options, taskHint: "audio-to-audio", }); - const isValidOutput = - Array.isArray(res) && - res.every( - (x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string" - ); - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{label: string, blob: string, content-type: string}>"); - } - return res; + return validateOutput( + res, + z.array( + z.object({ + label: z.string(), + blob: z.string(), + "content-type": z.string(), + }) + ) + ); } diff --git a/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts b/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts index 600d5b6c74..f29d354ebf 100644 --- a/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts +++ b/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -28,9 +28,5 @@ export async function automaticSpeechRecognition( ...options, taskHint: "automatic-speech-recognition", }); - const isValidOutput = typeof res?.text === "string"; - if (!isValidOutput) { - throw new InferenceOutputError("Expected {text: string}"); - } - return res; + return validateOutput(res, z.object({ text: z.string() })); } diff --git a/packages/inference/src/tasks/audio/textToSpeech.ts b/packages/inference/src/tasks/audio/textToSpeech.ts index 3c466110f5..797b99f283 100644 --- a/packages/inference/src/tasks/audio/textToSpeech.ts +++ b/packages/inference/src/tasks/audio/textToSpeech.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -20,9 +20,5 @@ export async function textToSpeech(args: TextToSpeechArgs, options?: Options): P ...options, taskHint: "text-to-speech", }); - const isValidOutput = res && res instanceof Blob; - if (!isValidOutput) { - throw new InferenceOutputError("Expected Blob"); - } - return res; + return validateOutput(res, z.blob()); } diff --git a/packages/inference/src/tasks/cv/imageClassification.ts b/packages/inference/src/tasks/cv/imageClassification.ts index 2ae7258704..1af10e6337 100644 --- a/packages/inference/src/tasks/cv/imageClassification.ts +++ b/packages/inference/src/tasks/cv/imageClassification.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -34,10 +34,5 @@ export async function imageClassification( ...options, taskHint: "image-classification", }); - const isValidOutput = - Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{label: string, score: number}>"); - } - return res; + return validateOutput(res, z.array(z.object({ label: z.string(), score: z.number() }))); } diff --git a/packages/inference/src/tasks/cv/imageSegmentation.ts b/packages/inference/src/tasks/cv/imageSegmentation.ts index 171f065260..744fe67891 100644 --- a/packages/inference/src/tasks/cv/imageSegmentation.ts +++ b/packages/inference/src/tasks/cv/imageSegmentation.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -38,11 +38,5 @@ export async function imageSegmentation( ...options, taskHint: "image-segmentation", }); - const isValidOutput = - Array.isArray(res) && - res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>"); - } - return res; + return validateOutput(res, z.array(z.object({ label: z.string(), mask: z.string(), score: z.number() }))); } diff --git a/packages/inference/src/tasks/cv/imageToImage.ts b/packages/inference/src/tasks/cv/imageToImage.ts index c96dce363f..b5973d60fe 100644 --- a/packages/inference/src/tasks/cv/imageToImage.ts +++ b/packages/inference/src/tasks/cv/imageToImage.ts @@ -1,7 +1,7 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options, RequestArgs } from "../../types"; import { request } from "../custom/request"; import { base64FromBytes } from "../../../../shared"; +import { validateOutput, z } from "../../lib/validateOutput"; export type ImageToImageArgs = BaseArgs & { /** @@ -78,9 +78,5 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P ...options, taskHint: "image-to-image", }); - const isValidOutput = res && res instanceof Blob; - if (!isValidOutput) { - throw new InferenceOutputError("Expected Blob"); - } - return res; + return validateOutput(res, z.blob()); } diff --git a/packages/inference/src/tasks/cv/imageToText.ts b/packages/inference/src/tasks/cv/imageToText.ts index 9dd3ae8c20..60c34faca4 100644 --- a/packages/inference/src/tasks/cv/imageToText.ts +++ b/packages/inference/src/tasks/cv/imageToText.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -20,16 +20,10 @@ export interface ImageToTextOutput { * This task reads some image input and outputs the text caption. */ export async function imageToText(args: ImageToTextArgs, options?: Options): Promise { - const res = ( - await request<[ImageToTextOutput]>(args, { - ...options, - taskHint: "image-to-text", - }) - )?.[0]; + const res = await request<[ImageToTextOutput]>(args, { + ...options, + taskHint: "image-to-text", + }); - if (typeof res?.generated_text !== "string") { - throw new InferenceOutputError("Expected {generated_text: string}"); - } - - return res; + return validateOutput(res, z.first(z.object({ generated_text: z.string() }))); } diff --git a/packages/inference/src/tasks/cv/objectDetection.ts b/packages/inference/src/tasks/cv/objectDetection.ts index 5bec721156..b40f346280 100644 --- a/packages/inference/src/tasks/cv/objectDetection.ts +++ b/packages/inference/src/tasks/cv/objectDetection.ts @@ -1,6 +1,6 @@ import { request } from "../custom/request"; import type { BaseArgs, Options } from "../../types"; -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; export type ObjectDetectionArgs = BaseArgs & { /** @@ -41,21 +41,15 @@ export async function objectDetection(args: ObjectDetectionArgs, options?: Optio ...options, taskHint: "object-detection", }); - const isValidOutput = - Array.isArray(res) && - res.every( - (x) => - typeof x.label === "string" && - typeof x.score === "number" && - typeof x.box.xmin === "number" && - typeof x.box.ymin === "number" && - typeof x.box.xmax === "number" && - typeof x.box.ymax === "number" - ); - if (!isValidOutput) { - throw new InferenceOutputError( - "Expected Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>" - ); - } - return res; + + return validateOutput( + res, + z.array( + z.object({ + label: z.string(), + score: z.number(), + box: z.object({ xmin: z.number(), ymin: z.number(), xmax: z.number(), ymax: z.number() }), + }) + ) + ); } diff --git a/packages/inference/src/tasks/cv/textToImage.ts b/packages/inference/src/tasks/cv/textToImage.ts index 677b3bc5c7..2aba5dd45c 100644 --- a/packages/inference/src/tasks/cv/textToImage.ts +++ b/packages/inference/src/tasks/cv/textToImage.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -43,9 +43,5 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro ...options, taskHint: "text-to-image", }); - const isValidOutput = res && res instanceof Blob; - if (!isValidOutput) { - throw new InferenceOutputError("Expected Blob"); - } - return res; + return validateOutput(res, z.blob()); } diff --git a/packages/inference/src/tasks/cv/zeroShotImageClassification.ts b/packages/inference/src/tasks/cv/zeroShotImageClassification.ts index cb2b8f0d6a..ab8aadd051 100644 --- a/packages/inference/src/tasks/cv/zeroShotImageClassification.ts +++ b/packages/inference/src/tasks/cv/zeroShotImageClassification.ts @@ -1,8 +1,8 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; import type { RequestArgs } from "../../types"; import { base64FromBytes } from "../../../../shared"; +import { validateOutput, z } from "../../lib/validateOutput"; export type ZeroShotImageClassificationArgs = BaseArgs & { inputs: { @@ -49,10 +49,5 @@ export async function zeroShotImageClassification( ...options, taskHint: "zero-shot-image-classification", }); - const isValidOutput = - Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{label: string, score: number}>"); - } - return res; + return validateOutput(res, z.array(z.object({ label: z.string(), score: z.number() }))); } diff --git a/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts b/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts index f5f993137f..3ddab8c4c5 100644 --- a/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts +++ b/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts @@ -1,9 +1,9 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; import type { RequestArgs } from "../../types"; import { base64FromBytes } from "../../../../shared"; import { toArray } from "../../utils/toArray"; +import { validateOutput, z } from "../../lib/validateOutput"; export type DocumentQuestionAnsweringArgs = BaseArgs & { inputs: { @@ -60,14 +60,17 @@ export async function documentQuestionAnswering( ...options, taskHint: "document-question-answering", }) - )?.[0]; - const isValidOutput = - typeof res?.answer === "string" && - (typeof res.end === "number" || typeof res.end === "undefined") && - (typeof res.score === "number" || typeof res.score === "undefined") && - (typeof res.start === "number" || typeof res.start === "undefined"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>"); - } - return res; + ); + + return validateOutput( + res, + z.first( + z.object({ + answer: z.string(), + end: z.optional(z.number()), + score: z.optional(z.number()), + start: z.optional(z.number()), + }) + ) + ); } diff --git a/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts b/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts index ec45a0fa04..e88249fc22 100644 --- a/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts +++ b/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts @@ -1,7 +1,7 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options, RequestArgs } from "../../types"; import { request } from "../custom/request"; import { base64FromBytes } from "../../../../shared"; +import { validateOutput, z } from "../../lib/validateOutput"; export type VisualQuestionAnsweringArgs = BaseArgs & { inputs: { @@ -45,15 +45,9 @@ export async function visualQuestionAnswering( ), }, } as RequestArgs; - const res = ( - await request<[VisualQuestionAnsweringOutput]>(reqArgs, { - ...options, - taskHint: "visual-question-answering", - }) - )?.[0]; - const isValidOutput = typeof res?.answer === "string" && typeof res.score === "number"; - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{answer: string, score: number}>"); - } - return res; + const res = await request<[VisualQuestionAnsweringOutput]>(reqArgs, { + ...options, + taskHint: "visual-question-answering", + }); + return validateOutput(res, z.first(z.object({ answer: z.string(), score: z.number() }))); } diff --git a/packages/inference/src/tasks/nlp/conversational.ts b/packages/inference/src/tasks/nlp/conversational.ts index e426e7ff90..dc233230e0 100644 --- a/packages/inference/src/tasks/nlp/conversational.ts +++ b/packages/inference/src/tasks/nlp/conversational.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -55,7 +55,7 @@ export interface ConversationalOutput { past_user_inputs: string[]; }; generated_text: string; - warnings: string[]; + warnings?: string[]; } /** @@ -64,18 +64,13 @@ export interface ConversationalOutput { */ export async function conversational(args: ConversationalArgs, options?: Options): Promise { const res = await request(args, { ...options, taskHint: "conversational" }); - const isValidOutput = - Array.isArray(res.conversation.generated_responses) && - res.conversation.generated_responses.every((x) => typeof x === "string") && - Array.isArray(res.conversation.past_user_inputs) && - res.conversation.past_user_inputs.every((x) => typeof x === "string") && - typeof res.generated_text === "string" && - (typeof res.warnings === "undefined" || - (Array.isArray(res.warnings) && res.warnings.every((x) => typeof x === "string"))); - if (!isValidOutput) { - throw new InferenceOutputError( - "Expected {conversation: {generated_responses: string[], past_user_inputs: string[]}, generated_text: string, warnings: string[]}" - ); - } - return res; + + return validateOutput( + res, + z.object({ + conversation: z.object({ generated_responses: z.array(z.string()), past_user_inputs: z.array(z.string()) }), + generated_text: z.string(), + warnings: z.optional(z.array(z.string())), + }) + ); } diff --git a/packages/inference/src/tasks/nlp/featureExtraction.ts b/packages/inference/src/tasks/nlp/featureExtraction.ts index fef6ccc614..4381f61703 100644 --- a/packages/inference/src/tasks/nlp/featureExtraction.ts +++ b/packages/inference/src/tasks/nlp/featureExtraction.ts @@ -1,5 +1,5 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; import { getDefaultTask } from "../../lib/getDefaultTask"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -16,7 +16,7 @@ export type FeatureExtractionArgs = BaseArgs & { /** * Returned values are a multidimensional array of floats (dimension depending on if you sent a string or a list of string, and if the automatic reduction, usually mean_pooling for instance was applied for you or not. This should be explained on the model's README). */ -export type FeatureExtractionOutput = (number | number[] | number[][])[]; +export type FeatureExtractionOutput = (number | number[] | number[][] | number[][][])[]; /** * This task reads some text and outputs raw float values, that are usually consumed as part of a semantic database/semantic search. @@ -32,21 +32,8 @@ export async function featureExtraction( taskHint: "feature-extraction", ...(defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" }), }); - let isValidOutput = true; - - const isNumArrayRec = (arr: unknown[], maxDepth: number, curDepth = 0): boolean => { - if (curDepth > maxDepth) return false; - if (arr.every((x) => Array.isArray(x))) { - return arr.every((x) => isNumArrayRec(x as unknown[], maxDepth, curDepth + 1)); - } else { - return arr.every((x) => typeof x === "number"); - } - }; - - isValidOutput = Array.isArray(res) && isNumArrayRec(res, 3, 0); - - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array"); - } - return res; + return validateOutput( + res, + z.array(z.or(z.number(), z.array(z.number()), z.array(z.array(z.number())), z.array(z.array(z.array(z.number()))))) + ); } diff --git a/packages/inference/src/tasks/nlp/fillMask.ts b/packages/inference/src/tasks/nlp/fillMask.ts index b8a2af1286..f74c7cc57f 100644 --- a/packages/inference/src/tasks/nlp/fillMask.ts +++ b/packages/inference/src/tasks/nlp/fillMask.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -33,19 +33,8 @@ export async function fillMask(args: FillMaskArgs, options?: Options): Promise - typeof x.score === "number" && - typeof x.sequence === "string" && - typeof x.token === "number" && - typeof x.token_str === "string" - ); - if (!isValidOutput) { - throw new InferenceOutputError( - "Expected Array<{score: number, sequence: string, token: number, token_str: string}>" - ); - } - return res; + return validateOutput( + res, + z.array(z.object({ score: z.number(), sequence: z.string(), token: z.number(), token_str: z.string() })) + ); } diff --git a/packages/inference/src/tasks/nlp/questionAnswering.ts b/packages/inference/src/tasks/nlp/questionAnswering.ts index 58074eb9c9..3cbe783eb2 100644 --- a/packages/inference/src/tasks/nlp/questionAnswering.ts +++ b/packages/inference/src/tasks/nlp/questionAnswering.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -39,15 +39,5 @@ export async function questionAnswering( ...options, taskHint: "question-answering", }); - const isValidOutput = - typeof res === "object" && - !!res && - typeof res.answer === "string" && - typeof res.end === "number" && - typeof res.score === "number" && - typeof res.start === "number"; - if (!isValidOutput) { - throw new InferenceOutputError("Expected {answer: string, end: number, score: number, start: number}"); - } - return res; + return validateOutput(res, z.object({ answer: z.string(), end: z.number(), score: z.number(), start: z.number() })); } diff --git a/packages/inference/src/tasks/nlp/sentenceSimilarity.ts b/packages/inference/src/tasks/nlp/sentenceSimilarity.ts index ec5c173ca2..078d0d8511 100644 --- a/packages/inference/src/tasks/nlp/sentenceSimilarity.ts +++ b/packages/inference/src/tasks/nlp/sentenceSimilarity.ts @@ -1,5 +1,5 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; import { getDefaultTask } from "../../lib/getDefaultTask"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -31,10 +31,5 @@ export async function sentenceSimilarity( taskHint: "sentence-similarity", ...(defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }), }); - - const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected number[]"); - } - return res; + return validateOutput(res, z.array(z.number())); } diff --git a/packages/inference/src/tasks/nlp/summarization.ts b/packages/inference/src/tasks/nlp/summarization.ts index 71efd1c3b9..5328d1e309 100644 --- a/packages/inference/src/tasks/nlp/summarization.ts +++ b/packages/inference/src/tasks/nlp/summarization.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -54,9 +54,5 @@ export async function summarization(args: SummarizationArgs, options?: Options): ...options, taskHint: "summarization", }); - const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{summary_text: string}>"); - } - return res?.[0]; + return validateOutput(res, z.first(z.object({ summary_text: z.string() }))); } diff --git a/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts b/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts index a0cf692512..b1d852320a 100644 --- a/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts +++ b/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -45,17 +45,13 @@ export async function tableQuestionAnswering( ...options, taskHint: "table-question-answering", }); - const isValidOutput = - typeof res?.aggregator === "string" && - typeof res.answer === "string" && - Array.isArray(res.cells) && - res.cells.every((x) => typeof x === "string") && - Array.isArray(res.coordinates) && - res.coordinates.every((coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")); - if (!isValidOutput) { - throw new InferenceOutputError( - "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}" - ); - } - return res; + return validateOutput( + res, + z.object({ + aggregator: z.string(), + answer: z.string(), + cells: z.array(z.string()), + coordinates: z.array(z.array(z.number())), + }) + ); } diff --git a/packages/inference/src/tasks/nlp/textClassification.ts b/packages/inference/src/tasks/nlp/textClassification.ts index 41ced40571..bf87993b3d 100644 --- a/packages/inference/src/tasks/nlp/textClassification.ts +++ b/packages/inference/src/tasks/nlp/textClassification.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -27,16 +27,9 @@ export async function textClassification( args: TextClassificationArgs, options?: Options ): Promise { - const res = ( - await request(args, { - ...options, - taskHint: "text-classification", - }) - )?.[0]; - const isValidOutput = - Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{label: string, score: number}>"); - } - return res; + const res = await request(args, { + ...options, + taskHint: "text-classification", + }); + return validateOutput(res, z.first(z.array(z.object({ label: z.string(), score: z.number() })))); } diff --git a/packages/inference/src/tasks/nlp/textGeneration.ts b/packages/inference/src/tasks/nlp/textGeneration.ts index 6b550a8846..d32e053872 100644 --- a/packages/inference/src/tasks/nlp/textGeneration.ts +++ b/packages/inference/src/tasks/nlp/textGeneration.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -70,9 +70,5 @@ export async function textGeneration(args: TextGenerationArgs, options?: Options ...options, taskHint: "text-generation", }); - const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{generated_text: string}>"); - } - return res?.[0]; + return validateOutput(res, z.first(z.object({ generated_text: z.string() }))); } diff --git a/packages/inference/src/tasks/nlp/tokenClassification.ts b/packages/inference/src/tasks/nlp/tokenClassification.ts index eeee58d4c6..7b9593f81e 100644 --- a/packages/inference/src/tasks/nlp/tokenClassification.ts +++ b/packages/inference/src/tasks/nlp/tokenClassification.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { toArray } from "../../utils/toArray"; import { request } from "../custom/request"; @@ -64,20 +64,10 @@ export async function tokenClassification( taskHint: "token-classification", }) ); - const isValidOutput = - Array.isArray(res) && - res.every( - (x) => - typeof x.end === "number" && - typeof x.entity_group === "string" && - typeof x.score === "number" && - typeof x.start === "number" && - typeof x.word === "string" - ); - if (!isValidOutput) { - throw new InferenceOutputError( - "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>" - ); - } - return res; + return validateOutput( + res, + z.array( + z.object({ end: z.number(), entity_group: z.string(), score: z.number(), start: z.number(), word: z.string() }) + ) + ); } diff --git a/packages/inference/src/tasks/nlp/translation.ts b/packages/inference/src/tasks/nlp/translation.ts index ea7a3054c0..25dbae1bfa 100644 --- a/packages/inference/src/tasks/nlp/translation.ts +++ b/packages/inference/src/tasks/nlp/translation.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -26,9 +26,6 @@ export async function translation(args: TranslationArgs, options?: Options): Pro ...options, taskHint: "translation", }); - const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected type Array<{translation_text: string}>"); - } - return res?.length === 1 ? res?.[0] : res; + const output = validateOutput(res, z.array(z.object({ translation_text: z.string() }))); + return output.length === 1 ? output[0] : output; } diff --git a/packages/inference/src/tasks/nlp/zeroShotClassification.ts b/packages/inference/src/tasks/nlp/zeroShotClassification.ts index 2552489c36..47f3e7587c 100644 --- a/packages/inference/src/tasks/nlp/zeroShotClassification.ts +++ b/packages/inference/src/tasks/nlp/zeroShotClassification.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { toArray } from "../../utils/toArray"; import { request } from "../custom/request"; @@ -41,18 +41,8 @@ export async function zeroShotClassification( taskHint: "zero-shot-classification", }) ); - const isValidOutput = - Array.isArray(res) && - res.every( - (x) => - Array.isArray(x.labels) && - x.labels.every((_label) => typeof _label === "string") && - Array.isArray(x.scores) && - x.scores.every((_score) => typeof _score === "number") && - typeof x.sequence === "string" - ); - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>"); - } - return res; + return validateOutput( + res, + z.array(z.object({ labels: z.array(z.string()), scores: z.array(z.number()), sequence: z.string() })) + ); } diff --git a/packages/inference/src/tasks/tabular/tabularClassification.ts b/packages/inference/src/tasks/tabular/tabularClassification.ts index f53e926e94..c18af97650 100644 --- a/packages/inference/src/tasks/tabular/tabularClassification.ts +++ b/packages/inference/src/tasks/tabular/tabularClassification.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -29,9 +29,5 @@ export async function tabularClassification( ...options, taskHint: "tabular-classification", }); - const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected number[]"); - } - return res; + return validateOutput(res, z.array(z.number())); } diff --git a/packages/inference/src/tasks/tabular/tabularRegression.ts b/packages/inference/src/tasks/tabular/tabularRegression.ts index e6bd9e3de1..9526a43cc2 100644 --- a/packages/inference/src/tasks/tabular/tabularRegression.ts +++ b/packages/inference/src/tasks/tabular/tabularRegression.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { validateOutput, z } from "../../lib/validateOutput"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -29,9 +29,5 @@ export async function tabularRegression( ...options, taskHint: "tabular-regression", }); - const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected number[]"); - } - return res; + return validateOutput(res, z.array(z.number())); } diff --git a/packages/inference/test/HfInference.spec.ts b/packages/inference/test/HfInference.spec.ts index 280dd049aa..db33da8da4 100644 --- a/packages/inference/test/HfInference.spec.ts +++ b/packages/inference/test/HfInference.spec.ts @@ -519,7 +519,7 @@ describe.concurrent( it("objectDetection", async () => { expect( - await hf.imageClassification({ + await hf.objectDetection({ data: new Blob([readTestFile("cats.png")], { type: "image/png" }), model: "facebook/detr-resnet-50", }) @@ -540,7 +540,7 @@ describe.concurrent( }); it("imageSegmentation", async () => { expect( - await hf.imageClassification({ + await hf.imageSegmentation({ data: new Blob([readTestFile("cats.png")], { type: "image/png" }), model: "facebook/detr-resnet-50-panoptic", })