Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenAI] Support text completion via engine.completions.create() #534

Merged
merged 4 commits into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Note that all examples below run in-browser and use WebGPU as a backend.
- [get-started-web-worker](get-started-web-worker): same as get-started, but using web worker.
- [next-simple-chat](next-simple-chat): a mininum and complete chat bot app with [Next.js](https://nextjs.org/).
- [multi-round-chat](multi-round-chat): while APIs are functional, we internally optimize so that multi round chat usage can reuse KV cache
- [text-completion](text-completion): demonstrates API `engine.completions.create()`, which is pure text completion with no conversation, as opposed to `engine.chat.completions.create()`

#### Advanced OpenAI API Capabilities

Expand Down
14 changes: 14 additions & 0 deletions examples/text-completion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# WebLLM Get Started App

This folder provides a minimum demo to show WebLLM API in a webapp setting.
To try it out, you can do the following steps under this folder

```bash
npm install
npm start
```

Note if you would like to hack WebLLM core package.
You can change web-llm dependencies as `"file:../.."`, and follow the build from source
instruction in the project to build webllm locally. This option is only recommended
if you would like to hack WebLLM core package.
20 changes: 20 additions & 0 deletions examples/text-completion/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"name": "text-completion",
"version": "0.1.0",
"private": true,
"scripts": {
"start": "parcel src/text_completion.html --port 8888",
"build": "parcel build src/text_completion.html --dist-dir lib"
},
"devDependencies": {
"buffer": "^5.7.1",
"parcel": "^2.8.3",
"process": "^0.11.10",
"tslib": "^2.3.1",
"typescript": "^4.9.5",
"url": "^0.11.3"
},
"dependencies": {
"@mlc-ai/web-llm": "../.."
}
}
23 changes: 23 additions & 0 deletions examples/text-completion/src/text_completion.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<!doctype html>
<html>
<script>
webLLMGlobal = {};
</script>
<body>
<h2>WebLLM Test Page</h2>
Open console to see output
<br />
<br />
<label id="init-label"> </label>

<h3>Prompt</h3>
<label id="prompt-label"> </label>

<h3>Response</h3>
<label id="generate-label"> </label>
<br />
<label id="stats-label"> </label>

<script type="module" src="./text_completion.ts"></script>
</body>
</html>
58 changes: 58 additions & 0 deletions examples/text-completion/src/text_completion.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import * as webllm from "@mlc-ai/web-llm";

function setLabel(id: string, text: string) {
const label = document.getElementById(id);
if (label == null) {
throw Error("Cannot find label " + id);
}
label.innerText = text;
}

async function main() {
const initProgressCallback = (report: webllm.InitProgressReport) => {
setLabel("init-label", report.text);
};

// Unlike "Llama-3.1-8B-Instruct-q4f32_1-MLC", this is a base model
const selectedModel = "Llama-3.1-8B-q4f32_1-MLC";

const appConfig: webllm.AppConfig = {
model_list: [
{
model: "https://huggingface.co/mlc-ai/Llama-3.1-8B-q4f32_1-MLC", // a base model
model_id: selectedModel,
model_lib:
webllm.modelLibURLPrefix +
webllm.modelVersion +
"/Llama-3_1-8B-Instruct-q4f32_1-ctx4k_cs1k-webgpu.wasm",
overrides: {
context_window_size: 2048,
},
},
],
};
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
selectedModel,
{
appConfig: appConfig,
initProgressCallback: initProgressCallback,
logLevel: "INFO",
},
);

const reply0 = await engine.completions.create({
prompt: "List 3 US states: ",
// below configurations are all optional
echo: true,
n: 2,
max_tokens: 64,
logprobs: true,
top_logprobs: 2,
});
console.log(reply0);
console.log(reply0.usage);

// To change model, either create a new engine via `CreateMLCEngine()`, or call `engine.reload(modelId)`
}

