-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[New Model]: meta-llama/Llama-Guard-3-1B #9294
Comments
for some reason the chat template for 8b is different from 1b on hf (see
gives the output
also fwiw it doesn't look like the pruned version is in the hf repo, the output layer weights shape is the full |
Thanks for the tip here @conwayz! I did a little debugging and the issue does in fact lie in the parsing of the chat template. For the llama-2 guard model, you can see that the chat template gets parsed correctly and the prompt includes the messages sent by the user. (Look for the content inserted in between
Whereas for a similar interaction with the llama-guard-1B model, the user message does not get inserted into the conversation bracket:
The reason why the code snippet provided by @conwayz works as expected is because the chat template format for models using the Llama 3.2 model variant have a new field that needs to be passed along with the
The only part of this which puzzles me is, it seems that even if we specify the request format in the same way as in Huggingface, vLLM does not do the substitution properly.
Is there any default substitution that happens which does not take into account a different type of chat template? |
So after a little bit more digging, it seems the primary issues comes from this part of the code in vLLM. This particular segment of code converts the content field into a single string whereas the 1B Guard model expects the
And since it doesn't find any elements while trying to loop through it, only the two |
For anyone who is reading this issue and is looking for a temporary work-around, using this chat template for the 1B model seems to fix the issue in vLLM.
Relevant code snippet
|
The model to consider.
meta-llama/Llama-Guard-3-1B
The closest model vllm already supports.
meta-llama/Llama-Guard-3-8B
What's your difficulty of supporting the model you want?
Currently the model runs, but its outputs are completely random, so the same prompt can be safe or unsafe at any point. Setting the temperature to 0.0 makes EVERY prompt return safe.
My hunch is the issue comes from the model pruning:
Output Layer Pruning
The Llama Guard model is trained to generate 128k output tokens out of which only 20 tokens (e.g. safe, unsafe, S, 1,...) are used. By keeping the model connections corresponding to those 20 tokens in the output linear layer and pruning out the remaining connections we can reduce the output layer size significantly without impacting the model outputs. Using output layer pruning, we reduced the output layer size from 262.6M parameters (2048x128k) to 40.96k parameters (2048x20), giving us a total savings of 131.3MB with 4-bit quantized weights. Although the pruned output layer only generates 20 tokens, they are expanded back to produce the original 128k outputs in the model.
Before submitting a new issue...
The text was updated successfully, but these errors were encountered: