Skip to content

Commit

Permalink
[OpenAI] Support text completion via engine.completions.create() (#534)
Browse files Browse the repository at this point in the history
`completions.create()`, as opposed to `chat.completions.create()`, is
something we have not supported prior to this PR. Compared to
`chat.completions`, `completions` is pure text completion with no
conversation. That is, given the user's input prompt, the model
autoregressively generates, ignoring any chat template. For more, see
`examples/text-completion` and
https://platform.openai.com/docs/api-reference/completions/object
  • Loading branch information
CharlieFRuan authored Aug 10, 2024
1 parent 5472977 commit 41e786b
Show file tree
Hide file tree
Showing 20 changed files with 1,388 additions and 358 deletions.
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Note that all examples below run in-browser and use WebGPU as a backend.
- [get-started-web-worker](get-started-web-worker): same as get-started, but using web worker.
- [next-simple-chat](next-simple-chat): a mininum and complete chat bot app with [Next.js](https://nextjs.org/).
- [multi-round-chat](multi-round-chat): while APIs are functional, we internally optimize so that multi round chat usage can reuse KV cache
- [text-completion](text-completion): demonstrates API `engine.completions.create()`, which is pure text completion with no conversation, as opposed to `engine.chat.completions.create()`

#### Advanced OpenAI API Capabilities

Expand Down
14 changes: 14 additions & 0 deletions examples/text-completion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# WebLLM Get Started App

This folder provides a minimum demo to show WebLLM API in a webapp setting.
To try it out, you can do the following steps under this folder

```bash
npm install
npm start
```

Note if you would like to hack WebLLM core package.
You can change web-llm dependencies as `"file:../.."`, and follow the build from source
instruction in the project to build webllm locally. This option is only recommended
if you would like to hack WebLLM core package.
20 changes: 20 additions & 0 deletions examples/text-completion/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"name": "text-completion",
"version": "0.1.0",
"private": true,
"scripts": {
"start": "parcel src/text_completion.html --port 8888",
"build": "parcel build src/text_completion.html --dist-dir lib"
},
"devDependencies": {
"buffer": "^5.7.1",
"parcel": "^2.8.3",
"process": "^0.11.10",
"tslib": "^2.3.1",
"typescript": "^4.9.5",
"url": "^0.11.3"
},
"dependencies": {
"@mlc-ai/web-llm": "../.."
}
}
23 changes: 23 additions & 0 deletions examples/text-completion/src/text_completion.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<!doctype html>
<html>
<script>
webLLMGlobal = {};
</script>
<body>
<h2>WebLLM Test Page</h2>
Open console to see output
<br />
<br />
<label id="init-label"> </label>

<h3>Prompt</h3>
<label id="prompt-label"> </label>

<h3>Response</h3>
<label id="generate-label"> </label>
<br />
<label id="stats-label"> </label>

<script type="module" src="./text_completion.ts"></script>
</body>
</html>
58 changes: 58 additions & 0 deletions examples/text-completion/src/text_completion.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import * as webllm from "@mlc-ai/web-llm";

function setLabel(id: string, text: string) {
const label = document.getElementById(id);
if (label == null) {
throw Error("Cannot find label " + id);
}
label.innerText = text;
}

async function main() {
const initProgressCallback = (report: webllm.InitProgressReport) => {
setLabel("init-label", report.text);
};

// Unlike "Llama-3.1-8B-Instruct-q4f32_1-MLC", this is a base model
const selectedModel = "Llama-3.1-8B-q4f32_1-MLC";

const appConfig: webllm.AppConfig = {
model_list: [
{
model: "https://huggingface.co/mlc-ai/Llama-3.1-8B-q4f32_1-MLC", // a base model
model_id: selectedModel,
model_lib:
webllm.modelLibURLPrefix +
webllm.modelVersion +
"/Llama-3_1-8B-Instruct-q4f32_1-ctx4k_cs1k-webgpu.wasm",
overrides: {
context_window_size: 2048,
},
},
],
};
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
selectedModel,
{
appConfig: appConfig,
initProgressCallback: initProgressCallback,
logLevel: "INFO",
},
);

const reply0 = await engine.completions.create({
prompt: "List 3 US states: ",
// below configurations are all optional
echo: true,
n: 2,
max_tokens: 64,
logprobs: true,
top_logprobs: 2,
});
console.log(reply0);
console.log(reply0.usage);

// To change model, either create a new engine via `CreateMLCEngine()`, or call `engine.reload(modelId)`
}