main();
183 changes: 178 additions & 5 deletions src/conversation.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,25 @@
import { ConvTemplateConfig, MessagePlaceholders, Role } from "./config";
import {
ChatConfig,
ConvTemplateConfig,
MessagePlaceholders,
Role,
} from "./config";
import {
ChatCompletionMessageParam,
ChatCompletionRequest,
} from "./openai_api_protocols/index";
import {
ContentTypeError,
FunctionNotFoundError,
InvalidToolChoiceError,
MessageOrderError,
SystemMessageOrderError,
TextCompletionConversationError,
TextCompletionConversationExpectsPrompt,
UnsupportedRoleError,
UnsupportedToolChoiceTypeError,
UnsupportedToolTypeError,
} from "./error";

/**
* Helper to keep track of history conversations.
Expand All @@ -8,15 +29,21 @@ export class Conversation {
public messages: Array<[Role, string, string | undefined]> = [];
readonly config: ConvTemplateConfig;

/** Whether the Conversation object is for text completion with no conversation-style formatting */
public isTextCompletion: boolean;
/** Used when isTextCompletion is true */
public prompt: string | undefined;

public function_string = "";
public use_function_calling = false;
public override_system_message?: string = undefined;

// TODO(tvm-team) confirm and remove
// private contextWindowStart = 0;

