Skip to content

feature: custom ai executor #1766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion examples/09-ai/01-minimal/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import "@blocknote/mantine/style.css";
import {
FormattingToolbar,
FormattingToolbarController,
SuggestionMenuController,
getDefaultReactSlashMenuItems,
getFormattingToolbarItems,
SuggestionMenuController,
useCreateBlockNote,
} from "@blocknote/react";
import {
Expand Down Expand Up @@ -64,6 +64,14 @@ export default function App() {
extensions: [
createAIExtension({
model,
/*
executor: (opts) => {
// fetch data
const resp = await fetch(opts)
// process to stream tool calls
const streamToolCalls = await yourLogicToConvertRespToStreamToolCalls(opts);
return LLMResponse.fromArray(opts.messages, opts.streamTools, streamToolCalls);
},*/
}),
],
// We set some initial content for demo purposes
Expand Down
25 changes: 21 additions & 4 deletions packages/xl-ai/src/AIExtension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@ import {
suggestChanges,
} from "@blocknote/prosemirror-suggest-changes";
import { APICallError, LanguageModel, RetryError } from "ai";
import { Fragment, Slice } from "prosemirror-model";
import { Plugin, PluginKey } from "prosemirror-state";
import { fixTablesKey } from "prosemirror-tables";
import { createStore, StoreApi } from "zustand/vanilla";
import { doLLMRequest, LLMRequestOptions } from "./api/LLMRequest.js";
import {
doLLMRequest,
ExecuteLLMRequestOptions,
LLMRequestOptions,
} from "./api/LLMRequest.js";
import { LLMResponse } from "./api/LLMResponse.js";
import { PromptBuilder } from "./api/formats/PromptBuilder.js";
import { LLMFormat, llmFormats } from "./api/index.js";
import { createAgentCursorPlugin } from "./plugins/AgentCursorPlugin.js";
import { Fragment, Slice } from "prosemirror-model";

type MakeOptional<T, K extends keyof T> = Omit<T, K> & Partial<Pick<T, K>>;

Expand Down Expand Up @@ -81,6 +85,13 @@ type GlobalLLMRequestOptions = {
* @default the default prompt builder for the selected {@link dataFormat}
*/
promptBuilder?: PromptBuilder;

/**
* Customize how your LLM backend is called.
* Implement this function if you want to call a backend that is not compatible with
* the Vercel AI SDK
*/
executor?: (opts: ExecuteLLMRequestOptions) => Promise<LLMResponse>;
};

const PLUGIN_KEY = new PluginKey(`blocknote-ai-plugin`);
Expand Down Expand Up @@ -112,7 +123,10 @@ export class AIExtension extends BlockNoteExtension {
public readonly options: ReturnType<
ReturnType<
typeof createStore<
MakeOptional<Required<GlobalLLMRequestOptions>, "promptBuilder">
MakeOptional<
Required<GlobalLLMRequestOptions>,
"promptBuilder" | "executor"
>
>
>
>;
Expand All @@ -134,7 +148,10 @@ export class AIExtension extends BlockNoteExtension {
super();

this.options = createStore<
MakeOptional<Required<GlobalLLMRequestOptions>, "promptBuilder">
MakeOptional<
Required<GlobalLLMRequestOptions>,
"promptBuilder" | "executor"
>
>()((_set) => ({
dataFormat: llmFormats.html,
stream: true,
Expand Down
110 changes: 62 additions & 48 deletions packages/xl-ai/src/api/LLMRequest.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,40 @@
import { BlockNoteEditor } from "@blocknote/core";
import { CoreMessage, generateObject, LanguageModelV1, streamObject } from "ai";
import {
generateOperations,
streamOperations,
} from "../streamTool/callLLMWithStreamTools.js";
import { createAISDKLLMRequestExecutor } from "../streamTool/callLLMWithStreamTools.js";
import { StreamTool } from "../streamTool/streamTool.js";
import { isEmptyParagraph } from "../util/emptyBlock.js";
import { LLMResponse } from "./LLMResponse.js";
import type { PromptBuilder } from "./formats/PromptBuilder.js";
import { htmlBlockLLMFormat } from "./formats/html-blocks/htmlBlocks.js";
import { LLMFormat } from "./index.js";

type MakeOptional<T, K extends keyof T> = Omit<T, K> & Partial<Pick<T, K>>;

export type ExecuteLLMRequestOptions = {
messages: CoreMessage[];
streamTools: StreamTool<any>[];
llmRequestOptions: MakeOptional<LLMRequestOptions, "executor">;
onStart?: () => void;
};

export type LLMRequestOptions = {
/**
* The language model to use for the LLM call (AI SDK)
*
* (when invoking `callLLM` via the `AIExtension` this will default to the
* model set in the `AIExtension` options)
*
* Note: perhaps we want to remove this
*/
model: LanguageModelV1;
model?: LanguageModelV1;

/**
* Customize how your LLM backend is called.
* Implement this function if you want to call a backend that is not compatible with
* the Vercel AI SDK
*/
executor?: (opts: ExecuteLLMRequestOptions) => Promise<LLMResponse>;

/**
* The user prompt to use for the LLM call
*/
Expand All @@ -43,12 +60,6 @@ export type LLMRequestOptions = {
* @default provided by the format (e.g. `llm.html.defaultPromptBuilder`)
*/
promptBuilder?: PromptBuilder;
/**
* The maximum number of retries for the LLM call
*
* @default 2
*/
maxRetries?: number;
/**
* Whether to use the editor selection for the LLM call
*
Expand All @@ -68,15 +79,6 @@ export type LLMRequestOptions = {
/** Enable the delete tool (default: true) */
delete?: boolean;
};
/**
* Whether to stream the LLM response or not
*
* When streaming, we use the AI SDK `streamObject` function,
* otherwise, we use the AI SDK `generateObject` function.
*
* @default true
*/
stream?: boolean;
/**
* If the user's cursor is in an empty paragraph, automatically delete it when the AI
* is starting to write.
Expand All @@ -102,6 +104,26 @@ export type LLMRequestOptions = {
* @default true
*/
withDelays?: boolean;

// The settings below might make more sense to be part of the executor

/**
* Whether to stream the LLM response or not
*
* When streaming, we use the AI SDK `streamObject` function,
* otherwise, we use the AI SDK `generateObject` function.
*
* @default true
*/
stream?: boolean;

/**
* The maximum number of retries for the LLM call
*
* @default 2
*/
maxRetries?: number;

/**
* Additional options to pass to the AI SDK `generateObject` function
* (only used when `stream` is `false`)
Expand Down Expand Up @@ -217,34 +239,26 @@ export async function doLLMRequest(
opts.onBlockUpdate,
);

let response:
| Awaited<ReturnType<typeof generateOperations<any>>>
| Awaited<ReturnType<typeof streamOperations<any>>>;

if (stream) {
response = await streamOperations(
streamTools,
{
messages,
...rest,
},
() => {
if (deleteCursorBlock) {
editor.removeBlocks([deleteCursorBlock]);
}
onStart?.();
},
);
} else {
response = await generateOperations(streamTools, {
messages,
...rest,
});
if (deleteCursorBlock) {
editor.removeBlocks([deleteCursorBlock]);
let executor = opts.executor;
if (!executor) {
if (!opts.model) {
throw new Error("model is required when no executor is provided");
}
onStart?.();
executor = createAISDKLLMRequestExecutor({ model: opts.model });
}

return new LLMResponse(messages, response, streamTools);
return executor({
onStart: () => {
if (deleteCursorBlock) {
editor.removeBlocks([deleteCursorBlock]);
}
onStart?.();
},
messages,
streamTools,
llmRequestOptions: {
...opts,
...rest,
stream,
},
});
}
43 changes: 43 additions & 0 deletions packages/xl-ai/src/api/LLMResponse.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { CoreMessage } from "ai";
import { OperationsResult } from "../streamTool/callLLMWithStreamTools.js";
import { StreamTool, StreamToolCall } from "../streamTool/streamTool.js";
import { createAsyncIterableStreamFromAsyncIterable } from "../util/stream.js";

/**
* Result of an LLM call with stream tools that apply changes to a BlockNote Editor
Expand Down Expand Up @@ -61,4 +62,46 @@ export class LLMResponse {
console.log(JSON.stringify(toolCall, null, 2));
}
}

/**
* Create a LLMResponse from an array of operations.
*
* Note: This is a temporary helper, we'll make it easier to create this from streaming data if required
*/
public static fromArray<T extends StreamTool<any>[]>(
messages: CoreMessage[],
streamTools: StreamTool<any>[],
operations: StreamToolCall<T>[],
): LLMResponse {
return new LLMResponse(
messages,
OperationsResultFromArray(operations),
streamTools,
);
}
}

function OperationsResultFromArray<T extends StreamTool<any>[]>(
operations: StreamToolCall<T>[],
): OperationsResult<T> {
async function* singleChunkGenerator() {
for (const op of operations) {
yield {
operation: op,
isUpdateToPreviousOperation: false,
isPossiblyPartial: false,
};
}
}

return {
streamObjectResult: undefined,
generateObjectResult: undefined,
get operationsSource() {
return createAsyncIterableStreamFromAsyncIterable(singleChunkGenerator());
},
async getGeneratedOperations() {
return { operations };
},
};
}
1 change: 1 addition & 0 deletions packages/xl-ai/src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,5 @@ export const llmFormats = {
};

export { doLLMRequest as callLLM } from "./LLMRequest.js";
export { LLMResponse } from "./LLMResponse.js";
export { promptHelpers } from "./promptHelpers/index.js";
40 changes: 40 additions & 0 deletions packages/xl-ai/src/streamTool/callLLMWithStreamTools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import {
CoreMessage,
GenerateObjectResult,
LanguageModel,
LanguageModelV1,
ObjectStreamPart,
StreamObjectResult,
generateObject,
Expand All @@ -11,6 +12,8 @@ import {

import { createStreamToolsArraySchema } from "./jsonSchema.js";

import { ExecuteLLMRequestOptions } from "../api/LLMRequest.js";
import { LLMResponse } from "../api/LLMResponse.js";
import {
AsyncIterableStream,
createAsyncIterableStream,
Expand Down Expand Up @@ -350,3 +353,40 @@ function partialObjectStream<PARTIAL>(
),
);
}

export function createAISDKLLMRequestExecutor(opts: {
model: LanguageModelV1;
}) {
const { model } = opts;
return async (opts: ExecuteLLMRequestOptions) => {
const { messages, streamTools, llmRequestOptions, onStart } = opts;
const { stream, maxRetries, _generateObjectOptions, _streamObjectOptions } =
llmRequestOptions;
let response:
| Awaited<ReturnType<typeof generateOperations<any>>>
| Awaited<ReturnType<typeof streamOperations<any>>>;

if (stream) {
response = await streamOperations(
streamTools,
{
messages,
model,
maxRetries,
...(_streamObjectOptions as any),
},
onStart,
);
} else {
response = await generateOperations(streamTools, {
messages,
model,
maxRetries,
...(_generateObjectOptions as any),
});
onStart?.();
}

return new LLMResponse(messages, response, streamTools);
};
}
Loading