main();
183 changes: 178 additions & 5 deletions src/conversation.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,25 @@
import { ConvTemplateConfig, MessagePlaceholders, Role } from "./config";
import {
ChatConfig,
ConvTemplateConfig,
MessagePlaceholders,
Role,
} from "./config";
import {
ChatCompletionMessageParam,
ChatCompletionRequest,
} from "./openai_api_protocols/index";
import {
ContentTypeError,
FunctionNotFoundError,
InvalidToolChoiceError,
MessageOrderError,
SystemMessageOrderError,
TextCompletionConversationError,
TextCompletionConversationExpectsPrompt,
UnsupportedRoleError,
UnsupportedToolChoiceTypeError,
UnsupportedToolTypeError,
} from "./error";

/**
* Helper to keep track of history conversations.
Expand All @@ -8,15 +29,21 @@ export class Conversation {
public messages: Array<[Role, string, string | undefined]> = [];
readonly config: ConvTemplateConfig;

/** Whether the Conversation object is for text completion with no conversation-style formatting */
public isTextCompletion: boolean;
/** Used when isTextCompletion is true */
public prompt: string | undefined;

public function_string = "";
public use_function_calling = false;
public override_system_message?: string = undefined;

// TODO(tvm-team) confirm and remove
// private contextWindowStart = 0;

