@@ -117,34 +117,13 @@ def test_prepare_inputs():
117117    ]) 
118118@mock .patch ('vllm.v1.spec_decode.eagle.get_pp_group' ) 
119119@mock .patch ('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config' ) 
120- @mock .patch ('vllm.v1.spec_decode.eagle.ModelRegistry' ) 
121- @mock .patch ('vllm.v1.spec_decode.eagle.get_model_loader' ) 
122- @mock .patch ('vllm.v1.spec_decode.eagle.set_default_torch_dtype' ) 
123- @mock .patch ('vllm.v1.spec_decode.eagle.set_current_vllm_config' ) 
124- def  test_load_model (mock_set_config , mock_set_dtype , mock_get_loader ,
125-                     mock_registry , mock_get_layers , mock_get_pp_group , method ,
120+ @mock .patch ('vllm.v1.spec_decode.eagle.get_model' ) 
121+ def  test_load_model (mock_get_model , mock_get_layers , mock_get_pp_group , method ,
126122                    proposer_helper , draft_model_dir , target_attribute_path ):
127123
128-     # Setup mock for model class 
129-     mock_model_cls  =  mock .MagicMock ()
130-     mock_registry .resolve_model_cls .return_value  =  (mock_model_cls ,
131-                                                     "test_arch" )
132- 
133-     # Create a real context manager for mocks 
134-     class  MockContextManager :
135- 
136-         def  __init__ (self ):
137-             pass 
138- 
139-         def  __enter__ (self ):
140-             return  None 
141- 
142-         def  __exit__ (self , exc_type , exc_val , exc_tb ):
143-             return  False 
144- 
145-     # Make the mocks return actual context manager objects 
146-     mock_set_dtype .return_value  =  MockContextManager ()
147-     mock_set_config .return_value  =  MockContextManager ()
124+     # Setup model mock 
125+     mock_model  =  mock .MagicMock ()
126+     mock_get_model .return_value  =  mock_model 
148127
149128    # Setup mocks for attention layers 
150129    target_attn_layers  =  {
@@ -164,25 +143,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
164143    mock_pp_group .world_size  =  2  if  method  ==  "eagle"  else  1 
165144    mock_get_pp_group .return_value  =  mock_pp_group 
166145
167-     # Setup model loader mock 
168-     mock_loader  =  mock .MagicMock ()
169-     mock_get_loader .return_value  =  mock_loader 
170- 
171-     # Setup model mock 
172-     mock_model  =  mock .MagicMock ()
173-     mock_model_cls .return_value  =  mock_model 
174-     mock_model .to .return_value  =  mock_model 
175- 
176-     # Configure mock to test the attribute sharing path 
177-     if  method  ==  "eagle" :
178-         # For eagle, test the lm_head path 
179-         mock_model .load_weights .return_value  =  {
180-             "model.embed_tokens.weight" : torch .zeros (1 )
181-         }
182-     else :
183-         # For eagle3, test the embed_tokens path 
184-         mock_model .load_weights .return_value  =  {}
185- 
186146    # Setup target model with the appropriate attributes 
187147    target_model  =  mock .MagicMock ()
188148
@@ -204,13 +164,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
204164    proposer .load_model (target_model )
205165
206166    # Verify common interactions 
207-     mock_get_loader .assert_called_once ()
208-     mock_model_cls .assert_called_once ()
209-     mock_model .to .assert_called_once ()
210-     mock_model .load_weights .assert_called_once ()
211- 
212-     # Verify the loader was called with the right config 
213-     mock_get_loader .assert_called_once_with (proposer .vllm_config .load_config )
167+     mock_get_model .assert_called_once ()
214168
215169    # Verify the specific attribute sharing based on the method 
216170    if  method  ==  "eagle" :
0 commit comments