constructor(config: ConvTemplateConfig) {
constructor(config: ConvTemplateConfig, isTextCompletion = false) {
this.config = config;
this.isTextCompletion = isTextCompletion;
}

private getPromptArrayInternal(addSystem: boolean, startPos: number) {
Expand Down Expand Up @@ -96,6 +123,9 @@ export class Conversation {
* @returns The prompt array.
*/
getPromptArray(): Array<string> {
if (this.isTextCompletion) {
throw new TextCompletionConversationError("getPromptArray");
}
return this.getPromptArrayInternal(true, 0);
}

Expand All @@ -107,13 +137,26 @@ export class Conversation {
*
* @returns The prompt array.
*/
getPrompArrayLastRound() {
getPromptArrayLastRound() {
if (this.isTextCompletion) {
throw new TextCompletionConversationError("getPromptyArrayLastRound");
}
if (this.messages.length < 3) {
throw Error("needs to call getPromptArray for the first message");
}
return this.getPromptArrayInternal(false, this.messages.length - 2);
}

/**
* Return prompt in an array for non-conversation text completion.
*/
getPromptArrayTextCompletion(): Array<string> {
if (!this.isTextCompletion || this.prompt === undefined) {
throw new TextCompletionConversationExpectsPrompt();
}
return [this.prompt];
}

/**
* Resets all states for this.conversation.
*/
Expand All @@ -123,6 +166,8 @@ export class Conversation {
this.override_system_message = undefined;
this.function_string = "";
this.use_function_calling = false;
this.isTextCompletion = false;
this.prompt = undefined;
}

getStopStr(): string[] {
Expand All @@ -137,6 +182,9 @@ export class Conversation {
}

appendMessage(role: Role, message: string, role_name?: string) {
if (this.isTextCompletion) {
throw new TextCompletionConversationError("appendMessage");
}
if (
this.messages.length != 0 &&
this.messages[this.messages.length - 1][2] == undefined
Expand All @@ -151,13 +199,19 @@ export class Conversation {
}

appendReplyHeader(role: Role) {
if (this.isTextCompletion) {
throw new TextCompletionConversationError("appendReplyHeader");
}
if (!(role in this.config.roles)) {
throw Error("Role is not supported: " + role);
}
this.messages.push([role, this.config.roles[role], undefined]);
}

finishReply(message: string) {
if (this.isTextCompletion) {
throw new TextCompletionConversationError("finishReply");
}
if (this.messages.length == 0) {
throw Error("Message error should not be 0");
}
Expand All @@ -171,9 +225,13 @@ export class Conversation {
export function getConversation(
conv_template: ConvTemplateConfig,
conv_config?: Partial<ConvTemplateConfig>,
isTextCompletion = false,
): Conversation {
// Update with conv_config
return new Conversation({ ...conv_template, ...conv_config });
return new Conversation(
{ ...conv_template, ...conv_config },
isTextCompletion,
);
}

/**
Expand All @@ -194,7 +252,8 @@ export function compareConversationObject(
convA.function_string !== convB.function_string ||
convA.use_function_calling !== convB.use_function_calling ||
convA.override_system_message !== convB.override_system_message ||
convA.messages.length !== convB.messages.length
convA.messages.length !== convB.messages.length ||
convA.isTextCompletion !== convB.isTextCompletion
) {
return false;
}
Expand All @@ -216,3 +275,117 @@ export function compareConversationObject(
}
return true;
}

/**
* Get a new Conversation object based on the chat completion request.
*
* @param request The incoming ChatCompletionRequest
* @note `request.messages[-1]` is not included as it would be treated as a normal input to
* `prefill()`.
*/
export function getConversationFromChatCompletionRequest(
request: ChatCompletionRequest,
config: ChatConfig,
): Conversation {
// 0. Instantiate a new Conversation object
const conversation = getConversation(
config.conv_template,
config.conv_config,
);

// 1. Populate function-calling-related fields
// TODO: either remove these or support gorilla-like function calling models.
// These commented code was used to support gorilla, but we could not use grammar to
// guarantee its output, nor make it conform to OpenAI's function calling output. Kept for now.
// const functionCallUsage = this.getFunctionCallUsage(request);
// conversation.function_string = functionCallUsage;
// conversation.use_function_calling = functionCallUsage !== "";

// 2. Populate conversation.messages
const input = request.messages;
const lastId = input.length - 1;
if (
(input[lastId].role !== "user" && input[lastId].role !== "tool") ||
typeof input[lastId].content !== "string"
) {
// TODO(Charlie): modify condition after we support multimodal inputs
throw new MessageOrderError(
"The last message should be a string from the `user` or `tool`.",
);
}
for (let i = 0; i < input.length - 1; i++) {
const message: ChatCompletionMessageParam = input[i];
if (message.role === "system") {
if (i !== 0) {
throw new SystemMessageOrderError();
}
conversation.override_system_message = message.content;
} else if (message.role === "user") {
if (typeof message.content !== "string") {
// TODO(Charlie): modify condition after we support multimodal inputs
throw new ContentTypeError(message.role + "'s message");
}
conversation.appendMessage(Role.user, message.content, message.name);
} else if (message.role === "assistant") {
if (typeof message.content !== "string") {
throw new ContentTypeError(message.role + "'s message");
}
conversation.appendMessage(Role.assistant, message.content, message.name);
} else if (message.role === "tool") {
conversation.appendMessage(Role.tool, message.content);
} else {
// Use `["role"]` instead of `.role` to suppress "Property does not exist on type 'never'"
throw new UnsupportedRoleError(message["role"]);
}
}
return conversation;
}

/**
* Returns the function string based on the request.tools and request.tool_choice, raises erros if
* encounter invalid request.
*
* @param request The chatCompletionRequest we are about to prefill for.
* @returns The string used to set Conversatoin.function_string
*/
export function getFunctionCallUsage(request: ChatCompletionRequest): string {
if (
request.tools == undefined ||
(typeof request.tool_choice == "string" && request.tool_choice == "none")
) {
return "";
}
if (
typeof request.tool_choice == "string" &&
request.tool_choice !== "auto"
) {
throw new InvalidToolChoiceError(request.tool_choice);
}
if (
typeof request.tool_choice !== "string" &&
request.tool_choice?.type !== "function"
) {
throw new UnsupportedToolChoiceTypeError();
}

const singleFunctionToCall =
typeof request.tool_choice !== "string" &&
request.tool_choice?.function?.name;
if (singleFunctionToCall) {
for (const f of request.tools) {
if (singleFunctionToCall == f.function.name) {
return JSON.stringify([f.function]);
}
}
throw new FunctionNotFoundError(singleFunctionToCall);
}

const function_list = [];
for (const f of request.tools) {
if (f.type !== "function") {
throw new UnsupportedToolTypeError();
}
function_list.push(f.function);
}
return JSON.stringify(function_list);
}
Loading
Loading