Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ It's not a hard requirement, but please consider using an icon from [Gitmoji](ht

## Tests

If you want to run only specific tests, you can do `pnpm test -- -t "test name"`
If you want to run only specific tests, you can do `pnpm test "test name"`

## Adding a package

Expand Down
7 changes: 5 additions & 2 deletions packages/inference/src/lib/InferenceOutputError.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
export class InferenceOutputError extends TypeError {
constructor(message: string) {
constructor(err: unknown) {
super(
`Invalid inference output: ${message}. Use the 'request' method with the same parameters to do a custom call with no type checking.`
`Invalid inference output: ${
err instanceof Error ? err.message : String(err)
}. Use the 'request' method with the same parameters to do a custom call with no type checking.`,
err instanceof Error ? { cause: err } : undefined
);
this.name = "InferenceOutputError";
}
Expand Down
84 changes: 84 additions & 0 deletions packages/inference/src/lib/validateOutput.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import { describe, expect, it } from "vitest";
import { z } from "./validateOutput";

describe("validateOutput", () => {
it("validates simple types", () => {
expect(z.string().parse("foo")).toBe("foo");
expect(z.number().parse(42)).toBe(42);
expect(z.blob().parse(new Blob())).toBeInstanceOf(Blob);
});

it("errors on simple types", () => {
expect(() => z.string().parse(42)).toThrow(/Expected string/);
expect(() => z.number().parse("foo")).toThrow(/Expected number/);
expect(() => z.blob().parse(42)).toThrow(/Expected Blob/);
});
it("validates arrays", () => {
expect(z.array(z.string()).parse(["foo"])).toEqual(["foo"]);
expect(z.array(z.number()).parse([42])).toEqual([42]);
expect(z.array(z.blob()).parse([new Blob()])).toEqual([new Blob()]);
});
it("errors on arrays", () => {
expect(() => z.array(z.string()).parse([42])).toThrow(/Expected Array<string>/);
expect(() => z.array(z.number()).parse(["foo"])).toThrow(/Expected Array<number>/);
expect(() => z.array(z.blob()).parse([42])).toThrow(/Expected Array<Blob>/);
});
it("validates objects", () => {
expect(z.object({ foo: z.string() }).parse({ foo: "foo" })).toEqual({ foo: "foo" });
expect(z.object({ foo: z.number() }).parse({ foo: 42 })).toEqual({ foo: 42 });
expect(z.object({ foo: z.blob() }).parse({ foo: new Blob() })).toEqual({ foo: new Blob() });
expect(z.object({ foo: z.string(), bar: z.number() }).parse({ foo: "foo", bar: 42 })).toEqual({
foo: "foo",
bar: 42,
});
});
it("errors on objects", () => {
expect(() => z.object({ foo: z.string() }).parse({ foo: 42 })).toThrow(/Expected { foo: string }/);
expect(() => z.object({ foo: z.number() }).parse({ foo: "foo" })).toThrow(/Expected { foo: number }/);
expect(() => z.object({ foo: z.blob() }).parse({ foo: 42 })).toThrow(/Expected { foo: Blob }/);
expect(() => z.object({ foo: z.string(), bar: z.number() }).parse({ foo: "foo", bar: "bar" })).toThrow(
/Expected { foo: string, bar: number }/
);
});
it("validates unions", () => {
expect(z.or(z.string(), z.number()).parse("foo")).toBe("foo");
expect(z.or(z.blob(), z.string(), z.number()).parse(42)).toBe(42);
});
it("errors on unions", () => {
expect(() => z.or(z.string(), z.number()).parse(new Blob())).toThrow(/Expected string | number/);
});
it("validates a complex object", () => {
expect(
z
.object({
foo: z.string(),
bar: z.array(z.number()),
baz: z.object({ a: z.string(), b: z.number() }),
})
.parse({
foo: "foo",
bar: [42],
baz: { a: "a", b: 42 },
})
).toEqual({
foo: "foo",
bar: [42],
baz: { a: "a", b: 42 },
});
});
it("errors on a complex object", () => {
expect(() =>
z
.object({
foo: z.string(),
bar: z.array(z.number()),
baz: z.object({ a: z.string(), b: z.number() }),
})
.parse({
foo: "foo",
bar: [42],
baz: { a: "a", b: "b" },
})
).toThrow(/Expected { foo: string, bar: Array<number>, baz: { a: string, b: number } }/);
});
});
155 changes: 155 additions & 0 deletions packages/inference/src/lib/validateOutput.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/**
* Heavily inspired by zod
*
* Re-created to avoid extra dependencies
*/

import { InferenceOutputError } from "./InferenceOutputError";

interface Parser<T = any> {
parse: (value: any) => T;
toString: () => string;
}

export type Infer<T extends Parser> = ReturnType<T["parse"]>;

export const z = {
array<T extends Parser>(items: T): Parser<Infer<T>[]> {
return {
parse: (value: unknown) => {
if (!Array.isArray(value)) {
throw new Error("Expected " + z.array(items).toString());
}
try {
return value.map((val) => items.parse(val));
} catch (err) {
throw new Error("Expected " + z.array(items).toString(), { cause: err });
}
},
toString(): string {
return `Array<${items.toString()}>`;
},
};
},
first<T extends Parser>(items: T): Parser<Infer<T>> {
return {
parse: (value: unknown) => {
if (!Array.isArray(value) || value.length === 0) {
throw new Error("Expected " + z.first(items).toString());
}
try {
return items.parse(value[0]);
} catch (err) {
throw new Error("Expected " + z.first(items).toString(), { cause: err });
}
},
toString(): string {
return `[${items.toString()}]`;
},
};
},
or: <T extends Parser[]>(...items: T): Parser<Infer<T[number]>> => ({
parse: (value: unknown): ReturnType<T[number]["parse"]> => {
const errors: Error[] = [];
for (const item of items) {
try {
return item.parse(value);
} catch (err) {
errors.push(err as Error);
}
}
throw new Error("Expected " + z.or(...items).toString(), { cause: errors });
},
toString(): string {
return items.map((item) => item.toString()).join(" | ");
},
}),
object<T extends Record<string, Parser>>(item: T): Parser<{ [key in keyof T]: Infer<T[key]> }> {
return {
parse: (value: unknown) => {
if (typeof value !== "object" || value === null || Array.isArray(value)) {
throw new Error("Expected " + z.object(item).toString());
}
try {
return Object.fromEntries(
Object.entries(item).map(([key, val]) => [key, val.parse((value as any)[key])])
) as {
[key in keyof T]: Infer<T[key]>;
};
} catch (err) {
throw new Error("Expected " + z.object(item).toString(), { cause: err });
}
},
toString(): string {
return `{ ${Object.entries(item)
.map(([key, val]) => `${key}: ${val.toString()}`)
.join(", ")} }`;
},
};
},
string(): Parser<string> {
return {
parse: (value: unknown): string => {
if (typeof value !== "string") {
throw new Error("Expected " + z.string().toString());
}
return value;
},
toString(): string {
return "string";
},
};
},
number(): Parser<number> {
return {
parse: (value: unknown): number => {
if (typeof value !== "number") {
throw new Error("Expected " + z.number().toString());
}
return value;
},
toString(): string {
return "number";
},
};
},
blob(): Parser<Blob> {
return {
parse: (value: unknown): Blob => {
if (!(value instanceof Blob)) {
throw new Error("Expected " + z.blob().toString());
}
return value;
},
toString(): string {
return "Blob";
},
};
},
optional<T extends Parser>(item: T): Parser<Infer<T> | undefined> {
return {
parse: (value: unknown): ReturnType<T["parse"]> | undefined => {
if (value === undefined) {
return undefined;
}
try {
return item.parse(value);
} catch (err) {
throw new Error("Expected " + z.optional(item).toString(), { cause: err });
}
},
toString(): string {
return `${item.toString()} | undefined`;
},
};
},
};

export function validateOutput<T>(value: unknown, schema: { parse: (value: any) => T }): T {
try {
return schema.parse(value);
} catch (err) {
throw new InferenceOutputError(err);
}
}
9 changes: 2 additions & 7 deletions packages/inference/src/tasks/audio/audioClassification.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { validateOutput, z } from "../../lib/validateOutput";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";

Expand Down Expand Up @@ -35,10 +35,5 @@ export async function audioClassification(
...options,
taskHint: "audio-classification",
});
const isValidOutput =
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
if (!isValidOutput) {
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
}
return res;
return validateOutput(res, z.array(z.object({ label: z.string(), score: z.number() })));
}
21 changes: 11 additions & 10 deletions packages/inference/src/tasks/audio/audioToAudio.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { validateOutput, z } from "../../lib/validateOutput";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";

