Skip to content

Commit

Permalink
move step.ai -> step.ai.wrap and rough out step.ai.infer (#747)
Browse files Browse the repository at this point in the history
## Summary
WIP on inferf



## Checklist
<!-- Tick these items off as you progress. -->
<!-- If an item isn't applicable, ideally please strikeout the item by
wrapping it in "~~"" and suffix it with "N/A My reason for skipping
this." -->
<!-- e.g. "- [ ] ~~Added tests~~ N/A Only touches docs" -->

- [ ] ~Added a [docs PR](https://github.com/inngest/website) that
references this PR~ N/A Covered elsewhere
- [x] Added unit/integration tests
- [x] Added changesets if applicable

---------

Co-authored-by: Tony Holdstock-Brown <tonyhb@gmail.com>
Co-authored-by: Jack Williams <1736957+jpwilliams@users.noreply.github.com>
Co-authored-by: Jack Williams <jack@inngest.com>
  • Loading branch information
4 people authored Nov 21, 2024
1 parent 5e1c665 commit 871a958
Show file tree
Hide file tree
Showing 14 changed files with 1,335 additions and 48 deletions.
5 changes: 5 additions & 0 deletions .changeset/perfect-guests-eat.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"inngest": minor
---

Add `step.ai.*()` tooling, allowing users to leverage AI workflows within Inngest functions
10 changes: 10 additions & 0 deletions packages/inngest/src/api/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ export const stepsSchemas = {
})
.strict()
)
.or(
z
.object({
type: z.literal("input").optional().default("input"),
input: z.any().refine((v) => typeof v !== "undefined", {
message: "If input is present it must not be `undefined`",
}),
})
.strict()
)

/**
* If the result isn't a distcint `data` or `error` object, then it's
Expand Down
2 changes: 1 addition & 1 deletion packages/inngest/src/cloudflare.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ export const frameworkName: SupportedFrameworkName = "cloudflare-pages";
* Expected arguments for a Cloudflare Pages Function.
*/
export type PagesHandlerArgs = [
{ request: Request; env: Record<string, string | undefined> }
{ request: Request; env: Record<string, string | undefined> },
];

/**
Expand Down
5 changes: 4 additions & 1 deletion packages/inngest/src/components/InngestCommHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1394,7 +1394,10 @@ export class InngestCommHandler<
result.type === "data"
? // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
{ id, data: result.data }
: { id, error: result.error },
: result.type === "input"
? // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
{ id, input: result.input }
: { id, error: result.error },
};
}, {});

Expand Down
284 changes: 284 additions & 0 deletions packages/inngest/src/components/InngestStepTools.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
/* eslint-disable @typescript-eslint/no-unsafe-assignment */
import { openai } from "@local/components/ai";
import { EventSchemas } from "@local/components/EventSchemas";
import { type Inngest } from "@local/components/Inngest";
import { InngestFunction } from "@local/components/InngestFunction";
Expand Down Expand Up @@ -240,6 +241,289 @@ describe("run", () => {
});
});

