Skip to content

Commit 79573f5

Browse files
committed
fix: oopsie
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
1 parent 954b5ed commit 79573f5

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

csrc/attention/mla/cutlass_mla_entry.cu

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,21 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out,
2424
torch::Tensor const& seq_lens,
2525
torch::Tensor const& page_table, double scale);
2626
#else
27-
// define fallback stubs
28-
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
29-
torch::Tensor const& q_nope,
30-
torch::Tensor const& q_pe,
31-
torch::Tensor const& kv_c_and_k_pe_cache,
32-
torch::Tensor const& seq_lens,
33-
torch::Tensor const& page_table, double scale) {
27+
// fallback stubs
28+
void sm100_cutlass_mla_decode(
29+
torch::Tensor const& out, torch::Tensor const& q_nope,
30+
torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache,
31+
torch::Tensor const& seq_lens, torch::Tensor const& page_table,
32+
torch::Tensor const& workspace, double sm_scale,
33+
int64_t num_kv_splits =
34+
1 /* Set to 1 to avoid cuda_graph issue by default. */) {
3435
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA");
3536
}
3637

37-
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len,
38-
int64_t num_batches,
39-
int64_t sm_count,
40-
int64_t num_kv_splits) {
38+
int64_t sm100_cutlass_mla_get_workspace_size(
39+
int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0,
40+
int64_t num_kv_splits =
41+
1 /* Set to 1 to avoid cuda_graph issue by default. */) {
4142
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA");
4243
}
4344
#endif

0 commit comments

Comments
 (0)