@@ -337,13 +337,19 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
337337        "target_attn_1" : mock .MagicMock (),
338338        "target_attn_2" : mock .MagicMock ()
339339    }
340+     target_indx_layers : dict [str , mock .MagicMock ] =  {}
340341    # Draft model has one extra attention layer compared to target model 
341342    all_attn_layers  =  {
342343        ** target_attn_layers , "draft_extra_attn" : mock .MagicMock ()
343344    }
344345
346+     all_indx_layers : dict [str , mock .MagicMock ] =  {}
347+ 
345348    # Make mock_get_layers return different values for each call 
346-     mock_get_layers .side_effect  =  [target_attn_layers , all_attn_layers ]
349+     mock_get_layers .side_effect  =  [
350+         target_attn_layers , target_indx_layers , all_attn_layers ,
351+         all_indx_layers 
352+     ]
347353
348354    # Setup mock for pp group to return the appropriate value for world size 
349355    mock_pp_group  =  mock .MagicMock ()
@@ -658,6 +664,9 @@ def create_deterministic_logits(token_ids, k: int):
658664    # Mock runner for attention metadata building. 
659665    proposer .runner  =  mock .MagicMock ()
660666    proposer .runner .attn_groups .append ([mock .MagicMock ()])
667+     proposer .runner .attn_groups [0 ][0 ].metadata_builders  =  [
668+         attn_metadata_builder 
669+     ]
661670    proposer .runner .attn_groups [0 ][0 ].get_metadata_builder .return_value  =  \
662671        attn_metadata_builder 
663672    proposer ._get_attention_metadata_builder  =  mock .MagicMock (
0 commit comments