describe("ai", () => {
describe("infer", () => {
let step: StepTools;

beforeEach(() => {
step = getStepTools();
});

test("return Step step op code", async () => {
await expect(
step.ai.infer("step", {
provider: openai({ model: "gpt-3.5-turbo" }),
body: {
messages: [],
},
})
).resolves.toMatchObject({
op: StepOpCode.AiGateway,
});
});

test("returns `id` as ID", async () => {
await expect(
step.ai.infer("id", {
provider: openai({ model: "gpt-3.5-turbo" }),
body: {
messages: [],
},
})
).resolves.toMatchObject({
id: "id",
});
});

test("return ID by default", async () => {
await expect(
step.ai.infer("id", {
provider: openai({ model: "gpt-3.5-turbo" }),
body: {
messages: [],
},
})
).resolves.toMatchObject({
displayName: "id",
});
});

test("return specific name if given", async () => {
await expect(
step.ai.infer(
{ id: "id", name: "name" },
{
provider: openai({ model: "gpt-3.5-turbo" }),
body: {
messages: [],
},
}
)
).resolves.toMatchObject({
displayName: "name",
});
});

test("requires a provider", () => {
// @ts-expect-error Missing provider
() => step.ai.infer("id", { body: { messages: [] } });
});

test("requires a body", () => {
() =>
// @ts-expect-error Missing body
step.ai.infer("id", {
provider: openai({ model: "gpt-3.5-turbo" }),
});
});

test("provider requires the correct body", () => {
() =>
step.ai.infer("id", {
provider: openai({ model: "gpt-3.5-turbo" }),
// @ts-expect-error Invalid body
body: {},
});
});

test("accepts the correct body", () => {
() =>
step.ai.infer("id", {
provider: openai({ model: "gpt-3.5-turbo" }),
body: {
messages: [],
},
});
});

test("uses default model if none given", async () => {
await expect(
step.ai.infer("id", {
provider: openai({ model: "gpt-3.5-turbo" }),
body: {
messages: [],
},
})
).resolves.toMatchObject({
opts: {
body: {
model: "gpt-3.5-turbo",
},
},
});
});

test("can overwrite model", async () => {
await expect(
step.ai.infer("id", {
provider: openai({ model: "gpt-3.5-turbo" }),
body: {
model: "gpt-3.5-something-else",
messages: [],
},
})
).resolves.toMatchObject({
opts: {
body: {
model: "gpt-3.5-something-else",
},
},
});
});
});

describe("wrap", () => {
let step: StepTools;

beforeEach(() => {
step = getStepTools();
});

test("return Step step op code", async () => {
await expect(
step.ai.wrap("step", () => undefined)
).resolves.toMatchObject({
op: StepOpCode.StepPlanned,
});
});

test("returns `id` as ID", async () => {
await expect(step.ai.wrap("id", () => undefined)).resolves.toMatchObject({
id: "id",
});
});

test("return ID by default", async () => {
await expect(step.ai.wrap("id", () => undefined)).resolves.toMatchObject({
displayName: "id",
});
});

test("return specific name if given", async () => {
await expect(
step.ai.wrap({ id: "id", name: "name" }, () => undefined)
).resolves.toMatchObject({
displayName: "name",
});
});

test("no input", async () => {
await expect(step.ai.wrap("", () => {})).resolves.toMatchObject({});
});

test("single input", async () => {
await expect(
// eslint-disable-next-line @typescript-eslint/no-unused-vars
step.ai.wrap("", (flag: boolean) => {}, true)
).resolves.toMatchObject({});
});

test("multiple input", async () => {
await expect(
// eslint-disable-next-line @typescript-eslint/no-unused-vars
step.ai.wrap("", (flag: boolean, value: number) => {}, true, 10)
).resolves.toMatchObject({});
});

test("disallow missing step inputs when function expects them", () => {
// @ts-expect-error Invalid data
// eslint-disable-next-line @typescript-eslint/no-unused-vars
void step.ai.wrap("", (flag: boolean, value: number) => {});
});

test("disallow step inputs when function does not expect them", () => {
// @ts-expect-error Invalid data
void step.ai.wrap("", () => {}, true);
});

test("disallow step inputs that don't match what function expects", () => {
// @ts-expect-error Invalid data
// eslint-disable-next-line @typescript-eslint/no-unused-vars
void step.ai.wrap("", (flag: boolean, value: number) => {}, 10, true);
});

test("optional input", async () => {
await expect(
step.run(
"",
// eslint-disable-next-line @typescript-eslint/no-unused-vars
(flag: boolean, value?: number) => {
// valid - enough arguments given - missing arg is optional
},
true
)
).resolves.toMatchObject({});
});

test("types returned from ai are the result of (de)serialization", () => {
const input = {
str: "",
num: 0,
bool: false,
date: new Date(),
fn: () => undefined,
obj: {
str: "",
num: 0,
},
arr: [0, 1, 2, () => undefined, true],
infinity: Infinity,
nan: NaN,
undef: undefined,
null: null,
symbol: Symbol("foo"),
map: new Map(),
set: new Set(),
bigint: BigInt(123),
typedArray: new Int8Array(2),
promise: Promise.resolve(),
weakMap: new WeakMap([[{}, "test"]]),
weakSet: new WeakSet([{}]),
};

const output = step.ai.wrap("step", () => input);

type Expected = {
str: string;
num: number;
bool: boolean;
date: string;
obj: {
str: string;
num: number;
};
arr: (number | null | boolean)[];
infinity: number;
nan: number;
null: null;
map: Record<string, never>;
set: Record<string, never>;
bigint: never;
typedArray: Record<string, number>;
// eslint-disable-next-line @typescript-eslint/ban-types
promise: {};
// eslint-disable-next-line @typescript-eslint/ban-types
weakMap: {};
// eslint-disable-next-line @typescript-eslint/ban-types
weakSet: {};
};

assertType<Promise<Expected>>(output);

/**
* Used to ensure that stripped base properties are also adhered to.
*/
type KeysMatchExactly<T, U> = keyof T extends keyof U
? keyof U extends keyof T
? true
: false
: false;

assertType<KeysMatchExactly<Expected, Awaited<typeof output>>>(true);
});
});
});

describe("sleep", () => {
let step: StepTools;

Expand Down
Loading

0 comments on commit 871a958

Please sign in to comment.