2222                          GenerationConfig )
2323
2424from  vllm  import  LLM , SamplingParams 
25+ from  vllm .v1 .executor .abstract  import  Executor 
26+ from  vllm .v1 .kv_cache_interface  import  (ChunkedLocalAttentionSpec ,
27+                                         FullAttentionSpec )
2528
2629from  ....utils  import  multi_gpu_test 
2730
@@ -69,6 +72,26 @@ def run_maverick_serving(model: str):
6972        raise 
7073
7174
75+ def  get_rope_layers_config (model_path : str ) ->  list [int ]:
76+     """ 
77+     Get the interleaved RoPE configuration from HuggingFace config 
78+ 
79+     Args: 
80+         model_path: Path to the local directory containing the reduced 
81+             Maverick model checkpoint 
82+ 
83+     Returns: 
84+         List of 0 or 1 indicating whether each layer uses RoPE and local attn 
85+         0 indicates that RoPE is not used while 1 indicates that RoPE is used. 
86+     """ 
87+     config_path  =  Path (model_path ) /  "config.json" 
88+     model_config  =  json .loads (config_path .read_text ())
89+     text_config  =  model_config ["text_config" ]
90+     no_rope_layers  =  text_config ["no_rope_layers" ]
91+     print (f"Found no_rope_layers: { no_rope_layers }  " )
92+     return  no_rope_layers 
93+ 
94+ 
7295def  create_reduced_maverick_model (
7396    original_model_name :
7497    str  =  "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8" ,
@@ -113,7 +136,6 @@ def create_reduced_maverick_model(
113136        print ("Loading original model configuration..." )
114137        original_config  =  AutoConfig .from_pretrained (original_model_name ,
115138                                                     trust_remote_code = True )
116- 
117139        print ("Creating reduced configuration..." )
118140        reduced_config  =  create_reduced_config (original_config , text_layers ,
119141                                               num_experts , vision_layers )
@@ -510,21 +532,32 @@ def save_weights_to_safetensors(weights: dict[str, torch.Tensor],
510532          f"{ index_data ['metadata' ]['total_size' ] /  (1024 ** 3 ):.2f}   GB" )
511533
512534
513- def  run_reduced_model (model_path : str ,
514-                       should_profile : bool  =  False ,
515-                       ** kwargs ) ->  None :
516-     """Test the created reduced model with vLLM.""" 
517- 
518-     print (f"\n Testing reduced model at { model_path }  ..." )
519- 
520-     llm  =  LLM (
521-         model = model_path ,
522-         trust_remote_code = True ,
523-         max_model_len = 512 ,  # Small context for testing 
524-         gpu_memory_utilization = 0.3 ,  # Conservative memory usage 
525-         ** kwargs ,
535+ def  check_attention_spec_interleaved_rope (
536+     llm : LLM ,
537+     num_attention_layers : int ,
538+     num_ranks : int ,
539+     rope_layers : list [int ],
540+ ):
541+     """Check that the attention spec is correct.""" 
542+     assert  isinstance (llm .llm_engine .model_executor , Executor )
543+     kv_cache_specs_per_rank  =  llm .llm_engine .model_executor .get_kv_cache_specs (
526544    )
527- 
545+     for  rank  in  range (num_ranks ):
546+         kv_cache_specs  =  kv_cache_specs_per_rank [rank ]
547+         assert  len (kv_cache_specs .keys ()) ==  num_attention_layers 
548+         for  i  in  range (num_attention_layers ):
549+             if  rope_layers [i ] ==  0 :
550+                 expected_spec  =  FullAttentionSpec 
551+             else :
552+                 expected_spec  =  ChunkedLocalAttentionSpec 
553+             assert  isinstance (
554+                 kv_cache_specs [
555+                     f"language_model.model.layers.{ i }  .self_attn.attn" ],
556+                 expected_spec )
557+ 
558+ 
559+ def  run_reduced_model (llm : LLM , should_profile : bool  =  False ) ->  None :
560+     """Test the created reduced model with vLLM.""" 
528561    sampling_params  =  SamplingParams (temperature = 0.8 ,
529562                                     top_p = 0.95 ,
530563                                     max_tokens = 50 )
@@ -551,6 +584,7 @@ def run_reduced_model(model_path: str,
551584@pytest .mark .parametrize ("tp,ep" , [(2 , True )]) 
552585@pytest .mark .skipif (not  torch .cuda .is_available (), reason = "CUDA not available" ) 
553586def  test_dummy_maverick (
587+     monkeypatch ,
554588    original_model_name : str ,
555589    text_layers : int ,
556590    num_experts : int ,
@@ -562,6 +596,10 @@ def test_dummy_maverick(
562596    force_recreate : bool  =  True ,
563597    profile : bool  =  False ,
564598) ->  None :
599+     # Disable multiprocessing allows us to access model executor from LLM engine 
600+     monkeypatch .setenv ("VLLM_USE_V1" , "1" )
601+     monkeypatch .setenv ("VLLM_ENABLE_V1_MULTIPROCESSING" , "0" )
602+ 
565603    model_path  =  create_reduced_maverick_model (
566604        original_model_name = original_model_name ,
567605        output_dir = output_dir ,
@@ -573,11 +611,27 @@ def test_dummy_maverick(
573611
574612    print (f"\n Reduced model created successfully at: { model_path }  " )
575613
576-     run_reduced_model (model_path = model_path ,
577-                       should_profile = profile ,
578-                       enforce_eager = enforce_eager ,
579-                       tensor_parallel_size = tp ,
580-                       enable_expert_parallel = ep )
614+     rope_layers  =  get_rope_layers_config (model_path )
615+ 
616+     llm  =  LLM (
617+         model = model_path ,
618+         trust_remote_code = True ,
619+         max_model_len = 512 ,  # Small context for testing 
620+         gpu_memory_utilization = 0.3 ,  # Conservative memory usage 
621+         enforce_eager = enforce_eager ,
622+         tensor_parallel_size = tp ,
623+         enable_expert_parallel = ep ,
624+     )
625+ 
626+     check_attention_spec_interleaved_rope (
627+         llm ,
628+         text_layers ,
629+         tp ,
630+         rope_layers ,
631+     )
632+ 
633+     print (f"\n Testing reduced model at { model_path }  ..." )
634+     run_reduced_model (llm = llm , should_profile = profile )
581635
582636
583637def  main ():
0 commit comments