constructor(config: ConvTemplateConfig) {
constructor(config: ConvTemplateConfig, isTextCompletion = false) {
this.config = config;
this.isTextCompletion = isTextCompletion;
}

private getPromptArrayInternal(addSystem: boolean, startPos: number) {
Expand Down Expand Up @@ -96,6 +123,9 @@ export class Conversation {
* @returns The prompt array.
*/
getPromptArray(): Array<string> {
if (this.isTextCompletion) {
throw new TextCompletionConversationError("getPromptArray");
}
return this.getPromptArrayInternal(true, 0);
}

Expand All @@ -107,13 +137,26 @@ export class Conversation {
*
* @returns The prompt array.
*/
getPrompArrayLastRound() {
getPromptArrayLastRound() {
if (this.isTextCompletion) {
throw new TextCompletionConversationError("getPromptyArrayLastRound");
}
if (this.messages.length < 3) {
throw Error("needs to call getPromptArray for the first message");
}
return this.getPromptArrayInternal(false, this.messages.length - 2);
}

/**
* Return prompt in an array for non-conversation text completion.
*/
getPromptArrayTextCompletion(): Array<string> {
if (!this.isTextCompletion || this.prompt === undefined) {
throw new TextCompletionConversationExpectsPrompt();
}
return [this.prompt];
}

/**
* Resets all states for this.conversation.
*/
Expand All @@ -123,6 +166,8 @@ export class Conversation {
this.override_system_message = undefined;
this.function_string = "";
this.use_function_calling = false;
this.isTextCompletion = false;
this.prompt = undefined;
}

getStopStr(): string[] {
Expand All @@ -137,6 +182,9 @@ export class Conversation {
}

appendMessage(role: Role, message: string, role_name?: string) {
if (this.isTextCompletion) {
throw new TextCompletionConversationError("appendMessage");
}
if (
this.messages.length != 0 &&
this.messages[this.messages.length - 1][2] == undefined
Expand All @@ -151,13 +199,19 @@ export class Conversation {
}

appendReplyHeader(role: Role) {
if (this.isTextCompletion) {
throw new TextCompletionConversationError("appendReplyHeader");
}
if (!(role in this.config.roles)) {
throw Error("Role is not supported: " + role);
}
this.messages.push([role, this.config.roles[role], undefined]);
}

finishReply(message: string) {
if (this.isTextCompletion) {
throw new TextCompletionConversationError("finishReply");
}
if (this.messages.length == 0) {
throw Error("Message error should not be 0");
}
Expand All @@ -171,9 +225,13 @@ export class Conversation {
export function getConversation(
conv_template: ConvTemplateConfig,
conv_config?: Partial<ConvTemplateConfig>,
isTextCompletion = false,
): Conversation {
// Update with conv_config
return new Conversation({ ...conv_template, ...conv_config });
return new Conversation(
{ ...conv_template, ...conv_config },
isTextCompletion,
);
}

/**
Expand All @@ -194,7 +252,8 @@ export function compareConversationObject(
convA.function_string !== convB.function_string ||
convA.use_function_calling !== convB.use_function_calling ||
convA.override_system_message !== convB.override_system_message ||
convA.messages.length !== convB.messages.length
convA.messages.length !== convB.messages.length ||
convA.isTextCompletion !== convB.isTextCompletion
) {
return false;
}
Expand All @@ -216,3 +275,117 @@ export function compareConversationObject(
}
return true;
}

/**
* Get a new Conversation object based on the chat completion request.
*
* @param request The incoming ChatCompletionRequest
* @note `request.messages[-1]` is not included as it would be treated as a normal input to
* `prefill()`.
*/
export function getConversationFromChatCompletionRequest(
request: ChatCompletionRequest,
config: ChatConfig,
): Conversation {
// 0. Instantiate a new Conversation object
const conversation = getConversation(
config.conv_template,
config.conv_config,
);

// 1. Populate function-calling-related fields
// TODO: either remove these or support gorilla-like function calling models.
// These commented code was used to support gorilla, but we could not use grammar to
// guarantee its output, nor make it conform to OpenAI's function calling output. Kept for now.
// const functionCallUsage = this.getFunctionCallUsage(request);
// conversation.function_string = functionCallUsage;
// conversation.use_function_calling = functionCallUsage !== "";

// 2. Populate conversation.messages
const input = request.messages;
const lastId = input.length - 1;
if (
(input[lastId].role !== "user" && input[lastId].role !== "tool") ||
typeof input[lastId].content !== "string"
) {
// TODO(Charlie): modify condition after we support multimodal inputs
throw new MessageOrderError(
"The last message should be a string from the `user` or `tool`.",
);
}
for (let i = 0; i < input.length - 1; i++) {
const message: ChatCompletionMessageParam = input[i];
if (message.role === "system") {
if (i !== 0) {
throw new SystemMessageOrderError();
}
conversation.override_system_message = message.content;
} else if (message.role === "user") {
if (typeof message.content !== "string") {
// TODO(Charlie): modify condition after we support multimodal inputs
throw new ContentTypeError(message.role + "'s message");
}
conversation.appendMessage(Role.user, message.content, message.name);
} else if (message.role === "assistant") {
if (typeof message.content !== "string") {
throw new ContentTypeError(message.role + "'s message");
}
conversation.appendMessage(Role.assistant, message.content, message.name);
} else if (message.role === "tool") {
conversation.appendMessage(Role.tool, message.content);
} else {
// Use `["role"]` instead of `.role` to suppress "Property does not exist on type 'never'"
throw new UnsupportedRoleError(message["role"]);
}
}
return conversation;
}

/**
* Returns the function string based on the request.tools and request.tool_choice, raises erros if
* encounter invalid request.
*
* @param request The chatCompletionRequest we are about to prefill for.
* @returns The string used to set Conversatoin.function_string
*/
export function getFunctionCallUsage(request: ChatCompletionRequest): string {
if (
request.tools == undefined ||
(typeof request.tool_choice == "string" && request.tool_choice == "none")
) {
return "";
}
if (
typeof request.tool_choice == "string" &&
request.tool_choice !== "auto"
) {
throw new InvalidToolChoiceError(request.tool_choice);
}
if (
typeof request.tool_choice !== "string" &&
request.tool_choice?.type !== "function"
) {
throw new UnsupportedToolChoiceTypeError();
}

const singleFunctionToCall =
typeof request.tool_choice !== "string" &&
request.tool_choice?.function?.name;
if (singleFunctionToCall) {
for (const f of request.tools) {
if (singleFunctionToCall == f.function.name) {
return JSON.stringify([f.function]);
}
}
throw new FunctionNotFoundError(singleFunctionToCall);
}

const function_list = [];
for (const f of request.tools) {
if (f.type !== "function") {
throw new UnsupportedToolTypeError();
}
function_list.push(f.function);
}
return JSON.stringify(function_list);
}
Loading

0 comments on commit 41e786b

Please sign in to comment.