Expand Down Expand Up @@ -37,13 +37,14 @@ export async function audioToAudio(args: AudioToAudioArgs, options?: Options): P
...options,
taskHint: "audio-to-audio",
});
const isValidOutput =
Array.isArray(res) &&
res.every(
(x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string"
);
if (!isValidOutput) {
throw new InferenceOutputError("Expected Array<{label: string, blob: string, content-type: string}>");
}
return res;
return validateOutput(
res,
z.array(
z.object({
label: z.string(),
blob: z.string(),
"content-type": z.string(),
})
)
);
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { validateOutput, z } from "../../lib/validateOutput";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";

Expand Down Expand Up @@ -28,9 +28,5 @@ export async function automaticSpeechRecognition(
...options,
taskHint: "automatic-speech-recognition",
});
const isValidOutput = typeof res?.text === "string";
if (!isValidOutput) {
throw new InferenceOutputError("Expected {text: string}");
}
return res;
return validateOutput(res, z.object({ text: z.string() }));
}
8 changes: 2 additions & 6 deletions packages/inference/src/tasks/audio/textToSpeech.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { validateOutput, z } from "../../lib/validateOutput";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";

Expand All @@ -20,9 +20,5 @@ export async function textToSpeech(args: TextToSpeechArgs, options?: Options): P
...options,
taskHint: "text-to-speech",
});
const isValidOutput = res && res instanceof Blob;
if (!isValidOutput) {
throw new InferenceOutputError("Expected Blob");
}
return res;
return validateOutput(res, z.blob());
}
9 changes: 2 additions & 7 deletions packages/inference/src/tasks/cv/imageClassification.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { validateOutput, z } from "../../lib/validateOutput";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";

Expand Down Expand Up @@ -34,10 +34,5 @@ export async function imageClassification(
...options,
taskHint: "image-classification",
});
const isValidOutput =
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
if (!isValidOutput) {
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
}
return res;
return validateOutput(res, z.array(z.object({ label: z.string(), score: z.number() })));
}
10 changes: 2 additions & 8 deletions packages/inference/src/tasks/cv/imageSegmentation.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { validateOutput, z } from "../../lib/validateOutput";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";

Expand Down Expand Up @@ -38,11 +38,5 @@ export async function imageSegmentation(
...options,
taskHint: "image-segmentation",
});
const isValidOutput =
Array.isArray(res) &&
res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
if (!isValidOutput) {
throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
}
return res;
return validateOutput(res, z.array(z.object({ label: z.string(), mask: z.string(), score: z.number() })));
}
Loading