@@ -895,6 +895,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
895895 config = vllm_config .model_config .hf_config
896896 quant_config = vllm_config .quant_config
897897 self .config = config
898+ self .enforce_eager = vllm_config .model_config .enforce_eager
898899 self .first_k_dense_replace = config .first_k_dense_replace
899900 self .afd_config = vllm_config .afd_config
900901 self .connector_name = self .afd_config .afd_connector if self .afd_config is not None else None
@@ -958,15 +959,17 @@ def forward_m2n(
958959 if layer .layer_idx < self .first_k_dense_replace :
959960 hidden_states , residual = layer (positions , hidden_states , residual )
960961 continue
961-
962- logger .info (f"jcz deepseekv2 layer_idx:{ layer .layer_idx } metadata:{ afd_metadata } hidden_states:{ hidden_states .shape } " )
962+
963963 afd_connector = afd_metadata .afd_connector
964964 afd_metadata .afd_stage_idx = dbo_current_ubatch_id ()
965965 start_idx = afd_metadata .afd_tokens_start_loc [afd_metadata .afd_stage_idx ]
966966 end_idx = start_idx + afd_metadata .afd_tokens_lens [afd_metadata .afd_stage_idx ]
967- logger .info (f"jcz deepseekv2 layer_idx:{ layer .layer_idx } start_loc:{ afd_metadata .afd_tokens_start_loc } "
968- f"start_idx:{ start_idx } end_idx:{ end_idx } "
969- f"stage_idx:{ afd_metadata .afd_stage_idx } " )
967+
968+ if self .enforce_eager :
969+ logger .info (f"jcz deepseekv2 layer_idx:{ layer .layer_idx } metadata:{ afd_metadata } hidden_states:{ hidden_states .shape } " )
970+ logger .info (f"jcz deepseekv2 layer_idx:{ layer .layer_idx } start_loc:{ afd_metadata .afd_tokens_start_loc } "
971+ f"start_idx:{ start_idx } end_idx:{ end_idx } "
972+ f"stage_idx:{ afd_metadata .afd_stage_idx } " )
970973
971974 if recv_handle is not None :
972975 for work in recv_handle :
@@ -1008,7 +1011,7 @@ def forward_m2n(
10081011 if self .connector_name == "m2nconnector" :
10091012 handle = afd_connector .send_attn_output (current_hidden ,topk_weights ,topk_ids ,metadata )
10101013 metadata .m2n_afdconnector_data .handle = handle
1011- hidden_states , recv_handle = afd_connector .recv_ffn_output (hidden_states ,metadata )
1014+ hidden_states = afd_connector .recv_ffn_output (hidden_states ,metadata )
10121015 elif self .connector_name == "camconnector" :
10131016 afd_connector .send_attn_output (current_hidden , topk_weights , topk_ids , metadata )
10141017 hidden_states = afd_connector .recv_ffn_output (metadata )
0 commit comments