Skip to content

Commit 6475679

Browse files
committed
fix pre-commit check
Signed-off-by: FENP <32334296+FENP@users.noreply.github.com>
1 parent 62de2ea commit 6475679

File tree

3 files changed

+75
-73
lines changed

3 files changed

+75
-73
lines changed

vllm/attention/ops/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def cp_lse_ag_out_rs(
135135
cp_attn_lse: torch.Tensor,
136136
cp_group: GroupCoordinator,
137137
ctx: CPTritonContext = None,
138-
return_lse = False,
138+
return_lse=False,
139139
):
140140
"""
141141
cp_attn_out: [ B, H, D ]
@@ -162,7 +162,7 @@ def cp_lse_ag_out_rs(
162162
if return_lse:
163163
cp_num_heads = lse.shape[1] // cp_group.world_size
164164
cp_rank = cp_group.rank_in_group
165-
lse = lse[:, cp_num_heads * cp_rank:cp_num_heads * (cp_rank + 1)]
165+
lse = lse[:, cp_num_heads * cp_rank : cp_num_heads * (cp_rank + 1)]
166166
return out, lse
167167
return out
168168

vllm/config/model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,21 +1209,22 @@ def verify_with_parallel_config(
12091209
"Supported models implement the `SupportsPP` interface."
12101210
)
12111211

1212-
decode_context_parallel_size = \
1213-
parallel_config.decode_context_parallel_size
1212+
decode_context_parallel_size = parallel_config.decode_context_parallel_size
12141213
if decode_context_parallel_size > 1 and not self.use_mla:
12151214
total_num_kv_heads = self.get_total_num_kv_heads()
12161215
assert tensor_parallel_size > total_num_kv_heads, (
12171216
f"tensor parallel size {tensor_parallel_size} must be greater "
12181217
f"than total num kv heads {total_num_kv_heads} when enable "
1219-
f"decode context parallel for GQA/MQA")
1218+
f"decode context parallel for GQA/MQA"
1219+
)
12201220

12211221
max_dcp_size = tensor_parallel_size // total_num_kv_heads
12221222
assert decode_context_parallel_size <= max_dcp_size, (
12231223
f"decode context parallel size must less than or equal to "
12241224
f"(tensor parallel size {tensor_parallel_size} // total "
12251225
f"num kv heads {total_num_kv_heads}) = {max_dcp_size}, "
1226-
f"but got {decode_context_parallel_size}")
1226+
f"but got {decode_context_parallel_size}"
1227+
)
12271228

12281229
def get_sliding_window(self) -> Optional[int]:
12291230
"""Get the sliding window size from the HF text config if present."""

vllm/v1/attention/backends/flash_attn.py

Lines changed: 68 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -340,28 +340,32 @@ def schedule(
340340
prefix_scheduler_metadata = None
341341

342342
if self.dcp_world_size > 1:
343-
query_kv_lens_cpu = common_attn_metadata.query_start_loc_cpu[1:] \
343+
query_kv_lens_cpu = (
344+
common_attn_metadata.query_start_loc_cpu[1:]
344345
- common_attn_metadata.query_start_loc_cpu[:-1]
346+
)
345347
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
346-
dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu \
347-
// self.dcp_world_size + (self.dcp_rank \
348-
<= (dcp_context_kv_lens_cpu-1) % self.dcp_world_size)
348+
dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self.dcp_world_size + (
349+
self.dcp_rank <= (dcp_context_kv_lens_cpu - 1) % self.dcp_world_size
350+
)
349351
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
350352
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
351353

352-
scheduler_metadata = schedule(batch_size=num_reqs,
353-
cu_query_lens=query_start_loc,
354-
max_query_len=max_query_len,
355-
seqlens=dcp_context_kv_lens,
356-
max_seq_len=max_dcp_context_kv_len,
357-
causal=False)
354+
scheduler_metadata = schedule(
355+
batch_size=num_reqs,
356+
cu_query_lens=query_start_loc,
357+
max_query_len=max_query_len,
358+
seqlens=dcp_context_kv_lens,
359+
max_seq_len=max_dcp_context_kv_len,
360+
causal=False,
361+
)
358362
elif use_cascade:
359-
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
360-
dtype=torch.int32,
361-
device=self.device)
362-
prefix_kv_lens = torch.tensor([common_prefix_len],
363-
dtype=torch.int32,
364-
device=self.device)
363+
cu_prefix_query_lens = torch.tensor(
364+
[0, num_actual_tokens], dtype=torch.int32, device=self.device
365+
)
366+
prefix_kv_lens = torch.tensor(
367+
[common_prefix_len], dtype=torch.int32, device=self.device
368+
)
365369
suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to(
366370
self.device, non_blocking=True
367371
)
@@ -683,60 +687,57 @@ def _forward_with_dcp(
683687

684688
query = query.contiguous()
685689
query_across_dcp = get_dcp_group().all_gather(query, dim=1)
686-
context_attn_out, context_lse = \
687-
flash_attn_varlen_func(
688-
q=query_across_dcp,
689-
k=key_cache,
690-
v=value_cache,
691-
out=None,
692-
cu_seqlens_q=cu_seqlens_q,
693-
max_seqlen_q=max_seqlen_q,
694-
seqused_k=attn_metadata.dcp_context_kv_lens,
695-
max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
696-
softmax_scale=self.scale,
697-
causal=False,
698-
alibi_slopes=self.alibi_slopes,
699-
window_size=self.sliding_window,
700-
block_table=block_table,
701-
softcap=self.logits_soft_cap,
702-
return_softmax_lse=True,
703-
scheduler_metadata=attn_metadata.scheduler_metadata,
704-
fa_version=self.vllm_flash_attn_version,
705-
q_descale=q_descale,
706-
k_descale=k_descale,
707-
v_descale=v_descale,
708-
)
690+
context_attn_out, context_lse = flash_attn_varlen_func(
691+
q=query_across_dcp,
692+
k=key_cache,
693+
v=value_cache,
694+
out=None,
695+
cu_seqlens_q=cu_seqlens_q,
696+
max_seqlen_q=max_seqlen_q,
697+
seqused_k=attn_metadata.dcp_context_kv_lens,
698+
max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
699+
softmax_scale=self.scale,
700+
causal=False,
701+
alibi_slopes=self.alibi_slopes,
702+
window_size=self.sliding_window,
703+
block_table=block_table,
704+
softcap=self.logits_soft_cap,
705+
return_softmax_lse=True,
706+
scheduler_metadata=attn_metadata.scheduler_metadata,
707+
fa_version=self.vllm_flash_attn_version,
708+
q_descale=q_descale,
709+
k_descale=k_descale,
710+
v_descale=v_descale,
711+
)
709712
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
710-
context_attn_out_cor, context_lse_cor = \
711-
cp_lse_ag_out_rs(
712-
context_attn_out,
713-
context_lse.transpose(0, 1),
714-
get_dcp_group(),
715-
return_lse=True
716-
)
713+
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
714+
context_attn_out,
715+
context_lse.transpose(0, 1),
716+
get_dcp_group(),
717+
return_lse=True,
718+
)
717719
context_lse_cor = context_lse_cor.transpose(0, 1).contiguous()
718720

719-
query_attn_out, query_lse = \
720-
flash_attn_varlen_func(
721-
q=query,
722-
k=key,
723-
v=value,
724-
out=None,
725-
cu_seqlens_q=cu_seqlens_q,
726-
max_seqlen_q=max_seqlen_q,
727-
cu_seqlens_k=cu_seqlens_q,
728-
max_seqlen_k=max_seqlen_q,
729-
softmax_scale=self.scale,
730-
causal=attn_metadata.causal,
731-
alibi_slopes=self.alibi_slopes,
732-
window_size=self.sliding_window,
733-
softcap=self.logits_soft_cap,
734-
return_softmax_lse=True,
735-
fa_version=self.vllm_flash_attn_version,
736-
q_descale=q_descale,
737-
k_descale=k_descale,
738-
v_descale=v_descale,
739-
)
721+
query_attn_out, query_lse = flash_attn_varlen_func(
722+
q=query,
723+
k=key,
724+
v=value,
725+
out=None,
726+
cu_seqlens_q=cu_seqlens_q,
727+
max_seqlen_q=max_seqlen_q,
728+
cu_seqlens_k=cu_seqlens_q,
729+
max_seqlen_k=max_seqlen_q,
730+
softmax_scale=self.scale,
731+
causal=attn_metadata.causal,
732+
alibi_slopes=self.alibi_slopes,
733+
window_size=self.sliding_window,
734+
softcap=self.logits_soft_cap,
735+
return_softmax_lse=True,
736+
fa_version=self.vllm_flash_attn_version,
737+
q_descale=q_descale,
738+
k_descale=k_descale,
739+
v_descale=v_descale,
740+
)
740741
assert context_attn_out_cor.shape == query_attn_out.shape
741742
assert context_lse_cor.shape == query_lse.shape
742743
merge_attn_states(

0 commit comments

Comments
 (0)