Skip to content

Commit cf5c0f8

Browse files
Merge pull request vllm-project#3 from jiangkuaixue123/jcz_afd_v0.11.0rc3_dev
fix log for graph
2 parents 3552481 + 658580f commit cf5c0f8

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

vllm/model_executor/models/deepseek_v2.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
895895
config = vllm_config.model_config.hf_config
896896
quant_config = vllm_config.quant_config
897897
self.config = config
898+
self.enforce_eager = vllm_config.model_config.enforce_eager
898899
self.first_k_dense_replace = config.first_k_dense_replace
899900
self.afd_config = vllm_config.afd_config
900901
self.connector_name = self.afd_config.afd_connector if self.afd_config is not None else None
@@ -958,15 +959,17 @@ def forward_m2n(
958959
if layer.layer_idx < self.first_k_dense_replace:
959960
hidden_states, residual = layer(positions, hidden_states, residual)
960961
continue
961-
962-
logger.info(f"jcz deepseekv2 layer_idx:{layer.layer_idx} metadata:{afd_metadata} hidden_states:{hidden_states.shape}")
962+
963963
afd_connector = afd_metadata.afd_connector
964964
afd_metadata.afd_stage_idx = dbo_current_ubatch_id()
965965
start_idx = afd_metadata.afd_tokens_start_loc[afd_metadata.afd_stage_idx]
966966
end_idx = start_idx + afd_metadata.afd_tokens_lens[afd_metadata.afd_stage_idx]
967-
logger.info(f"jcz deepseekv2 layer_idx:{layer.layer_idx} start_loc:{afd_metadata.afd_tokens_start_loc} "
968-
f"start_idx:{start_idx} end_idx:{end_idx} "
969-
f"stage_idx:{afd_metadata.afd_stage_idx}")
967+
968+
if self.enforce_eager:
969+
logger.info(f"jcz deepseekv2 layer_idx:{layer.layer_idx} metadata:{afd_metadata} hidden_states:{hidden_states.shape}")
970+
logger.info(f"jcz deepseekv2 layer_idx:{layer.layer_idx} start_loc:{afd_metadata.afd_tokens_start_loc} "
971+
f"start_idx:{start_idx} end_idx:{end_idx} "
972+
f"stage_idx:{afd_metadata.afd_stage_idx}")
970973

971974
if recv_handle is not None:
972975
for work in recv_handle:
@@ -1008,7 +1011,7 @@ def forward_m2n(
10081011
if self.connector_name == "m2nconnector":
10091012
handle = afd_connector.send_attn_output(current_hidden,topk_weights,topk_ids,metadata)
10101013
metadata.m2n_afdconnector_data.handle = handle
1011-
hidden_states, recv_handle = afd_connector.recv_ffn_output(hidden_states,metadata)
1014+
hidden_states = afd_connector.recv_ffn_output(hidden_states,metadata)
10121015
elif self.connector_name == "camconnector":
10131016
afd_connector.send_attn_output(current_hidden, topk_weights, topk_ids, metadata)
10141017
hidden_states = afd_connector.recv_ffn_output(metadata)

0 commit comments

Comments
 (0)