-
Notifications
You must be signed in to change notification settings - Fork 9.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Command R Plus support #6491
Conversation
Probably shouldn't do the model_max_length mapping and should instead force it to be in config.json, otherwise these changes worked for me as well |
Do the ggufs produced w/ this work w/ the main branch once they are quantized? |
@N8python i'll let you know as soon as I have a finished one, I would be slightly surprised if they worked without this PR at inference time but i'm not sure how it all works my conversion to f16.gguf is almost done, will be making a Q2 immediately after and seeing if that runs from master |
I think bunch of people are rushing to implement this. I have a slightly more complete code here (https://github.com/Noeda/llama.cpp/tree/commandr-plus I made Q4_K and Q8_0 quants for myself; those seem fine but inference is not. If you want you can pull my code into yours but it doesn't work. I have a bit of limited time and might have to stop hacking until evening or later tomorrow; but I'll try to get it working. I think adding the new layernorms for query and value should be enough; didn't see other differences in the Transformers code. (I'll comment here if I have to get off, so no one waits for me if I have to go. I'm currently hacking and trying to figure out what's going on with my assert failures. This is another of those models that gets lots of excited people out of woods including myself :D to hacking but I don't want people to wait on me because the times I can work are unpredictable and tend to come in bursts and I might have to suddenly disappear) Link for easier reading from my diff: it's not exactly lots of lines of code: master...Noeda:llama.cpp:commandr-plus |
as Noeda suspected this change was not enough to make it work, conversion to f16.gguf "worked" but going to Q2 failed with "gguf_init_from_file: invalid magic characters ''" |
I pinged Cohere on HF and they added model_max_length to the config.json. So no more need to compensate for that oversight in the code. |
Here's the mlx impl: |
FYI converting to fp16 on macOS works with this PR, but quantizing segfaults. ~/git/llama.cpp/convert-hf-to-gguf.py ./CohereForAI_c4ai-command-r-plus --outtype f16 --outfile CohereForAI_c4ai-command-r-plus.fp16.bin
|
Quantizing to Q5_0 works, but the llm_build_norm() function doesn't accept a 2D layer norm for the new q_norm and k_norm parameters. The tensor is 12288 but appears it should be evaluated as 128x96 by the layer norm. |
Your latest push seems to have fixed Q3_K_M and Q4_K_M creation:
Here's the Q3_K_M quantization log if it's interesting: quantize CohereForAI_c4ai-command-r-plus.fp16.bin CohereForAI_c4ai-command-r-plus-Q3_K_M.gguf Q3_K_M
|
So they work now? |
The norm layer is probably being quantized if it is exported as a 2d tensor, but it needs to be f32. Exporting it as 1d (reshaped) may work. |
Does the inference appear to be sane? |
I am now exporting it in 1d f32 in the latest commit, and the issue remains - nonsense output because the Layer Norm should be 2D not 1D. |
In the current implementation seems like most of the values in the computation graph are zero. (Also I learned how to more systematically track intermediate computation values). It's very different compared to old Command-R model even before it hits the code path that uses the new norms. Command-R+ (new model; 5 first and 5 last values from the first intermediate computed values)
Command-R (old model, same tensors)
It's not zero across the board but it looks fairly broken. A bit surprising since it isn't that different of a model. The quants I'm working with have a suspiciously good compression rates with zstd. They don't look like entire zeroes in hexedit but Maybe worth checking if the GGUF converter isn't throwing away data somehow. Although possibly I have corrupted files. Checksums from HF seem to match....hrm. It would be annoying if I've had trouble only because of corrupted files. Edit: I don't get the zeroes in intermediate computations with f16. It's just so big the test workflow takes forever. I wonder if there might be another quant bug with tensors being larger than 2**31-1 like we found with the previous model, but more subtle this time. |
|
@dranger003 Ah thanks! Yeah, that indicate that you also have a zero hole in your file. Okay good, so it's not just me. I think I may have found the part that overflows. It's the same tensor as last Command-R model that also had an overflow, but it overflows in a different part this time. Maybe the tensor is even larger this time. |
Congrats on getting it working :D :D :D. My ballpark for a m3 max (my device) at Q8 would be 2-3 tok/sec (logic -> 3-5 tok/sec at Q3 for 120 b)... what mac studio do you have? Maybe there's a slower part of the inference code? |
I'll be honest; I straight up might not have time. I have some possibly high-stakes interviews this week and will spend time prepping for those instead, starting right about when I finish typing this comment :D I'm not sure if @dranger003 has an older version in their HF if you look in past commits if there's a working Q4_K_M one uploaded.
Just rechecked my setup: quantize, gguf-dump.py, main and perplexity from Q4_0
This is a case of "works for me". Wondering what could be different. MacOS version and snippet from `llama.cpp` loading itself when using Q4_0.
|
@teis-e Why not use IQ4_XS? |
I saw your HuggingFace, but I don't understand, there are 2 files for IQ4_XS. Which one do I download? |
|
And then i merge them? ./gguf-split --merg How exactly? |
You don't need to merge them at all. Download both files and just point LCPP at the first one. It will load both parts properly. If you want to merge them though, it's
|
This is incorrect, you need to use |
Models that have been split with the now-built-in splitting utility can't simply be concatenated. You can either leave them in multiple pieces and LCPP will load them as-is, or you can use the utility to recombine the pieces into a single large GGUF. |
I got it now. Thnx for all the answers. I'm now moving to the next step using it on LocalAI Has anybody got that working? |
I have perplexity working again for this model using CUDA. I pushed the changes here dranger003@0bcfc87 and here dranger003@835d702
EDIT: @Carolinabanana I'm running PPL on all the quants to test the code, looks like we'll need more updates, I'll continue to commit as I find them. |
The above comment caught my eye. Forgive a simplistic question but I've long seen "split" files which are for GGUF or any other model formats on, e.g. hf, and usually in the file name there's not a lot of information / format consistency other than typically something along the lines of ggml-c4ai-command-r-plus-104b-f16-00005-of-00005.gguf or pytorch_model-00001-of-00004.bin. So if there's no reliably consistent nomenclature as to file naming / "extension" to indicate "specially split" files, and many files are just "ordinarily split" by sharding without other transformation, how should a user tell whether the gguf file can be trivially concatenated or whether it's been somehow altered / wrapped with headers / trailers / whatever and may / must be gguf-split processed or left sharded and loaded that way? I assume there's some kind of identifiable header / magic flag ("/usr/bin/file") or command line option to gguf-using utilities that can check the format / integrity / usability of sharded / not files? |
The format is specified in Lines 1038 to 1041 in cc4a954
I have tried to summarize all here: Feel free to add an improvment request with |
The problem was on my end - somehow I had the QK normalization tensors quantized to I wouldn't be surprised if we have integer overflows in the Metal kernels (I'm actually more surprised that we don't 😄 ). We'll fix those as they occur |
@@ -160,7 +160,7 @@ def write_tensors(self): | |||
data = data.astype(np.float32) | |||
|
|||
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 | |||
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: | |||
if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to update the comment
@@ -1225,7 +1225,7 @@ static void ggml_cuda_op_mul_mat_cublas( | |||
|
|||
// the main device has a larger memory buffer to hold the results from all GPUs | |||
// ldc == nrows of the matrix that cuBLAS writes into | |||
int ldc = id == ctx.device ? ne0 : row_diff; | |||
int64_t ldc = id == ctx.device ? ne0 : row_diff; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great to update the PR description to summarize why we upcasted all int params to int64 in this cntext.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@phymbert I reverted that one in dranger003@835d702 but it looks like the PR got merged before it could be pulled. My level of knowledge of these is no where near to be on par with those that created the code and so I definitely rely on your reviews. I looked at some of the values through the debugger but since we have so many overflowing I had to change them in batches, so this means I most likely changed some that don't need to be changed. Hopefully this makes some sense. I can submit another PR to master with that last commit, otherwise perplexity was still broken using CUDA for this model without that commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation, please raise with @ggerganov as I am out of the subject regarding CommandR+
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I opened a PR as a follow-up (#6563)
@dranger003 is there a chance you could upload an imatrix q4_K_S quant? (and/or imatrix q3_K_L) EDIT: Apparently IQ3 is quite slow, while q4_K_S is equivalent to IQ4_XS. |
@kalomaze Sure, I'l see what I can do. EDIT: I uploaded the new quants, I'll update the perplexity table shortly. |
Thanks a lot. For now I'm using IQ_3XXS and it seems fairly servicable |
I've tried IQ3_M variant for several hours on my Apple silicon. |
Can you show me your command line for that? When I use Q1_M I can use command-r on apple silicon M1 (64GB), but when I use a Q3_M I only get garbage and the logs (I use the server) show the following for each token:
|
I'm currently using llama-cpp-python backed by the latest Llama.cpp, so I'm not using CLI now. |
Please help building gguf-split on windows |
* Add Command R Plus GGUF * Add Command R Plus GGUF * Loading works up to LayerNorm2D * Export new tensors in 1D so they are not quantized. * Fix embedding layer based on Noeda's example * Whitespace * Add line * Fix unexpected tokens on MPS. Re-add F16 fix. ((Noeda) * dranger003: Fix block index overflow in CUDA dequantizing. * Reverted blocked multiplication code as it still has issues and could affect other Llama arches * export norms as f32 * fix overflow issues during quant and other cleanup * Type convention Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * dranger003: Fix more int overflow during quant. --------- Co-authored-by: S <seast@Ss-Mac-Studio.local> Co-authored-by: S <s@example.com> Co-authored-by: slaren <slarengh@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Updated tensor mapping to add Command R Plus support for GGUF conversion.