99from  ...utils  import  check_outputs_equal 
1010
1111# This test is for the hybrid models 
12- MODELS  =  ["ai21labs/Jamba-tiny-dev" ]
12+ MODELS  =  ["ai21labs/Jamba-tiny-dev" ,  "Zyphra/Zamba2-1.2B-instruct" ]
1313# Bamba at Fp32 is too big for the CI (L4 GPU). 
1414# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"] 
1515
@@ -27,17 +27,19 @@ def test_models(
2727) ->  None :
2828
2929    # numeric error produces different generation 
30-     if  ' Bamba'   in  model :
30+     if  " Bamba"   in  model :
3131        example_prompts .pop (3 )
3232
33-     with  hf_runner (
34-             model ,
35-             dtype = dtype ,
36-             model_kwargs = {
37-                 "use_mamba_kernels" :
38-                 False ,  # mamba kernels are not installed so HF  
39-                 # don't use them 
40-             }) as  hf_model :
33+     model_kwargs  =  {
34+         "use_mamba_kernels" : False ,  # mamba kernels are not installed so HF  
35+         # don't use them 
36+     }
37+     if  "Zamba2"  in  model :
38+         # Zamba2 HF implementation automatically checks if mamba kernels are 
39+         # installed 
40+         model_kwargs  =  {}
41+ 
42+     with  hf_runner (model , dtype = dtype , model_kwargs = model_kwargs ) as  hf_model :
4143        hf_outputs  =  hf_model .generate_greedy (example_prompts , max_tokens )
4244
4345    with  vllm_runner (model , dtype = dtype ) as  vllm_model :
@@ -112,26 +114,31 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
112114def  test_mamba_prefill_chunking (hf_runner , vllm_runner , example_prompts ,
113115                                model : str , dtype : str ,
114116                                max_tokens : int ) ->  None :
115-     # numeric error during prefill chucking  produces different generation 
117+     # numeric error during prefill chunking  produces different generation 
116118    # compared to w/o prefill chunking for those examples, removed them for now 
117-     if  ' Jamba'   in  model :
119+     if  " Jamba"   in  model :
118120        example_prompts .pop (7 )
119121        example_prompts .pop (2 )
120122        example_prompts .pop (1 )
121-     elif  ' Bamba'   in  model :
123+     elif  " Bamba"   in  model :
122124        example_prompts .pop (6 )
123125        example_prompts .pop (3 )
124126        example_prompts .pop (2 )
125127        dtype  =  "half"   # use a different dtype for Bamba 
126- 
127-     with  hf_runner (
128-             model ,
129-             dtype = dtype ,
130-             model_kwargs = {
131-                 "use_mamba_kernels" :
132-                 False ,  # mamba kernels are not installed so HF  
133-                 # don't use them 
134-             }) as  hf_model :
128+     elif  "Zamba2"  in  model :
129+         example_prompts .pop (7 )
130+         dtype  =  "half" 
131+ 
132+     model_kwargs  =  {
133+         "use_mamba_kernels" : False ,  # mamba kernels are not installed so HF  
134+         # don't use them 
135+     }
136+     if  "Zamba2"  in  model :
137+         # Zamba2 HF implementation automatically checks if mamba kernels are 
138+         # installed 
139+         model_kwargs  =  {}
140+ 
141+     with  hf_runner (model , dtype = dtype , model_kwargs = model_kwargs ) as  hf_model :
135142        non_chunked  =  hf_model .generate_greedy (example_prompts , max_tokens )
136143
137144    with  vllm_runner (model ,
0 commit comments