Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* LLaMA * sharding and docs * tweak * black * inits * ruff * LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP * init * no checkpoint * docs * ruff * type_vocab_size * tokenizer fixes * tokenizer fixes * Update tokenization_llama.py * Update tokenization_llama.py * Update configuration_llama.py * Update modeling_llama.py * tokenizer add_bos by default * licenses * remove decoder * norms and mlp * rope overhaul * tweaks * black * mention OPT implementation * off-by-one naming * typo * fix * tokenization fix and slicing bug * padding config * cleanup * black * update tests * undo typo * fix vocab caching logic * ruff * docbuilder * attn fix from BlackSamorez * initial feedback * typo * docs * llama case * llama case * load checkpoint docs * comment about tokenizer * tokenizer defaults * clear past_key_values if use_cache=False * last tweaks * last tweaks * last tweaks * last tweaks --------- Co-authored-by: Stella Biderman <stellabiderman@gmail.com>
- Loading branch information
464d420
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any idea why I might be getting this warning @zphang @StellaAthena ?
Warning:
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization.
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'.
The class this function is called from is 'LlamaTokenizer'."
Code:
"
LlamaTokenizer.from_pretrained(
<tokenizer_name>
), LlamaForCausalLM.from_pretrained(
<model_name>, load_in_8bit=True, device_map="auto", torch_dtype=torch.float16
)
"
464d420
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The classes were originally named
LLaMA
and thedecapoda-research
weights on HuggingFace were saved with these class names. Later the classes were renamed toLlama
(notice the case difference), but the HF models still have the old naming convention, causing this error. Everything works fine, but you will get this error until the model config files are updated.464d420
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zphang and @StellaAthena - thankyou so much for providing Llama model support in hugging face ! I have been going through the inner workings of your wonderful Llama implementation to get a better understanding. In this process, I have hit a roadblock with the apply_rotary_pos_emb that then calls the rotate_half function.
After taking an example, I see this code working perfectly well for a two dimensional case [x1, x2] but I am unable to see how this works for 2+ dimension case. To the best of what I understand, I have tried to show this with an example below. If you could help me understand where I have gone wrong, that would be immensely useful. Or else if this is indeed a bug, then I would be happy to help fix it.
Let us take an example with 4 dimensions [x1, x2, x3, x4].
Hence the cos tensor would be [Cos(theta-0), Cos(theta-1), Cos(theta-0), Cos(theta-1)] as per Line 168 and similarly for the sine tensor.
The rotate_half function taken in [x1,x2,x3,x4] and would return [-x3, -x4, x1, x2]. Hence Line 186 would then compute the new value for x1 as [x1Cos(theta-0) -x3Sin(theta0)] when I think it should probably be [x1Cos(theta-0) -x2Sin(theta0)] ?
Can you please help me understand at which step am I making a mistake (if any) ?
One way I could see this working is if the input itself was [x1,x3,x2,x4] then I think this would be correct. But have the inputs been permuted like the above ?