@@ -74,15 +74,13 @@ def set_forward_context(attn_metadata: Any,
7474 if vllm_config .parallel_config .data_parallel_size > 1 :
7575 dp_size = vllm_config .parallel_config .data_parallel_size
7676 dp_rank = vllm_config .parallel_config .data_parallel_rank
77- if attn_metadata is not None :
78- if hasattr (attn_metadata , "num_prefill_tokens" ):
79- # for v0 attention backends
80- batchsize = attn_metadata .num_prefill_tokens + \
81- attn_metadata .num_decode_tokens
82- else :
83- # for v1 attention backends
84- batchsize = attn_metadata .num_input_tokens
77+ if attn_metadata is not None and hasattr (attn_metadata ,
78+ "num_prefill_tokens" ):
79+ # for v0 attention backends
80+ batchsize = attn_metadata .num_prefill_tokens + \
81+ attn_metadata .num_decode_tokens
8582 else :
83+ # for v1 attention backends or no attn_metadata
8684 batchsize = num_tokens
8785 num_tokens_across_dp = [0 ] * dp_size
8886 num_tokens_across_dp [dp_rank ] = batchsize
@@ -124,7 +122,7 @@ def set_forward_context(attn_metadata: Any,
124122 attn_metadata .num_decode_tokens
125123 else :
126124 # for v1 attention backends
127- batchsize = attn_metadata . num_input_tokens
125+ batchsize = num_tokens
128126 # we use synchronous scheduling right now,
129127 # adding a sync point here should not affect
130128 # scheduling of the next batch
0 commit comments