@@ -24,20 +24,21 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out,
2424 torch::Tensor const & seq_lens,
2525 torch::Tensor const & page_table, double scale);
2626#else
27- // define fallback stubs
28- void cutlass_mla_decode_sm100a (torch::Tensor const & out,
29- torch::Tensor const & q_nope,
30- torch::Tensor const & q_pe,
31- torch::Tensor const & kv_c_and_k_pe_cache,
32- torch::Tensor const & seq_lens,
33- torch::Tensor const & page_table, double scale) {
27+ // fallback stubs
28+ void sm100_cutlass_mla_decode (
29+ torch::Tensor const & out, torch::Tensor const & q_nope,
30+ torch::Tensor const & q_pe, torch::Tensor const & kv_c_and_k_pe_cache,
31+ torch::Tensor const & seq_lens, torch::Tensor const & page_table,
32+ torch::Tensor const & workspace, double sm_scale,
33+ int64_t num_kv_splits =
34+ 1 /* Set to 1 to avoid cuda_graph issue by default. */ ) {
3435 TORCH_CHECK_NOT_IMPLEMENTED (false , " No compiled cutlass MLA" );
3536}
3637
37- int64_t sm100_cutlass_mla_get_workspace_size (int64_t max_seq_len,
38- int64_t num_batches ,
39- int64_t sm_count,
40- int64_t num_kv_splits ) {
38+ int64_t sm100_cutlass_mla_get_workspace_size (
39+ int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0 ,
40+ int64_t num_kv_splits =
41+ 1 /* Set to 1 to avoid cuda_graph issue by default. */ ) {
4142 TORCH_CHECK_NOT_IMPLEMENTED (false , " No compiled cutlass MLA" );
4243}
4344#endif
0 commit comments