-
Notifications
You must be signed in to change notification settings - Fork 9.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
StableLM support #3586
StableLM support #3586
Conversation
Why warn for safetensors? The current version of the falcon script supports both formats; it's not hard to do. |
He might not have tested it. Fine for me. @Galunid : tokenizers are bad, did you check? |
I wasn't aware we support it already and don't need to pull any new dependencies. I just took a look at |
@goerch: No, not yet, right now I can load the model and it starts generating gibberish. I wanted to get something that runs first. I'm going to look at the tokenizer now. Could you confirm whether the output below looks like a tokenizer issue to you?
|
After a little modification, the output was correct.
This model has parameters like llama, but the logic is GPT-NEOX. https://huggingface.co/stabilityai/stablelm-3b-4e1t add attn-norm bias
|
Thanks, I realized I forgot to add biases, but I'd have to spend quite a while to realize rope was set wrong! |
@Galunid Hello, just wanted to ask if there was any update on this? were you able to convert the stablelm-3b-4e1t model to GGUF? Thank you. |
@niranjanakella There's no update. I had family emergency and I didn't have time to touch code. I got back today and I'm planning to look at this more tomorrow. You can download converted model here. You will not get anything useful running this though, just some random nonsense. |
@Galunid how did you convert 4e1t, i dont see any safetensor code in the conversion script. |
I tried hacking in safetensor loading (by copying from falconconvert), but it fails with a size mismatch, idk, just dumping my diff here for reference. diff --git a/convert-stablelm-hf-to-gguf.py b/convert-stablelm-hf-to-gguf.py
index 4a6fc66a..e163bb87 100755
--- a/convert-stablelm-hf-to-gguf.py
+++ b/convert-stablelm-hf-to-gguf.py
@@ -4,6 +4,7 @@
from __future__ import annotations
import argparse
+import contextlib
import json
import os
import struct
@@ -20,17 +21,16 @@ if 'NO_LOCAL_GGUF' not in os.environ:
import gguf
-def count_model_parts(dir_model: Path) -> int:
+def count_model_parts(dir_model: Path, prefix: str) -> int:
num_parts = 0
for filename in os.listdir(dir_model):
- if filename.startswith("pytorch_model-"):
+ if filename.startswith(prefix):
num_parts += 1
if num_parts > 0:
print("gguf: found " + str(num_parts) + " model parts")
return num_parts
-
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert a stablelm model to a GGML compatible file")
parser.add_argument(
@@ -80,10 +80,17 @@ with open(dir_model / "config.json", "r", encoding="utf-8") as f:
if hparams["architectures"][0] != "StableLMEpochForCausalLM":
print("Model architecture not supported: " + hparams["architectures"][0])
- sys.exit()
+ sys.exit(1)
# get number of model parts
-num_parts = count_model_parts(dir_model)
+#num_parts = count_model_parts(dir_model, "model-00")
+#if num_parts:
+num_parts = 0
+is_safetensors = True
+from safetensors import safe_open
+#else:
+ #is_safetensors = False
+ #num_parts = count_model_parts(dir_model, "pytorch_model-")
ARCH=gguf.MODEL_ARCH.STABLELM
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
@@ -140,13 +147,20 @@ special_vocab.add_to_gguf(gguf_writer)
# TENSORS
tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
-print(tensor_map)
+#print(tensor_map)
# tensor info
print("gguf: get tensor metadata")
if num_parts == 0:
- part_names = iter(("pytorch_model.bin",))
+ if is_safetensors:
+ part_names = iter(("model.safetensors",))
+ else:
+ part_names = iter(("pytorch_model.bin",))
+elif is_safetensors:
+ part_names = (
+ f"model-{n:05}-of-{num_parts:05}.safetensors" for n in range(1, num_parts + 1)
+ )
else:
part_names = (
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
@@ -156,47 +170,55 @@ for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'")
- model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
+ if is_safetensors:
+ ctx = safe_open(dir_model / part_name, framework="pt", device="cpu")
+ else:
+ ctx = contextlib.nullcontext(torch.load(dir_model / part_name, map_location="cpu"))
+
+ #model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
+
+ with ctx as model_part:
+ for name in model_part.keys():
+ #data = model_part[name]
+ data = model_part.get_tensor(name) if is_safetensors else model_part[name]
- for name in model_part.keys():
- data = model_part[name]
- # we don't need these
- if name.endswith(".attention.masked_bias") or name.endswith(".attention.bias") or name.endswith(".attention.rotary_emb.inv_freq"):
- continue
+ # we don't need these
+ if name.endswith(".attention.masked_bias") or name.endswith(".attention.bias") or name.endswith(".attention.rotary_emb.inv_freq"):
+ continue
- old_dtype = data.dtype
+ old_dtype = data.dtype
- # convert any unsupported data types to float32
- if data.dtype != torch.float16 and data.dtype != torch.float32:
- data = data.to(torch.float32)
+ # convert any unsupported data types to float32
+ if data.dtype != torch.float16 and data.dtype != torch.float32:
+ data = data.to(torch.float32)
- data = data.squeeze().numpy()
+ data = data.squeeze().numpy()
- # map tensor names
- new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
- if new_name is None:
- print("Can not map tensor '" + name + "'")
- sys.exit()
+ # map tensor names
+ new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
+ if new_name is None:
+ print("Can not map tensor '" + name + "'")
+ sys.exit()
- n_dims = len(data.shape)
- data_dtype = data.dtype
+ n_dims = len(data.shape)
+ data_dtype = data.dtype
- # if f32 desired, convert any float16 to float32
- if ftype == 0 and data_dtype == np.float16:
- data = data.astype(np.float32)
+ # if f32 desired, convert any float16 to float32
+ if ftype == 0 and data_dtype == np.float16:
+ data = data.astype(np.float32)
- # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
- if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
- data = data.astype(np.float32)
+ # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
+ if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+ data = data.astype(np.float32)
- # if f16 desired, convert any float32 2-dim weight tensors to float16
- if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
- data = data.astype(np.float16)
+ # if f16 desired, convert any float32 2-dim weight tensors to float16
+ if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
+ data = data.astype(np.float16)
- print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
+ print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
- gguf_writer.add_tensor(new_name, data)
+ gguf_writer.add_tensor(new_name, data)
print("gguf: write header")
the error:
|
I used from transformers import AutoModelForCausalLM
token = "<your token>"
model = AutoModelForCausalLM.from_pretrained(
"stabilityai/stablelm-3b-4e1t",
trust_remote_code=True,
torch_dtype="auto",
token=token
)
model.save_pretrained("output") |
@Green-Sky I changed conversion script to use safetensors. It works for me, could you give it a try? |
thanks for the update. I still have the same issue with the safetensors file. I also compared the sha256 of model.safetensors to the one on huggingface and they are a match. And no, i have >300gig on free space on the disk :)
|
running in f32 mode the error is somewhat different and more explicit:
how. where is this temp file located?
|
It's putting the temp file on tmpfs. You need to run the command with |
thanks, that worked. here is the current state:
looks like vocabulary is missing. |
I think it's more of a model architecture being incorrectly implemented and model going nuts and "hallucinating" non existent tokens. I mostly copied the implementation from existing llama/gptneox ones and stitched them together, but I've seen that stability had their own modifications I haven't checked yet, so it'll probably be a pain. On a different note I verified that |
I think I've found where the problem is:
We use RMSNorm here, I'll look into this more tomorrow |
The author of Hamlet isThe author of Hamlet is unknown. Although Shakespeare was known to have written the play, it wasn't until after his death in 1616 that his identity as its author became generally accepted by scholars; a theory advanced in the 17th century by Francis Meres, a playwright and poet who had been born near Stratford-upon-Avon, England. This was because Hamlet is so full of allusions to Shakespeare's plays and life that it would be very unusual for any other writer to have made them. Shakespeare wrote many different types of plays over his long career -- tragedies like Macbeth and Romeo and Juliet; comedies like Much Ado About Nothing and A Midsummer Night's Dream; and historical dramas such as Henry V, Richard III and Julius Caesar. The most famous of all is Hamlet, which Shakespeare probably wrote between 1600 and 1602, though critics debate this point. In the play Hamlet, Prince of Denmark, a Danish prince who was heir to the throne but was not allowed to marry his father's widow because she had been married before. After his uncle killed Hamlet’s father, Hamlet swore revenge on Claudius, Hamlet’s new stepfather. In the end, it is Hamlet, however, who kills himself in despair over his failure to avenge his father and restore order to Denmark. Shakespeare wrote Hamlet when he was around 30 years old, but it wasn't performed until more than 200 years after its composition because of censorship rules that prohibited performance of plays with religious or political themes.<|endoftext|> [end of text]The best music isThe best music is often found in the most unlikely places. This week on The Sound of Young America, we're bringing you some classic tracks that are all about finding love and romance...in the strangest of ways! We'll start with a song from one of our favorite bands - Death Cab For Cutie - who has an album out this month (which is also being released in a limited-edOn this day humanity received a grim reminder trigger warning (mass shooting)On this day humanity received a grim reminder that it is not, and never will be, invincible. The world suffered the loss of thousands of lives due to one man’s hatred for people who don’t look like him or think like he does. It was an act so atrocious that it would make even the strongest person question their faith in humanity—and that is exactly what happened. After taking down a building, this sick individual set fire to himself with his own lighter and ended up burning alive on live television during the attack. What makes this all that more disturbing? The fact that it was supposed to be a terroristic attack, but due to the suspect’s poor aim, no one else died in the incident. This terrorist attack is now known as the worst act of terrorism ever committed by an individual on U.S soil and has been ranked the deadliest for a single attacker in world history—ever since it occurred. The man who started this tragedy was named Omar Mateen. He was born in New York to Afghan parents, but moved to Florida after his birth. His father is reportedly a former member of Afghanistan’s Communist Party and was considered an enemy to the U.S government when he came to America. This terrorist attack occurred at Pulse Nightclub, which is located in Orlando, Florida. The nightclub is very popular with people who identify as LGBTQ+—a group that is often discriminated against by religious extremists like Omar Mateen. The nightclub was hosting a Latin Night on the night of June 12 when this tragedy took place. It was open to both men and women, but all patrons were required to purchase admission tickets in order to enter. Mateen entered Pulse with an AR-15 semi-automatic rifle which he had purchased legally at a gun store less than a week prior. He also brought two handguns to the club during this attack—all of these weapons came from his father, who was unaware that his son would use them for such a horrific purpose. After entering Pulse, Omar Mateen began shooting people inside the nightclub and eventually set fire to the main entrance. The police arrived at the scene within minutes after the first shot rang out and confronted Mateen outside of the building—at which point he pulled out his own weapon and opened fire on them too! The man killed in this terrorist attack was named Eddie Justice, who worked as a security guard at Pulse nightclub during its opening hours. He was able to save lives by wrestling away Omar Mateen’s gun from him so that other people could escape without being harmed. During the attack, many patrons of the club were forced outside into a nearby parking lot where they hid under cars or stood on top of them in order not to be shot at from above—all while watching helplessly as their friends and family members inside continued to get killed one by one! The police used explosives to enter Pulse nightclub after Mateen’s gunfire had stopped. It was then that they found the bodies of 49 people who were murdered during this attack, which makes it second only in number to 9/11 when it comes down to deadliest terrorist attacks ever committed on U.S soil! After he died from his injuries, Omar Mateen left a note behind which contained quotes from Islamic holy texts as well as some personal messages directed towards America’s LGBTQ community… but most importantly—he mentioned that Allah had instructed him to commit this act of terror against them all in order for God’s will be done! Omar Mateen was only 29 years old when he committed suicide after being fatally shot by police officers during the Orlando shooting. He is reported to have been married with two children, and his wife still lives in Florida where they were previously living together before this tragedy occurred on June 12th 2016.<|endoftext|> [end of text]I'd say not bad, it's still crashing sometimes and I'm not sure why. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still hitting a cuda assert with -ngl 35
I think you need to ggml_cont
either qrot
or krot
or both since CUDA does not yet support non-contiguous rope. Not ideal, but maybe we can fix this later
Let's merge and support this from master
Any idea why it doesn't stop output? Is it misconfigured control tokens in the (converted) model metadata? |
It looks like there's some missing ggml cuda functions (I think it's I'll merge master tomorrow and we can merge this then. |
@Green-Sky Would you mind giving it one final look? If all is good feel free to merge (if CI runs green) :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reconverted the model, re-ran some perplexities - values slightly differ, bit within variance. Looks pretty good.
the ci oracle gives the green light :)
I was toying around with it a bit more and realized its very slow now. llama-bench on master(36eed0c) stablelm-3b-4e1t
TinyLlama-1.1B
I know comparing 1.1B to 2.8B is not very fair, but the token generation speed seems to scale properly. Or am I just looking at the difference in architecture (mostly GQA)? next problemI re-ran the benchmark on master.... and now the prompt processing speeds are vastly inferior...
... now I have no idea what is going on q.q |
Could you try 6be3356, it should be faster. Right now there are more matrix operations we need to compute (for rope). Previously we just computed rope and on a whole tensor, right now we need to split tensor into two halves (the one roped, and the one not roped). Calculate rope for the first one, concatenate both (which also requires permuting both tensors) at the end. We also use |
did some more testing on master, and turns out that watching youtube and moving windows around has a positive correlation to both performance AND variance ....
sooo... my system seems to be unsuited for llama.cpp / benchmarking. |
Oh yea, 6be3356 is significantly faster:
|
Thanks for making this happen! There is now a growing list of many of the mentioned finetuned Models converted to gguf format on this Huggingface Collection Some of the mentioned Models will not convert at all, or produce incorrect files. Do you want a list of them? |
Yes, I'll take a look if you could list them |
// self-attention | ||
{ | ||
// compute Q and K and RoPE them | ||
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); | ||
cb(tmpq, "tmpq", il); | ||
|
||
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); | ||
cb(tmpk, "tmpk", il); | ||
|
||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); | ||
cb(Vcur, "Vcur", il); | ||
|
||
// RoPE the first n_rot of q/k, pass the other half, and concat. | ||
struct ggml_tensor * qrot = ggml_cont(ctx0, ggml_view_3d( | ||
ctx0, tmpq, hparams.n_rot, n_head, n_tokens, | ||
ggml_element_size(tmpq) * n_embd_head, | ||
ggml_element_size(tmpq) * n_embd_head * n_head, | ||
0 | ||
)); | ||
cb(qrot, "qrot", il); | ||
|
||
struct ggml_tensor * krot = ggml_cont(ctx0, ggml_view_3d( | ||
ctx0, tmpk, hparams.n_rot, n_head, n_tokens, | ||
ggml_element_size(tmpk) * n_embd_head, | ||
ggml_element_size(tmpk) * n_embd_head * n_head_kv, | ||
0 | ||
)); | ||
cb(krot, "krot", il); | ||
|
||
// get the second half of tmpq, e.g tmpq[n_rot:, :, :] | ||
struct ggml_tensor * qpass = ggml_view_3d( | ||
ctx0, tmpq, (n_embd_head - hparams.n_rot), n_head, n_tokens, | ||
ggml_element_size(tmpq) * n_embd_head, | ||
ggml_element_size(tmpq) * n_embd_head * n_head, | ||
ggml_element_size(tmpq) * hparams.n_rot | ||
); | ||
cb(qpass, "qpass", il); | ||
|
||
struct ggml_tensor * kpass = ggml_view_3d( | ||
ctx0, tmpk, (n_embd_head - hparams.n_rot), n_head_kv, n_tokens, | ||
ggml_element_size(tmpk) * (n_embd_head), | ||
ggml_element_size(tmpk) * (n_embd_head) * n_head_kv, | ||
ggml_element_size(tmpk) * hparams.n_rot | ||
); | ||
cb(kpass, "kpass", il); | ||
|
||
struct ggml_tensor * qrotated = ggml_rope_custom( | ||
ctx0, qrot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx, | ||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow | ||
); | ||
cb(qrotated, "qrotated", il); | ||
|
||
struct ggml_tensor * krotated = ggml_rope_custom( | ||
ctx0, krot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx, | ||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow | ||
); | ||
cb(krotated, "krotated", il); | ||
|
||
// ggml currently only supports concatenation on dim=2 | ||
// so we need to permute qrot, qpass, concat, then permute back. | ||
qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3)); | ||
cb(qrotated, "qrotated", il); | ||
|
||
krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3)); | ||
cb(krotated, "krotated", il); | ||
|
||
qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3)); | ||
cb(qpass, "qpass", il); | ||
|
||
kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3)); | ||
cb(kpass, "kpass", il); | ||
|
||
struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass); | ||
cb(Qcur, "Qcur", il); | ||
|
||
struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass); | ||
cb(Kcur, "Kcur", il); | ||
|
||
struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 2, 1, 0, 3)); | ||
cb(Q, "Q", il); | ||
|
||
Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3)); | ||
cb(Kcur, "Kcur", il); | ||
|
||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to Persimmon, we should look into simplifying this.
Either we introduce some custom operation, or extend rope to support this kind of cases. Or if necessary, we can prepare the model data upon conversion to be more friendly to ggml
ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could revert a371a8b. I think that's the most sensible, since it requires just one extra rope implementation, compared to multiple operations we need to implement now (similar to persimmon). This should also allow for simplifying persimmon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or if necessary, we can prepare the model data upon conversion to be more friendly to
ggml
ops.
The disadvantage of doing that is that it makes harder to convert LoRAs from HF, since the tensors no longer match, and we would need to apply the same conversions to the LoRAs (if that's possible at all). Related: #3519
* Add support for stablelm-3b-4e1t * Supports GPU offloading of (n-1) layers
* Add support for stablelm-3b-4e1t * Supports GPU offloading of (n-1) layers
Did you make any progress with this? It blabbers on. |
Seems to have stopped doing that now, using If anything, it's very terse now, can't get it to respond with more than a sentence 🤷 |
I managed to use the stop=["###"] parameter in calling the server.
…On Fri, 22 Dec 2023, 10:16 Daniel Demmel, ***@***.***> wrote:
Any idea why it doesn't stop output? Is it misconfigured control tokens in
the (converted) model metadata?
Did you make any progress with this? It blabbers on.
Seems to have stopped doing that now, using
nisten/obsidian-3b-multimodal-q6-gguf/obsidian-q6.gguf
If anything, it's very terse now, can't get it to respond with more than a
sentence 🤷
image.png (view on web)
<https://github.com/ggerganov/llama.cpp/assets/69962/c5e30165-9caa-41fd-b671-24987abc70b6>
—
Reply to this email directly, view it on GitHub
<#3586 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAWJWSTZQJHSZSDD5HFURILYKVMXPAVCNFSM6AAAAAA54ROHJCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQNRXGUYDEOBRGY>
.
You are receiving this because you commented.Message ID: <ggerganov/llama.
***@***.***>
|
Didn't make it any more concise though - it kept banging on about how
it's a nice artistic arrangement.
|
Add warning when trying to convert .safetensors model.safetensors
model conversionGGML_ASSERT: ggml-cuda.cu:6402: ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"
)gpt2
tokenizer produces the same results asGPTNeoxFast
fromtransformers
std::unordered_map
added_tokens
fixes to conversion script.bin
models toocloses #3456