Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions csrc/attention/mla/sm100_cutlass_mla_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ struct MlaSm100 {
template <typename T>
typename T::Fmha::Arguments args_from_options(
at::Tensor const& out,
at::Tensor const& lse,
at::Tensor const& q_nope,
at::Tensor const& q_pe,
at::Tensor const& kv_c_and_k_pe_cache,
Expand Down Expand Up @@ -162,7 +163,10 @@ typename T::Fmha::Arguments args_from_options(
stride_PT,
page_count_total,
page_size},
{static_cast<ElementOut*>(out.data_ptr()), stride_O, static_cast<ElementAcc*>(nullptr), stride_LSE},
{static_cast<ElementOut*>(out.data_ptr()),
stride_O,
static_cast<ElementAcc*>(lse.defined() ? lse.data_ptr() : nullptr),
stride_LSE},
hw_info,
// TODO(trevor-m): Change split_kv back to -1 when
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
Expand All @@ -181,6 +185,7 @@ typename T::Fmha::Arguments args_from_options(
template <typename Element, bool IsPaged128, typename PersistenceOption>
void runMla(
at::Tensor const& out,
at::Tensor const& lse,
at::Tensor const& q_nope,
at::Tensor const& q_pe,
at::Tensor const& kv_c_and_k_pe_cache,
Expand All @@ -192,7 +197,7 @@ void runMla(
cudaStream_t stream) {
using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
typename MlaSm100Type::Fmha fmha;
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
auto arguments = args_from_options<MlaSm100Type>(out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);

CUTLASS_CHECK(fmha.can_implement(arguments));

Expand All @@ -214,6 +219,7 @@ void runMla(

void sm100_cutlass_mla_decode(
torch::Tensor const& out,
torch::Tensor const& lse,
torch::Tensor const& q_nope,
torch::Tensor const& q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
Expand All @@ -234,13 +240,13 @@ void sm100_cutlass_mla_decode(
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
if (in_dtype == at::ScalarType::Half) {
runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else {
TORCH_CHECK(false, "Unsupported input data type of MLA");
}
Expand Down
3 changes: 2 additions & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {

// SM100 CUTLASS MLA decode
ops.def(
"sm100_cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
"sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope, "
"Tensor q_pe,"
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
" Tensor page_table, Tensor workspace, float "
"scale,"
Expand Down
6 changes: 3 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1843,13 +1843,13 @@ def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
return out


def sm100_cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
q_pe: torch.Tensor,
def sm100_cutlass_mla_decode(out: torch.Tensor, lse: torch.Tensor,
q_nope: torch.Tensor, q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
seq_lens: torch.Tensor, page_table: torch.Tensor,
workspace: torch.Tensor, scale: float,
num_kv_splits: int) -> torch.Tensor:
torch.ops._C.sm100_cutlass_mla_decode(out, q_nope, q_pe,
torch.ops._C.sm100_cutlass_mla_decode(out, lse, q_nope, q_pe,
kv_c_and_k_pe_cache, seq_lens,
page_table, workspace, scale,
num_kv_splits)
Expand Down
70 changes: 70 additions & 0 deletions vllm/attention/ops/merge_attn_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,73 @@ def supported_headdim(o: torch.Tensor) -> bool:
merge_attn_states)
return merge_attn_states(output, prefix_output, prefix_lse,
suffix_output, suffix_lse, output_lse)


def merge_multi_attn_states(partials: torch.Tensor,
lse: torch.Tensor) -> torch.Tensor:
"""Merge attention partials across a parallel dimension using LSE.

Args:
partials: [tp, B, H_owned, D]
lse: [tp, B, H_owned]

Returns:
merged: [B, H_owned, D]
"""
assert partials.dim() == 4 and lse.dim() == 3, (
f"partials shape {partials.shape}, lse shape {lse.shape}")
tp, batch_size, heads_owned, dim = partials.shape
# [tp, B, H_owned] -> [B, H_owned]
max_lse, _ = torch.max(lse, dim=0)
# Avoid -inf producing NaNs
max_lse = torch.where(torch.isfinite(max_lse), max_lse,
torch.zeros_like(max_lse))

# Compute exp-corrected weights and normalize across tp
# [tp, B, H_owned]
weights = torch.exp(lse - max_lse.unsqueeze(0))
denom = torch.clamp(weights.sum(dim=0, keepdim=False), min=1e-20)
weights = weights / denom

# Apply weights to partials: broadcast weights to dim
# [tp, B, H_owned, D]
weighted = partials * weights.unsqueeze(-1)
merged = weighted.sum(dim=0)
return merged


def reduce_lse_over_tp(lse: torch.Tensor) -> torch.Tensor:
"""Reduce per-rank LSE across TP via stable log-sum-exp.

Args:
lse: [tp, B, H_owned]

Returns:
reduced_lse: [B, H_owned]
"""
assert lse.dim() == 3
tp_max, _ = torch.max(lse, dim=0)
tp_max = torch.where(torch.isfinite(tp_max), tp_max,
torch.zeros_like(tp_max))
weights = torch.exp(lse - tp_max.unsqueeze(0))
denom = torch.clamp(weights.sum(dim=0, keepdim=False), min=1e-20)
return torch.log(denom) + tp_max


def merge_multi_attn_states_with_lse(
partials: torch.Tensor,
lse: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Fused helper that returns merged outputs and reduced LSE.

Args:
partials: [tp, B, H_owned, D]
lse: [tp, B, H_owned]

Returns:
(merged, reduced_lse):
merged: [B, H_owned, D]
reduced_lse: [B, H_owned]
"""
merged = merge_multi_attn_states(partials, lse)
reduced = reduce_lse_over_tp(lse)
return merged, reduced
8 changes: 8 additions & 0 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,13 @@ class ModelConfig:
- "transformers" will use the Transformers model implementation."""
override_attention_dtype: Optional[str] = None
"""Override dtype for attention"""
enable_mla_sharded_kv: bool = False
"""Enable MLA sharded KV mode for tensor parallelism.

When enabled with tensor parallelism (>1), MLA decode will gather query
tensors across TP ranks to form full queries per rank. Without this flag,
MLA with TP>1 is disallowed to avoid silent fallbacks.
"""

def compute_hash(self) -> str:
"""
Expand All @@ -490,6 +497,7 @@ def compute_hash(self) -> str:
factors.append(self.generation_config)
factors.append(self.model_impl)
factors.append(self.override_generation_config)
factors.append(self.enable_mla_sharded_kv)
factors.append(self.rope_scaling)
factors.append(self.rope_theta)
# hf_config can control how the model looks!
Expand Down
17 changes: 17 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,23 @@ def all_gatherv(self,
raise ValueError("No device communicator found")
return self.device_communicator.all_gatherv(input_, dim, sizes)

def all_to_all(self, input_: torch.Tensor, dim: int = 0) -> torch.Tensor:
"""All-to-all over the device group, splitting along dim equally.

Note: This is a simple wrapper for torch.distributed.all_to_all_single
with equal splits across ranks.
"""
world_size = self.world_size
if world_size == 1:
return input_
if dim < 0:
dim += input_.dim()
x = input_.movedim(dim, 0).contiguous()
assert x.shape[0] % world_size == 0
out = torch.empty_like(x)
torch.distributed.all_to_all_single(out, x, group=self.device_group)
return out.movedim(0, dim).contiguous()

def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
Expand Down
2 changes: 2 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
**model_kwargs["model_impl"])
model_group.add_argument("--override-attention-dtype",
**model_kwargs["override_attention_dtype"])
model_group.add_argument("--enable-mla-sharded-kv",
**model_kwargs["enable_mla_sharded_kv"])

# Model loading arguments
load_kwargs = get_kwargs(LoadConfig)
Expand Down
Loading