You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I tried to reproduce zephyr-7b-gemma-v0.1 using the exact code provided in this repository with 4xA100 GPUs. However, the resulting MT-bench test score was much lower than reported: 6.63, versus the reported value on huggingface pages: 7.81.
I wonder if anyone else is encountering this issue?
Command ran (the same as what's mentioned in the repo but modified gradient accumulation since I am using only 4xA100)
It seems that the issue is with chat templates used by fastchat during evaluation. Using the following templates to test H4's gemma models recovers the reported performance:
from fastchat.conversation import register_conv_template
register_conv_template(
Conversation(
name="templ=h4_gemma_chatml",
system_template="<bos><|im_start|>system\n{system_message}",
system_message="You are an AI assistant.",
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
stop_str=["<|im_end|>", "<|endoftext|>"],
)
)
# other init code omitted
This relates from how the model is trained using the run_dpo.py script. In that script, chat data is first formatted using tokenizer's template and then fed into Trainer. Unless you use (maybe) the latest version of fschat (which uses hardcoded templates), fschat will not use that same template; which leads to performance degradation.
I tried to reproduce
zephyr-7b-gemma-v0.1
using the exact code provided in this repository with 4xA100 GPUs. However, the resulting MT-bench test score was much lower than reported: 6.63, versus the reported value on huggingface pages: 7.81.I wonder if anyone else is encountering this issue?
Command ran (the same as what's mentioned in the repo but modified gradient accumulation since I am using only 4xA100)
and when generating model answers for MT-bench I used the default commands:
Related library versions I used:
torch==2.1.2+cu118, transformers==4.39.1, trl==0.8.1, flash-attn==2.5.6, fschat==0.2.36
training curves from wandb:

eval reward curves:

The text was updated successfully, but these errors were encountered: