-
Notifications
You must be signed in to change notification settings - Fork 10.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
speculative: Ensure draft and target model vocab matches #3812
speculative: Ensure draft and target model vocab matches #3812
Conversation
examples/speculative/speculative.cpp
Outdated
@@ -64,6 +64,26 @@ int main(int argc, char ** argv) { | |||
params.n_gpu_layers = params.n_gpu_layers_draft; | |||
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params); | |||
|
|||
{ | |||
int n_vocab_tgt = llama_n_vocab(model_tgt); | |||
if (n_vocab_tgt != llama_n_vocab(model_dft)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not ideal. Codellama 7B and 13B have vocab size of 32000 while Codellama 34B has vocab size of 32016. It's the same vocab but with some extra tokens.
We should not disallow such cases. Maybe just print errors / warnings, but still continue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we error if the size differs by more than 100, and also check the content of min(n_vocab_tgt, n_vocab_dft)
tokens? Maybe even start 5 or something, to allow for cases where something like BOS or EOS has different content.
100 and 5 are just random numbers I plucked out of the air. The actual values can be whatever you prefer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, thinks it would work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it. Also made the token content mismatch message a bit more helpful. For example, trying to use Orca 3B to draft Mistral 7B:
main: error: draft model vocab must match target model to use speculation but token 259 content differs - target ' ', draft ' t'
* master: (350 commits) speculative : ensure draft and target model vocab matches (ggerganov#3812) llama : correctly report GGUFv3 format (ggerganov#3818) simple : fix batch handling (ggerganov#3803) cuda : improve text-generation and batched decoding performance (ggerganov#3776) server : do not release slot on image input (ggerganov#3798) batched-bench : print params at start log : disable pid in log filenames server : add parameter -tb N, --threads-batch N (ggerganov#3584) (ggerganov#3768) server : do not block system prompt update (ggerganov#3767) sync : ggml (conv ops + cuda MSVC fixes) (ggerganov#3765) cmake : add missed dependencies (ggerganov#3763) cuda : add batched cuBLAS GEMM for faster attention (ggerganov#3749) Add more tokenizer tests (ggerganov#3742) metal : handle ggml_scale for n%4 != 0 (close ggerganov#3754) Revert "make : add optional CUDA_NATIVE_ARCH (ggerganov#2482)" issues : separate bug and enhancement template + no default title (ggerganov#3748) Update special token handling in conversion scripts for gpt2 derived tokenizers (ggerganov#3746) llama : remove token functions with `context` args in favor of `model` (ggerganov#3720) Fix baichuan convert script not detecing model (ggerganov#3739) make : add optional CUDA_NATIVE_ARCH (ggerganov#2482) ...
…3812) * speculative: Ensure draft and target model vocab matches * Tolerate small differences when checking dft vs tgt vocab
It's currently possible to shoot yourself in the foot by trying to speculate using a draft model with vocab that doesn't match the target, and weird stuff will happen in that case. Naturally the draft model will fail 100% of the time, but looking at the logs it'll appear that the draft is just generating random unrelated stuff (even draft candidates with NaN as the probability).
When there's a mismatch you'll now get an error like:
or
This approach may be too strict, since one might possibly want to use a draft model with a few special tokens that differ. One way to deal with that might be to just say there can be X mismatches at most.
strcmp
on every entry might also be overkill. On my system, checking 32,000 entries is instant but doing something like looping with a step of 10 would probably be fine also.