@@ -482,41 +482,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
482482 " Tensor page_table, float scale) -> ()" );
483483 ops.impl (" cutlass_mla_decode" , torch::kCUDA , &cutlass_mla_decode);
484484
485- // Mamba selective scan kernel
486- ops.def (
487- " selective_scan_fwd(Tensor! u, Tensor! delta,"
488- " Tensor! A, Tensor! B, Tensor! C,"
489- " Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
490- " bool delta_softplus,"
491- " Tensor? query_start_loc,"
492- " Tensor? cache_indices,"
493- " Tensor? has_initial_state,"
494- " Tensor! ssm_states,"
495- " int pad_slot_id) -> ()" );
496- ops.impl (" selective_scan_fwd" , torch::kCUDA , &selective_scan_fwd);
497-
498- ops.def (
499- " causal_conv1d_update(Tensor! x,"
500- " Tensor! conv_state,"
501- " Tensor! weight,"
502- " Tensor? bias_,"
503- " bool silu_activation,"
504- " Tensor? cache_seqlens_,"
505- " Tensor? conv_state_indices,"
506- " int pad_slot_id) -> ()" );
507- ops.impl (" causal_conv1d_update" , torch::kCUDA , &causal_conv1d_update);
508-
509- ops.def (
510- " causal_conv1d_fwd(Tensor! x, Tensor! weight,"
511- " Tensor? bias_,"
512- " Tensor!? conv_states,"
513- " Tensor? query_start_loc,"
514- " Tensor? cache_indices,"
515- " Tensor? has_initial_state,"
516- " bool silu_activation,"
517- " int pad_slot_id) -> ()" );
518- ops.impl (" causal_conv1d_fwd" , torch::kCUDA , &causal_conv1d_fwd);
519-
520485 // Compute NVFP4 block quantized tensor.
521486 ops.def (
522487 " scaled_fp4_quant(Tensor! output, Tensor input,"
@@ -584,6 +549,41 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
584549 ops.impl (" dynamic_scaled_int8_quant" , torch::kCUDA ,
585550 &dynamic_scaled_int8_quant);
586551
552+ // Mamba selective scan kernel
553+ ops.def (
554+ " selective_scan_fwd(Tensor! u, Tensor! delta,"
555+ " Tensor! A, Tensor! B, Tensor! C,"
556+ " Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
557+ " bool delta_softplus,"
558+ " Tensor? query_start_loc,"
559+ " Tensor? cache_indices,"
560+ " Tensor? has_initial_state,"
561+ " Tensor! ssm_states,"
562+ " int pad_slot_id) -> ()" );
563+ ops.impl (" selective_scan_fwd" , torch::kCUDA , &selective_scan_fwd);
564+
565+ ops.def (
566+ " causal_conv1d_update(Tensor! x,"
567+ " Tensor! conv_state,"
568+ " Tensor! weight,"
569+ " Tensor? bias_,"
570+ " bool silu_activation,"
571+ " Tensor? cache_seqlens_,"
572+ " Tensor? conv_state_indices,"
573+ " int pad_slot_id) -> ()" );
574+ ops.impl (" causal_conv1d_update" , torch::kCUDA , &causal_conv1d_update);
575+
576+ ops.def (
577+ " causal_conv1d_fwd(Tensor! x, Tensor! weight,"
578+ " Tensor? bias_,"
579+ " Tensor!? conv_states,"
580+ " Tensor? query_start_loc,"
581+ " Tensor? cache_indices,"
582+ " Tensor? has_initial_state,"
583+ " bool silu_activation,"
584+ " int pad_slot_id) -> ()" );
585+ ops.impl (" causal_conv1d_fwd" , torch::kCUDA , &causal_conv1d_fwd);
586+
587587#ifndef USE_ROCM
588588 // reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
589589 ops.def (
0 commit comments