Skip to content

Commit

Permalink
[Grammar] Integrate with XGrammar (#635)
Browse files Browse the repository at this point in the history
This PR integrates with XGrammar: https://github.com/mlc-ai/xgrammar.

Prior to this PR, grammar is supported by the grammar portion of MLC-LLM
compiled into the model WASM. That portion is now a standalone project
XGrammar. Therefore, this PR adds `mlc-ai/web-xgrammar` as part of the
dependency and remove `src/grammar.ts`. We update `llm_chat.ts`
accordingly for xgrammar's APIs.

In addition, besides `json_schema`, we now also support requests with
EBNF-formatted strings by using the following in the chat completion
request. See `examples/json-schema`'s `ebnfGrammarExample()` for a full
example.

```typescript
    response_format: {
      type: "grammar",
      grammar: jsonGrammarStr,
    } as webllm.ResponseFormat,
```

We also add the following performance info:
- Add `grammar_init_ms` and `grammar_per_token_ms` to
`CompletionUsage.extra` when using grammar
- Add `time_to_first_token_s` (TTFT) and `time_per_output_token_s`
(TPOT), `e2e_latency_s` to `CompletionUsage.extra`

We also add `ignore_eos` to `Completion` and `ChatCompletion` requests,
which can be useful for benchmarking purposes.
  • Loading branch information
CharlieFRuan authored Nov 22, 2024
1 parent 6504047 commit c6b1b4e
Show file tree
Hide file tree
Showing 13 changed files with 467 additions and 301 deletions.
5 changes: 4 additions & 1 deletion examples/json-mode/src/json_mode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ async function main() {
const initProgressCallback = (report: webllm.InitProgressReport) => {
setLabel("init-label", report.text);
};
const selectedModel = "Llama-3.1-8B-Instruct-q4f32_1-MLC";
// Pick any one of these models to start trying -- most models in WebLLM support grammar
const selectedModel = "Llama-3.2-3B-Instruct-q4f16_1-MLC";
// const selectedModel = "Qwen2.5-1.5B-Instruct-q4f16_1-MLC";
// const selectedModel = "Phi-3.5-mini-instruct-q4f16_1-MLC";
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
selectedModel,
{ initProgressCallback: initProgressCallback },
Expand Down
77 changes: 71 additions & 6 deletions examples/json-schema/src/json_schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,14 @@ async function simpleStructuredTextExample() {
const initProgressCallback = (report: webllm.InitProgressReport) => {
setLabel("init-label", report.text);
};

// Pick any one of these models to start trying -- most models in WebLLM support grammar
// const selectedModel = "Llama-3.2-3B-Instruct-q4f16_1-MLC";
// const selectedModel = "Qwen2.5-1.5B-Instruct-q4f16_1-MLC";
const selectedModel = "Phi-3.5-mini-instruct-q4f16_1-MLC";
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
"Llama-3.1-8B-Instruct-q4f16_1-MLC",
{ initProgressCallback: initProgressCallback },
selectedModel,
{ initProgressCallback: initProgressCallback, logLevel: "INFO" },
);

// Note that you'd need to prompt the model to answer in JSON either in
Expand Down Expand Up @@ -106,9 +111,14 @@ async function harryPotterExample() {
setLabel("init-label", report.text);
};

// Pick any one of these models to start trying -- most models in WebLLM support grammar
const selectedModel = "Llama-3.2-3B-Instruct-q4f16_1-MLC";
// const selectedModel = "Qwen2.5-1.5B-Instruct-q4f16_1-MLC";
// const selectedModel = "Phi-3.5-mini-instruct-q4f16_1-MLC";

const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
"Llama-3.1-8B-Instruct-q4f16_1-MLC",
{ initProgressCallback: initProgressCallback },
selectedModel,
{ initProgressCallback: initProgressCallback, logLevel: "INFO" },
);

// Note that you'd need to prompt the model to answer in JSON either in
Expand All @@ -134,6 +144,7 @@ async function harryPotterExample() {
console.log(reply);
console.log("Output:\n" + (await engine.getMessage()));
console.log(reply.usage);
console.log(reply.usage!.extra);
}

async function functionCallingExample() {
Expand Down Expand Up @@ -214,10 +225,64 @@ async function functionCallingExample() {
console.log(reply.usage);
}

async function ebnfGrammarExample() {
// You can directly define an EBNFGrammar string with ResponseFormat.grammar
const jsonGrammarStr = String.raw`
root ::= basic_array | basic_object
basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object
basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"?
basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
basic_string ::= (([\"] basic_string_1 [\"]))
basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
basic_boolean ::= "true" | "false"
basic_null ::= "null"
basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
ws ::= [ \n\t]*
`;

const initProgressCallback = (report: webllm.InitProgressReport) => {
setLabel("init-label", report.text);
};

// Pick any one of these models to start trying -- most models in WebLLM support grammar
const selectedModel = "Llama-3.2-3B-Instruct-q4f16_1-MLC";
// const selectedModel = "Qwen2.5-1.5B-Instruct-q4f16_1-MLC";
// const selectedModel = "Phi-3.5-mini-instruct-q4f16_1-MLC";
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
selectedModel,
{ initProgressCallback: initProgressCallback, logLevel: "INFO" },
);

// Note that you'd need to prompt the model to answer in JSON either in
// user's message or the system prompt
const request: webllm.ChatCompletionRequest = {
stream: false, // works with streaming, logprobs, top_logprobs as well
messages: [
{
role: "user",
content: "Introduce yourself in JSON",
},
],
max_tokens: 128,
response_format: {
type: "grammar",
grammar: jsonGrammarStr,
} as webllm.ResponseFormat,
};

const reply0 = await engine.chatCompletion(request);
console.log(reply0);
console.log("Output:\n" + (await engine.getMessage()));
console.log(reply0.usage);
}

async function main() {
// await simpleStructuredTextExample();
// await harryPotterExample();
await functionCallingExample();
await harryPotterExample();
// await functionCallingExample();
// await ebnfGrammarExample();
}

main();
28 changes: 28 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"ts-jest": "^29.1.2",
"tslib": "^2.3.1",
"@mlc-ai/web-runtime": "0.18.0-dev2",
"@mlc-ai/web-xgrammar": "../xgrammar/web",
"typescript": "^4.9.5"
},
"dependencies": {
Expand Down
1 change: 1 addition & 0 deletions src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ export interface MLCEngineConfig {
export interface GenerationConfig {
// Only used in MLC
repetition_penalty?: number;
ignore_eos?: boolean;
// Shared by MLC and OpenAI APIs
top_p?: number | null;
temperature?: number | null;
Expand Down
10 changes: 6 additions & 4 deletions src/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,12 @@ export class Conversation {
}

getStopStr(): string[] {
if (this.config.stop_str.length > 0) {
return this.config.stop_str;
}
return [this.config.seps[this.config.seps.length - 1]];
// TODO(Charlie): Is this needed?
// if (this.config.stop_str.length > 0) {
// return this.config.stop_str;
// }
// return [this.config.seps[this.config.seps.length - 1]];
return this.config.stop_str;
}

getStopTokens() {
Expand Down
72 changes: 63 additions & 9 deletions src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -465,20 +465,23 @@ export class MLCEngine implements MLCEngineInterface {
pipeline: LLMChatPipeline,
chatConfig: ChatConfig,
genConfig: GenerationConfig,
timeReceived: number,
): AsyncGenerator<ChatCompletionChunk, void, void>;
asyncGenerate(
request: CompletionCreateParamsStreaming,
model: string,
pipeline: LLMChatPipeline,
chatConfig: ChatConfig,
genConfig: GenerationConfig,
timeReceived: number,
): AsyncGenerator<Completion, void, void>;
async *asyncGenerate(
request: ChatCompletionRequestStreaming | CompletionCreateParamsStreaming,
model: string,
pipeline: LLMChatPipeline,
chatConfig: ChatConfig,
genConfig: GenerationConfig,
timeReceived: number,
): AsyncGenerator<ChatCompletionChunk | Completion, void, void> {
// Since it is an async generator, we need to do fine-grained try-catch to ensure lock is
// released only when errors occur. Then release at the very end when no error occurs.
Expand Down Expand Up @@ -678,18 +681,39 @@ export class MLCEngine implements MLCEngineInterface {

// 4. Usage chunk
if (request.stream_options?.include_usage) {
const usedGrammar =
"response_format" in request &&
(request.response_format?.type === "grammar" ||
request.response_format?.type === "json_object");
const completion_tokens = pipeline.getCurRoundDecodingTotalTokens();
const prompt_tokens = pipeline.getCurRoundPrefillTotalTokens();
const prefill_tokens_per_s = pipeline.getCurRoundPrefillTokensPerSec();
const decode_tokens_per_s = pipeline.getCurRoundDecodingTokensPerSec();
const grammar_init_s = pipeline.getCurRoundGrammarInitTotalTime();
const prefill_time = pipeline.getCurRoundPrefillTotalTime();
const decode_time = pipeline.getCurRoundDecodingTotalTime();
const grammar_per_token_s =
pipeline.getCurRoundGrammarPerTokenTotalTime();
const defaultExtra = {
e2e_latency_s: (Date.now() - timeReceived) / 1000,
prefill_tokens_per_s: prefill_tokens_per_s,
decode_tokens_per_s: decode_tokens_per_s,
time_to_first_token_s: prefill_time,
time_per_output_token_s: decode_time / completion_tokens,
};
const usage: CompletionUsage = {
completion_tokens: completion_tokens,
prompt_tokens: prompt_tokens,
total_tokens: completion_tokens + prompt_tokens,
extra: {
prefill_tokens_per_s: prefill_tokens_per_s,
decode_tokens_per_s: decode_tokens_per_s,
},
extra: usedGrammar
? {
...defaultExtra,
...{
grammar_init_s: grammar_init_s,
grammar_per_token_s: grammar_per_token_s / completion_tokens,
},
}
: defaultExtra,
};
if (isChatCompletion) {
const usageChunk: ChatCompletionChunk = {
Expand Down Expand Up @@ -745,6 +769,7 @@ export class MLCEngine implements MLCEngineInterface {
async chatCompletion(
request: ChatCompletionRequest,
): Promise<AsyncIterable<ChatCompletionChunk> | ChatCompletion> {
const timeReceived = Date.now();
// 0. Check model loaded and preprocess inputs
const [selectedModelId, selectedPipeline, selectedChatConfig] =
this.getLLMStates("ChatCompletionRequest", request.model);
Expand All @@ -766,6 +791,7 @@ export class MLCEngine implements MLCEngineInterface {
logprobs: request.logprobs,
top_logprobs: request.top_logprobs,
response_format: request.response_format,
ignore_eos: request.ignore_eos,
};

// 0.5 Block wait until this pipeline finishes all previous requests
Expand All @@ -780,6 +806,7 @@ export class MLCEngine implements MLCEngineInterface {
selectedPipeline,
selectedChatConfig,
genConfig,
timeReceived,
);
}

Expand All @@ -796,6 +823,8 @@ export class MLCEngine implements MLCEngineInterface {
let prompt_tokens = 0;
let prefill_time = 0;
let decode_time = 0;
let grammar_init_s = 0;
let grammar_per_token_s = 0;
for (let i = 0; i < n; i++) {
let outputMessage: string;
if (this.interruptSignal) {
Expand Down Expand Up @@ -852,8 +881,21 @@ export class MLCEngine implements MLCEngineInterface {
prompt_tokens += selectedPipeline.getCurRoundPrefillTotalTokens();
prefill_time += selectedPipeline.getCurRoundPrefillTotalTime();
decode_time += selectedPipeline.getCurRoundDecodingTotalTime();
grammar_init_s += selectedPipeline.getCurRoundGrammarInitTotalTime();
grammar_per_token_s +=
selectedPipeline.getCurRoundGrammarPerTokenTotalTime();
}

const usedGrammar =
"response_format" in request &&
(request.response_format?.type === "grammar" ||
request.response_format?.type === "json_object");
const defaultExtra = {
e2e_latency_s: (Date.now() - timeReceived) / 1000,
prefill_tokens_per_s: prompt_tokens / prefill_time,
decode_tokens_per_s: completion_tokens / decode_time,
time_to_first_token_s: prefill_time,
time_per_output_token_s: decode_time / completion_tokens,
};
const response: ChatCompletion = {
id: crypto.randomUUID(),
choices: choices,
Expand All @@ -864,10 +906,15 @@ export class MLCEngine implements MLCEngineInterface {
completion_tokens: completion_tokens,
prompt_tokens: prompt_tokens,
total_tokens: completion_tokens + prompt_tokens,
extra: {
prefill_tokens_per_s: prompt_tokens / prefill_time,
decode_tokens_per_s: completion_tokens / decode_time,
},
extra: usedGrammar
? {
...defaultExtra,
...{
grammar_init_s: grammar_init_s,
grammar_per_token_s: grammar_per_token_s / completion_tokens,
},
}
: defaultExtra,
} as CompletionUsage,
};

Expand Down Expand Up @@ -901,6 +948,8 @@ export class MLCEngine implements MLCEngineInterface {
async completion(
request: CompletionCreateParams,
): Promise<AsyncIterable<Completion> | Completion> {
const timeReceived = Date.now();

// 0. Check model loaded and preprocess inputs
const [selectedModelId, selectedPipeline, selectedChatConfig] =
this.getLLMStates("CompletionCreateParams", request.model);
Expand All @@ -915,6 +964,7 @@ export class MLCEngine implements MLCEngineInterface {
logit_bias: request.logit_bias,
logprobs: request.logprobs,
top_logprobs: request.top_logprobs,
ignore_eos: request.ignore_eos,
};

// 0.5 Block wait until this pipeline finishes all previous requests
Expand All @@ -929,6 +979,7 @@ export class MLCEngine implements MLCEngineInterface {
selectedPipeline,
selectedChatConfig,
genConfig,
timeReceived,
);
}

Expand Down Expand Up @@ -989,8 +1040,11 @@ export class MLCEngine implements MLCEngineInterface {
prompt_tokens: prompt_tokens,
total_tokens: completion_tokens + prompt_tokens,
extra: {
e2e_latency_s: (Date.now() - timeReceived) / 1000,
prefill_tokens_per_s: prompt_tokens / prefill_time,
decode_tokens_per_s: completion_tokens / decode_time,
time_to_first_token_s: prefill_time,
time_per_output_token_s: decode_time / completion_tokens,
},
} as CompletionUsage,
};
Expand Down
Loading

0 comments on commit c6b1b4e

Please sign in to comment.