@@ -130,13 +130,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
130130 " ) -> ()" );
131131 ops.impl (" advance_step_flashinfer" , torch::kCUDA , &advance_step_flashinfer);
132132
133- // Compute MLA decode using cutlass.
134- ops.def (
135- " cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
136- " Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
137- " Tensor page_table, float scale) -> ()" );
138- ops.impl (" cutlass_mla_decode" , torch::kCUDA , &cutlass_mla_decode);
139-
140133 // Layernorm
141134 // Apply Root Mean Square (RMS) Normalization to the input tensor.
142135 ops.def (
@@ -450,6 +443,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
450443 ops.def (" cutlass_sparse_compress(Tensor a) -> Tensor[]" );
451444 ops.impl (" cutlass_sparse_compress" , &cutlass_sparse_compress);
452445
446+ // CUTLASS MLA decode
447+ ops.def (
448+ " cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
449+ " Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
450+ " Tensor page_table, float scale) -> ()" );
451+ ops.impl (" cutlass_mla_decode" , torch::kCUDA , &cutlass_mla_decode);
452+
453453 // Mamba selective scan kernel
454454 ops.def (
455455 " selective_scan_fwd(Tensor! u, Tensor! delta,"
0 commit comments