-
Notifications
You must be signed in to change notification settings - Fork 10.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Grammar sampler implementation causes non-trivial token speed degradation #3980
Comments
I can confirm a huge slowdown |
So overall, the initial grammar implementation was not optimized in any way in particular, so there are certainly optimization opportunities. Also I do know of certain pathological grammars that seem to cause large slowdowns or hanging, I haven't looked into that yet. I will say though that:
|
With the assumption that sampler code runs on CPU rather than GPU, I imagine that a M2 Mac does better than a i7 CPU from 2016, lol |
Something like that should helps a bit, no ? struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
std::unordered_map<const llama_grammar_element*, const llama_grammar_element*> pointer_map;
// create a map of old pointers to new pointers
for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
pointer_map[&grammar->rules[ir0][ir1]] = &result->rules[ir0][ir1];
}
}
// redirect elements in stacks to point to new rules
for (size_t is = 0; is < result->stacks.size(); is++) {
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
result->stacks[is][ie] = pointer_map[grammar->stacks[is][ie]];
}
}
return result;
} |
This issue was closed because it has been inactive for 14 days since being marked as stale. |
The copy function of the Grammar sampler specifically is O(n^2) in time complexity.
On 13b 4_K_M, with all layers fully offloaded to my GPU (RTX 3060 12GB VRAM), I normally get a token speed of ~16T/s. This degrades to ~10T/s with grammar sampling on, regardless of the complexity of the grammar being used.
I'm not sure if the sampler code is being threaded at the moment, or if that would help, but hopefully the Grammar implementation could be refactored in some way to accomodate for this.
I'm not sure if it's running through the entire list of 32,000 logits. Maybe it would be smart to run the grammar sampler only after truncation samplers (Top K, Min P...)? If this time complexity is inherently necessary.
The text was updated successfully, but these errors were encountered: