99from ...utils import check_outputs_equal
1010
1111# This test is for the hybrid models
12- MODELS = ["ai21labs/Jamba-tiny-dev" ]
12+ MODELS = ["ai21labs/Jamba-tiny-dev" , "Zyphra/Zamba2-1.2B-instruct" ]
1313# Bamba at Fp32 is too big for the CI (L4 GPU).
1414# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
1515
@@ -27,17 +27,19 @@ def test_models(
2727) -> None :
2828
2929 # numeric error produces different generation
30- if ' Bamba' in model :
30+ if " Bamba" in model :
3131 example_prompts .pop (3 )
3232
33- with hf_runner (
34- model ,
35- dtype = dtype ,
36- model_kwargs = {
37- "use_mamba_kernels" :
38- False , # mamba kernels are not installed so HF
39- # don't use them
40- }) as hf_model :
33+ model_kwargs = {
34+ "use_mamba_kernels" : False , # mamba kernels are not installed so HF
35+ # don't use them
36+ }
37+ if "Zamba2" in model :
38+ # Zamba2 HF implementation automatically checks if mamba kernels are
39+ # installed
40+ model_kwargs = {}
41+
42+ with hf_runner (model , dtype = dtype , model_kwargs = model_kwargs ) as hf_model :
4143 hf_outputs = hf_model .generate_greedy (example_prompts , max_tokens )
4244
4345 with vllm_runner (model , dtype = dtype ) as vllm_model :
@@ -112,26 +114,31 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
112114def test_mamba_prefill_chunking (hf_runner , vllm_runner , example_prompts ,
113115 model : str , dtype : str ,
114116 max_tokens : int ) -> None :
115- # numeric error during prefill chucking produces different generation
117+ # numeric error during prefill chunking produces different generation
116118 # compared to w/o prefill chunking for those examples, removed them for now
117- if ' Jamba' in model :
119+ if " Jamba" in model :
118120 example_prompts .pop (7 )
119121 example_prompts .pop (2 )
120122 example_prompts .pop (1 )
121- elif ' Bamba' in model :
123+ elif " Bamba" in model :
122124 example_prompts .pop (6 )
123125 example_prompts .pop (3 )
124126 example_prompts .pop (2 )
125127 dtype = "half" # use a different dtype for Bamba
126-
127- with hf_runner (
128- model ,
129- dtype = dtype ,
130- model_kwargs = {
131- "use_mamba_kernels" :
132- False , # mamba kernels are not installed so HF
133- # don't use them
134- }) as hf_model :
128+ elif "Zamba2" in model :
129+ example_prompts .pop (7 )
130+ dtype = "half"
131+
132+ model_kwargs = {
133+ "use_mamba_kernels" : False , # mamba kernels are not installed so HF
134+ # don't use them
135+ }
136+ if "Zamba2" in model :
137+ # Zamba2 HF implementation automatically checks if mamba kernels are
138+ # installed
139+ model_kwargs = {}
140+
141+ with hf_runner (model , dtype = dtype , model_kwargs = model_kwargs ) as hf_model :
135142 non_chunked = hf_model .generate_greedy (example_prompts , max_tokens )
136143
137144 with vllm_runner (model ,
0 commit comments