1717 "state-spaces/mamba-130m-hf" ,
1818 "tiiuae/falcon-mamba-tiny-dev" ,
1919 # TODO: Compare to a Mamba2 model. The HF transformers implementation of
20- # Mamba2 is buggy for Codestral as it doesn't handle n_groups.
20+ # Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
21+ # doesn't compare vLLM output with HF output.
2122 # See https://github.com/huggingface/transformers/pull/35943
22- # "mistralai/Mamba-Codestral-7B-v0.1",
23+ "mistralai/Mamba-Codestral-7B-v0.1" ,
2324]
2425
2526HYBRID_MODELS = [
3536 "hmellor/tiny-random-BambaForCausalLM" ,
3637]
3738
39+ V1_SUPPORTED_MODELS = [
40+ "mistralai/Mamba-Codestral-7B-v0.1" ,
41+ ]
42+
3843# Avoid OOM
3944MAX_NUM_SEQS = 4
4045
@@ -46,24 +51,50 @@ def test_models(
4651 hf_runner ,
4752 vllm_runner ,
4853 example_prompts ,
54+ monkeypatch ,
4955 model : str ,
5056 max_tokens : int ,
5157 num_logprobs : int ,
5258) -> None :
5359 with hf_runner (model ) as hf_model :
54- hf_outputs = hf_model .generate_greedy_logprobs_limit (
55- example_prompts , max_tokens , num_logprobs )
60+ if model != "mistralai/Mamba-Codestral-7B-v0.1" :
61+ hf_outputs = hf_model .generate_greedy_logprobs_limit (
62+ example_prompts , max_tokens , num_logprobs )
63+ else :
64+ hf_outputs = None
5665
5766 with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
58- vllm_outputs = vllm_model .generate_greedy_logprobs (
67+ vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
5968 example_prompts , max_tokens , num_logprobs )
6069
61- check_logprobs_close (
62- outputs_0_lst = hf_outputs ,
63- outputs_1_lst = vllm_outputs ,
64- name_0 = "hf" ,
65- name_1 = "vllm" ,
66- )
70+ if model in V1_SUPPORTED_MODELS :
71+ with monkeypatch .context () as m :
72+ m .setenv ("VLLM_USE_V1" , "1" )
73+ with vllm_runner (model ,
74+ max_num_seqs = MAX_NUM_SEQS ,
75+ enforce_eager = True ,
76+ enable_prefix_caching = False ) as vllm_model :
77+ vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
78+ example_prompts , max_tokens , num_logprobs )
79+ else :
80+ vllm_v1_outputs = None
81+
82+ if hf_outputs is not None :
83+ check_logprobs_close (
84+ outputs_0_lst = hf_outputs ,
85+ outputs_1_lst = vllm_v0_outputs ,
86+ name_0 = "hf" ,
87+ name_1 = "vllm-v0" ,
88+ )
89+
90+ if model in V1_SUPPORTED_MODELS :
91+ ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
92+ check_logprobs_close (
93+ outputs_0_lst = ref_outputs ,
94+ outputs_1_lst = vllm_v1_outputs ,
95+ name_0 = "hf" if hf_outputs is not None else "vllm-v0" ,
96+ name_1 = "vllm-v1" ,
97+ )
6798
6899
69100@pytest .mark .parametrize ("model" , SSM_MODELS + HYBRID_MODELS )
0 commit comments