-
Notifications
You must be signed in to change notification settings - Fork 9.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : add grammar-based sampling #1773
Conversation
Suggest taking a file as grammar parameter and put several examples like what we did for prompts (in .\prompts folder). |
Incredibly useful contribution. It's really amazing how much this simplifies many use cases. I agree that it would be better if the grammar came from a file. Two snags I hit while trying this out:
Some additional thoughts:
|
Very nice! I am wondering what is the rationale for not including the parser in the llama.cpp API. Without it, most downstream users will be forced to manually make a copy of the parser in their code to support the feature, which is not great. |
Thanks all! Just added support for grammar files (with examples) and updated the grammar syntax to add shell-style comments and allow empty lines between rules, as well as newlines inside parenthesized groups.
I wonder if that was #1699 ? If so, should be fixed now
Sounds cool, I don't know enough about GPU programming to comment on that myself. The grammar participates in the sampling layer, and I'm not sure if that leverages the GPU currently.
This is definitely possible. That said, AFAIK the token would still need to be evaluated, and that seems to be the bottleneck. Maybe the optimization comes in being able to batch eval strings of such tokens?
Neat idea. Would that be more of an optimization or to reference tokens that can't be expressed textually?
Honestly, I was trying to reduce the changes to
Makes sense. I left that out of this round of changes - if it's desired to have the grammar parser in the llama API, this may naturally fit with that change. |
First, this is amazing work. This makes me wonder whether the entire sampling API should be pulled into something like The existing functions can be kept for compatibility. This reminds me a lot of: https://lmql.ai/.
Would it though?
AFAIK, that's the case.
The current sampling is CPU. |
one of the discussion points for adding more llm generic tooling back into ggml(repo) was moving the sampler there. but afaik nothing happened yet :) |
Was planning to tackle this next. I've got it more or less working locally in a branch off of this, at least with the examples on jsonformer's README. It uses a Python script to generate a JSON BNF that conforms to the schema. |
@@ -263,6 +289,9 @@ extern "C" { | |||
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); | |||
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); | |||
|
|||
/// @details Apply constraints from grammar | |||
LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make llama_grammar as a structure with two callbacks? So the other implementation of it can support context aware state machine instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean like, the caller would provide the implementation of llama_grammar
(via callbacks), from which the llama API determines which tokens are valid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, so llama code will not assume the grammar implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I'm open to that idea, assuming the grammar interface itself generalizes well to other implementations. I kind of designed this with the specific implementation in mind so that's not a guarantee.
Great stuff! I'm still wrapping my head around this.
|
To clarify, this PR adds the core sampling functionality in
Yes, that's correct. I think that's doable, I can take a stab at that. |
EDIT: I read through the code. The grammar doesn't care about the text encoding. It could work with any encoding, provided that the rules match the characters correctly. The parser doesn't understand UTF-8 so it will create rules that don't match as the user expects. For example, if I wanted to create a rule to match all Hiragana characters, I should be able to write:
However the parser doesn't see it as two characters separated by
But the correct rule should be something like this?
|
Just dont use repeat penalties to get best grammar as llama can |
It's fine the way it is |
FWIW I'm adapting this code into an analogous feature for models running on torch. In my implementation, I'm doing grammar enforcement logit masking on the GPU across the full token set before selecting candidates: https://github.com/Shopify/torch-grammar/blob/df23e354083c909c70120e256ed34036c93f6714/grammar_sampler.py#L232-L239. The same strategy would probably work here if anyone was super motivated to try it. |
Adding a similar grammar-based sampling to |
Can we improve the result by fine tuning the model ? If so, what is the example to improve it ? |
@ggerganov agreed! Although I'm not sure when or if I'll be able to contribute that. @x4080 this approach is independent of the model variant and can be used with a fine-tune. In the comments above there's a demonstration of using WizardLM, for example. |
@ejones thanks |
What is the generation speed on this? How does it compare to unconstrained generation when using cpp? I am trying to replicate this in python using an A100 on an 8bit quantized llama 7b and it is extremely slow per token (compared to unconstrained generation) due to all the extra encoding and decoding that needs to happen. |
I've only done CPU inference, but the performance impact has been insignificant for everything I've tried. On the M2 Max I'm seeing about ~0.5 ms / token sampling for unconstrained vs ~6ms with a grammar, with token eval taking about ~70ms for 13b (Q4_K). I'm not sure if I know enough about GPU programming to meaningfully comment; I know that other folks are working on approaches that are more generic or are GPU- and/or Python-focused. There are some examples of this upthread: https://github.com/Shopify/torch-grammar and https://github.com/normal-computing/outlines. |
@ggerganov alright, I've done it: ggerganov/whisper.cpp#1229 |
Is it a natural side effect of the grammar sampling method that I am seeing a significant degradation in tokens per second speed during generation, even with permissive sampling rules? |
It tells you the time spent on sampling at the end. You can confirm it there.
|
Yeah, I generally see about ~5ms/token overhead for grammars on the M2 Max, which is usually a fraction of the per token eval time. But recently I was testing with a grammar and saw a more significant impact. Should investigate, there may be some pathological cases. |
…d grammars by including the `messages` field and adjusting the endpoint to `/v1/chat/completions`. # Aider chat conversation: USER: https://localai.io/features/constrained_grammars/: - [](/) LocalAI - - - [*info* Overview](https://localai.io/) - *rocket_launch* Getting started - [Quickstart](https://localai.io/basics/getting_started/) - [Run other Models](https://localai.io/docs/getting-started/run-other-models/) - [Customizing the Model](https://localai.io/docs/getting-started/customize-model/) - [Run models manually](https://localai.io/docs/getting-started/manual/) - [Build LocalAI from source](https://localai.io/basics/build/) - [*newspaper* News](https://localai.io/basics/news/) - *feature_search* Features - [⚡ GPU acceleration](https://localai.io/features/gpu-acceleration/) - [📖 Text generation (GPT)](https://localai.io/features/text-generation/) - [🗣 Text to audio (TTS)](https://localai.io/features/text-to-audio/) - [🎨 Image generation](https://localai.io/features/image-generation/) - [🧠 Embeddings](https://localai.io/features/embeddings/) - [🆕 GPT Vision](https://localai.io/features/gpt-vision/) - [✍️ Constrained grammars](https://localai.io/features/constrained_grammars/) - [🔈 Audio to text](https://localai.io/features/audio-to-text/) - [🔥 OpenAI functions and tools](https://localai.io/features/openai-functions/) - [💾 Stores](https://localai.io/stores/) - [🖼️ Model gallery](https://localai.io/models/) - [*rocket_launch* Integrations](https://localai.io/docs/integrations/) - *science* Advanced - [Advanced usage](https://localai.io/advanced/) - [Fine-tuning LLMs for text generation](https://localai.io/docs/advanced/fine-tuning/) - *science* References - [Model compatibility table](https://localai.io/model-compatibility/) - [Architecture](https://localai.io/docs/reference/architecture/) - [Available Container images](https://localai.io/docs/reference/container-images/) - [All-In-One images](https://localai.io/docs/reference/aio-images/) - [*quiz* FAQ](https://localai.io/faq/) ::::::::::::::::::::::::::::::::::::::::::::::::: {role="main"} [](/) menu search Search [ ]{.kbd} [ ]{.kbd} - [](%20https://github.com/mudler/LocalAI%20) - [](%20https://twitter.com/LocalAI_API%20) - [](%20https://localai.io/index.xml%20) []{.kbd} []{.kbd} to navigate []{.kbd} to select []{.kbd} to close cancel - [*Home*](/docs/) - [Features](/features/) - ✍️ Constrained grammars On this page - - [Setup](#setup) - [💡 Usage example](#-usage-example) Table of Contents - - [Setup](#setup) - [💡 Usage example](#-usage-example) *article* # ✍️ Constrained grammars The chat endpoint accepts an additional `grammar` parameter which takes a [BNF defined grammar](https://en.wikipedia.org/wiki/Backus%E2%80%93Naur_form). This allows the LLM to constrain the output to a user-defined schema, allowing to generate `JSON`, `YAML`, and everything that can be defined with a BNF grammar. notifications This feature works only with models compatible with the [llama.cpp](https://github.com/ggerganov/llama.cpp) backend (see also [Model compatibility](/model-compatibility/)). For details on how it works, see the upstream PRs: <ggerganov/llama.cpp#1773>, <ggerganov/llama.cpp#1887> ## Setup [*link*](#setup) Follow the setup instructions from the [LocalAI functions](/features/openai-functions/) page. ## 💡 Usage example [*link*](#-usage-example) For example, to constrain the output to either `yes`, `no`: curl http://localhost:8080/v1/chat/completions -H "Content-Type: application/json" -d '{ "model": "gpt-4", "messages": [{"role": "user", "content": "Do you like apples?"}], "grammar": "root ::= (\"yes\" | \"no\")" }' [ Edit this page](https://github.com/mudler/LocalAI/blob/master/docs/content/docs/features/constrained_grammars.md) Last updated 3 months ago. history ------------------------------------------------------------------------ [](/features/gpt-vision/) *navigate_before* 🆕 GPT Vision [](/features/audio-to-text/) 🔈 Audio to text *navigate_next* © 2023-2024 [Ettore Di Giacinto](https://mudler.pm) ::::::::::::::::::::::::::::::::::::::::::::::::: ASSISTANT: Based on the information provided about the LocalAI's constrained grammars feature, it seems that the `handleCompletions` function in the `rag-api.go` file needs to be updated to support the constrained grammars when making a request to the LocalAI service. The request payload should include a `grammar` field when it is provided in the `CompletionRequest`. Here is the diff for the necessary changes: ```diff --- cluster/apps/home/localai/app/config/rag-api.go +++ cluster/apps/home/localai/app/config/rag-api.go @@ ... @@ func handleCompletions(w http.ResponseWriter, r *http.Request) { atomic.AddUint64(&requestMetrics.CompletionRequests, 1) logRequest(r) var req CompletionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } defer r.Body.Close() embedding, err := getEmbeddings(req.Prompt) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - // Generate completion using constrained grammar (if requested) - var completion CompletionResponse - if req.ConstrainedGrammar != "" { - payload := map[string]interface{}{ - "model": "gpt-4", - "prompt": req.Prompt, - "grammar": req.ConstrainedGrammar, - "max_tokens": req.MaxTokens, - "temperature": req.Temperature, - "top_p": req.TopP, - } - jsonPayload, err := json.Marshal(payload) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } + // Generate completion using constrained grammar or retrieval-augmented generation (if requested) + var completionResp CompletionResponse + payload := map[string]interface{}{ + "model": "gpt-4", + "prompt": req.Prompt, + "max_tokens": req.MaxTokens, + "temperature": req.Temperature, + "top_p": req.TopP, + } + if req.ConstrainedGrammar != "" { + payload["grammar"] = req.ConstrainedGrammar + } + jsonPayload, err := json.Marshal(payload) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } localAI := os.Getenv("LOCAL_AI_ENDPOINT") if localAI == "" { localAI = defaultLocalAI } - resp, err := http.Post(localAI+"/v1/chat/completions", "application/json", bytes.NewBuffer(jsonPayload)) + resp, err := http.Post(localAI+"/completions", "application/json", bytes.NewBuffer(jsonPayload)) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } http.Error(w, fmt.Sprintf("completion request failed with status code: %d: %s", resp.StatusCode, body), resp.StatusCode) return } - var respBody struct { - Result CompletionResponse `json:"result"` - } - if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { + if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - completion = respBody.Result + jsonResp, err := json.Marshal(completionResp) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } - jsonResp, err := json.Marshal(completion) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(jsonResp) } ``` This diff updates the `handleCompletions` function to include the `grammar` field in the request payload if it is provided in the `CompletionRequest`. It also corrects the endpoint to `/completions` as per the LocalAI documentation. USER: UnifiedDiffNoMatch: hunk failed to apply! cluster/apps/home/localai/app/config/rag-api.go does not contain lines that match the diff you provided! Try again. DO NOT skip blank lines, comments, docstrings, etc! The diff needs to apply cleanly to the lines in cluster/apps/home/localai/app/config/rag-api.go! cluster/apps/home/localai/app/config/rag-api.go does not contain these 76 exact lines in a row: ``` func handleCompletions(w http.ResponseWriter, r *http.Request) { atomic.AddUint64(&requestMetrics.CompletionRequests, 1) logRequest(r) var req CompletionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } defer r.Body.Close() embedding, err := getEmbeddings(req.Prompt) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } // Generate completion using constrained grammar (if requested) var completion CompletionResponse if req.ConstrainedGrammar != "" { payload := map[string]interface{}{ "model": "gpt-4", "prompt": req.Prompt, "grammar": req.ConstrainedGrammar, "max_tokens": req.MaxTokens, "temperature": req.Temperature, "top_p": req.TopP, } jsonPayload, err := json.Marshal(payload) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } localAI := os.Getenv("LOCAL_AI_ENDPOINT") if localAI == "" { localAI = defaultLocalAI } resp, err := http.Post(localAI+"/v1/chat/completions", "application/json", bytes.NewBuffer(jsonPayload)) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } http.Error(w, fmt.Sprintf("completion request failed with status code: %d: %s", resp.StatusCode, body), resp.StatusCode) return } var respBody struct { Result CompletionResponse `json:"result"` } if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } completion = respBody.Result jsonResp, err := json.Marshal(completion) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(jsonResp) } ``` ASSISTANT: I apologize for the confusion. Let's try again with the correct diff that applies cleanly to the `cluster/apps/home/localai/app/config/rag-api.go` file. I will update the `handleCompletions` function to correctly handle the constrained grammars feature as described in the LocalAI documentation. ```diff --- cluster/apps/home/localai/app/config/rag-api.go +++ cluster/apps/home/localai/app/config/rag-api.go @@ ... @@ func handleCompletions(w http.ResponseWriter, r *http.Request) { atomic.AddUint64(&requestMetrics.CompletionRequests, 1) logRequest(r) var req CompletionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } defer r.Body.Close() // Generate completion using constrained grammar (if requested) var completion CompletionResponse if req.ConstrainedGrammar != "" { payload := map[string]interface{}{ "model": "gpt-4", "prompt": req.Prompt, - "grammar": req.ConstrainedGrammar, + "grammar": req.ConstrainedGrammar, "max_tokens": req.MaxTokens, "temperature": req.Temperature, "top_p": req.TopP, } jsonPayload, err := json.Marshal(payload) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } localAI := os.Getenv("LOCAL_AI_ENDPOINT") if localAI == "" { localAI = defaultLocalAI } - resp, err := http.Post(localAI+"/v1/chat/completions", "application/json", bytes.NewBuffer(jsonPayload)) + resp, err := http.Post(localAI+"/completions", "application/json", bytes.NewBuffer(jsonPayload)) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } http.Error(w, fmt.Sprintf("completion request failed with status code: %d: %s", resp.StatusCode, body), resp.StatusCode) return } var respBody struct { Result CompletionResponse `json:"result"` } if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } completion = respBody.Result } else { // Implement retrieval-augmented generation using retrieved data // ... } jsonResp, err := json.Marshal(completion) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } w.Header().Set ^C KeyboardInterrupt USER: https://localai.io/features/constrained_grammars/: - [](/) LocalAI - - - [*info* Overview](https://localai.io/) - *rocket_launch* Getting started - [Quickstart](https://localai.io/basics/getting_started/) - [Run other Models](https://localai.io/docs/getting-started/run-other-models/) - [Customizing the Model](https://localai.io/docs/getting-started/customize-model/) - [Run models manually](https://localai.io/docs/getting-started/manual/) - [Build LocalAI from source](https://localai.io/basics/build/) - [*newspaper* News](https://localai.io/basics/news/) - *feature_search* Features - [⚡ GPU acceleration](https://localai.io/features/gpu-acceleration/) - [📖 Text generation (GPT)](https://localai.io/features/text-generation/) - [🗣 Text to audio (TTS)](https://localai.io/features/text-to-audio/) - [🎨 Image generation](https://localai.io/features/image-generation/) - [🧠 Embeddings](https://localai.io/features/embeddings/) - [🆕 GPT Vision](https://localai.io/features/gpt-vision/) - [✍️ Constrained grammars](https://localai.io/features/constrained_grammars/) - [🔈 Audio to text](https://localai.io/features/audio-to-text/) - [🔥 OpenAI functions and tools](https://localai.io/features/openai-functions/) - [💾 Stores](https://localai.io/stores/) - [🖼️ Model gallery](https://localai.io/models/) - [*rocket_launch* Integrations](https://localai.io/docs/integrations/) - *science* Advanced - [Advanced usage](https://localai.io/advanced/) - [Fine-tuning LLMs for text generation](https://localai.io/docs/advanced/fine-tuning/) - *science* References - [Model compatibility table](https://localai.io/model-compatibility/) - [Architecture](https://localai.io/docs/reference/architecture/) - [Available Container images](https://localai.io/docs/reference/container-images/) - [All-In-One images](https://localai.io/docs/reference/aio-images/) - [*quiz* FAQ](https://localai.io/faq/) ::::::::::::::::::::::::::::::::::::::::::::::::: {role="main"} [](/) menu search Search [ ]{.kbd} [ ]{.kbd} - [](%20https://github.com/mudler/LocalAI%20) - [](%20https://twitter.com/LocalAI_API%20) - [](%20https://localai.io/index.xml%20) []{.kbd} []{.kbd} to navigate []{.kbd} to select []{.kbd} to close cancel - [*Home*](/docs/) - [Features](/features/) - ✍️ Constrained grammars On this page - - [Setup](#setup) - [💡 Usage example](#-usage-example) Table of Contents - - [Setup](#setup) - [💡 Usage example](#-usage-example) *article* # ✍️ Constrained grammars The chat endpoint accepts an additional `grammar` parameter which takes a [BNF defined grammar](https://en.wikipedia.org/wiki/Backus%E2%80%93Naur_form). This allows the LLM to constrain the output to a user-defined schema, allowing to generate `JSON`, `YAML`, and everything that can be defined with a BNF grammar. notifications This feature works only with models compatible with the [llama.cpp](https://github.com/ggerganov/llama.cpp) backend (see also [Model compatibility](/model-compatibility/)). For details on how it works, see the upstream PRs: <ggerganov/llama.cpp#1773>, <ggerganov/llama.cpp#1887> ## Setup [*link*](#setup) Follow the setup instructions from the [LocalAI functions](/features/openai-functions/) page. ## 💡 Usage example [*link*](#-usage-example) For example, to constrain the output to either `yes`, `no`: curl http://localhost:8080/v1/chat/completions -H "Content-Type: application/json" -d '{ "model": "gpt-4", "messages": [{"role": "user", "content": "Do you like apples?"}], "grammar": "root ::= (\"yes\" | \"no\")" }' [ Edit this page](https://github.com/mudler/LocalAI/blob/master/docs/content/docs/features/constrained_grammars.md) Last updated 3 months ago. history ------------------------------------------------------------------------ [](/features/gpt-vision/) *navigate_before* 🆕 GPT Vision [](/features/audio-to-text/) 🔈 Audio to text *navigate_next* © 2023-2024 [Ettore Di Giacinto](https://mudler.pm) ::::::::::::::::::::::::::::::::::::::::::::::::: ASSISTANT: To implement the constrained grammars feature as described in the LocalAI documentation, we need to update the `handleCompletions` function in the `rag-api.go` file to correctly handle the `grammar` parameter in the request payload. Here is the diff for the necessary changes: ```diff --- cluster/apps/home/localai/app/config/rag-api.go +++ cluster/apps/home/localai/app/config/rag-api.go @@ ... @@ func handleCompletions(w http.ResponseWriter, r *http.Request) { atomic.AddUint64(&requestMetrics.CompletionRequests, 1) logRequest(r) var req CompletionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } defer r.Body.Close() // Generate completion using constrained grammar (if requested) var completion CompletionResponse if req.ConstrainedGrammar != "" { payload := map[string]interface{}{ "model": "gpt-4", "prompt": req.Prompt, + "messages": []map[string]string{ + { + "role": "user", + "content": req.Prompt, + }, + }, "max_tokens": req.MaxTokens, "temperature": req.Temperature, "top_p": req.TopP, + "grammar": req.ConstrainedGrammar, } jsonPayload, err := json.Marshal(payload) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } localAI := os.Getenv("LOCAL_AI_ENDPOINT") if localAI == "" { localAI = defaultLocalAI } - resp, err := http.Post(localAI+"/completions", "application/json", bytes.NewBuffer(jsonPayload)) + resp, err := http.Post(localAI+"/v1/chat/completions", "application/json", bytes.NewBuffer(jsonPayload)) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } http.Error(w, fmt.Sprintf("completion request failed with status code: %d: %s", resp.StatusCode, body), resp.StatusCode) return } var respBody struct { Result CompletionResponse `json:"result"` } if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } completion = respBody.Result } else { // Implement retrieval-augmented generation using retrieved data // ... } jsonResp, err := json.Marshal(completion) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(jsonResp) } ``` This diff adds the `messages` field to the payload as required by the LocalAI API for constrained grammars and updates the endpoint to `/v1/chat/completions`. It also includes the `grammar` field when the `ConstrainedGrammar` is not empty.
EDITED after updates
Inspired by #1397 and grantslatton's CFG work, this adds an API that takes a serialized context-free grammar to guide and constrain sampling. Also adds a sample Backus-Naur form (BNF)-like syntax in
main
for specifying a grammar for generations.Testing
(M2 Max, 30B)
Chess
"Chess" without grammar
Arithmetic
Arithmetic - no grammar
JSON
"JSON" - no grammar
Japanese
Japanese - no grammar
Approach
Grammar API
The
llama
API accepts a data structure representing a context-free grammar over 32-bit code points:Sampling
The grammar sampling code models a nondeterministic pushdown automaton, maintaining N stacks for the possible parse states. Sampling a token is done in two steps: a sampling API that filters candidates to those that match one of the parse stacks (
llama_sample_grammar
) and adding the chose token to the grammar (llama_grammar_accept_token
).Examples
Adds
--grammar
and--grammar-file
arguments tomain
taking a simple extended BNF to constrain generations. The parser for this format is implemented inexamples/grammar-parser.{h,cpp}
:The
root
rule identifies the start of the grammar.## Caveatsthe binary format makes the code harder to understand and more brittlethe grammar contemplates 16-bit chars but it's just being applied to the 8-bit UTF-8 chars in token strings currentlythe 1-char lookahead sampling is probably biasing generations in a weird way; further investigation on quality of outputs is probably needed