From 432e44679aefaae159f24c5dc4a89ade101500bc Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Thu, 2 Oct 2025 03:38:18 +0000 Subject: [PATCH 01/11] use tensor view --- csrc/batch_attention.cu | 22 +++-- csrc/batch_attention_jit_binding.cu | 23 ++--- csrc/batch_decode.cu | 20 ++-- csrc/batch_decode_jit_binding.cu | 20 ++-- csrc/batch_decode_mla_binding.cu | 23 +++-- csrc/batch_decode_mla_cute_sm80.cu | 20 ++-- csrc/batch_decode_mla_plan.cu | 8 +- csrc/batch_decode_mla_run.cu | 15 ++- csrc/batch_mla_binding.cu | 19 ++-- csrc/batch_mla_plan.cu | 11 ++- csrc/batch_mla_run.cu | 8 +- csrc/batch_mla_sm90_binding.cu | 20 ++-- csrc/batch_mla_sm90_plan.cu | 11 ++- csrc/batch_mla_sm90_run.cu | 9 +- csrc/batch_prefill.cu | 32 ++++--- csrc/batch_prefill_fp8_sm90.cu | 38 ++++---- csrc/batch_prefill_jit_binding.cu | 32 ++++--- csrc/batch_prefill_sm90.cu | 37 ++++--- csrc/batch_prefill_sm90_jit_binding.cu | 37 ++++--- csrc/blackwell_fmha_plan.cu | 8 +- csrc/bmm_fp8.cu | 6 +- csrc/cascade.cu | 9 +- csrc/cudnn_sdpa_kernel_launcher.cu | 43 +++++---- csrc/cutlass_mla.cu | 6 +- csrc/flashinfer_cascade_binding.cu | 9 +- csrc/flashinfer_gemm_binding.cu | 10 +- csrc/flashinfer_gemm_sm90_binding.cu | 9 +- csrc/flashinfer_mla_binding.cu | 5 +- csrc/flashinfer_norm_binding.cu | 12 +-- csrc/flashinfer_page_binding.cu | 23 +++-- csrc/flashinfer_quantization_binding.cu | 6 +- csrc/flashinfer_rope_binding.cu | 50 +++++----- csrc/flashinfer_sampling_binding.cu | 49 +++++----- csrc/flashinfer_xqa_binding.cu | 13 +-- csrc/fmha_cutlass_sm100.cu | 15 +-- csrc/fmha_cutlass_sm100_binding.cu | 20 ++-- csrc/fp4_gemm_cutlass.cu | 23 +++-- csrc/fp4_gemm_cutlass_sm120.cu | 25 +++-- csrc/fp8_gemm_cutlass.cu | 19 ++-- ...shinfer_cutlass_fused_moe_sm100_binding.cu | 14 ++- csrc/gemm_groupwise_sm100.cu | 9 +- csrc/gemm_groupwise_sm120.cu | 10 +- csrc/gemm_sm100_binding.cu | 9 +- csrc/gemm_sm120_binding.cu | 8 +- csrc/group_gemm.cu | 6 +- csrc/group_gemm_fp8_groupwise_sm100.cu | 12 +-- csrc/group_gemm_fp8_groupwise_sm120.cu | 7 +- csrc/group_gemm_mxfp4_groupwise_sm100.cu | 9 +- csrc/group_gemm_sm100_binding.cu | 21 ++-- csrc/group_gemm_sm120_binding.cu | 7 +- csrc/group_gemm_sm90.cu | 9 +- csrc/norm.cu | 12 +-- csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp | 15 +-- .../tensorrt_llm/thop/fp4Quantize.cpp | 17 ++-- .../tensorrt_llm/thop/fp4Quantize.h | 18 ++-- .../tensorrt_llm/thop/fp8Quantize.cpp | 10 +- .../tensorrt_llm/thop/fp8Quantize.h | 10 +- csrc/nvshmem_binding.cu | 27 +++--- csrc/page.cu | 23 +++-- csrc/pod.cu | 24 ++--- csrc/pod_jit_binding.cu | 22 +++-- csrc/quantization.cu | 8 +- csrc/renorm.cu | 12 +-- csrc/rope.cu | 40 ++++---- csrc/sampling.cu | 37 +++---- csrc/single_decode.cu | 4 +- csrc/single_decode_jit_binding.cu | 4 +- csrc/single_prefill.cu | 8 +- csrc/single_prefill_fp8_sm90.cu | 8 +- csrc/single_prefill_jit_binding.cu | 8 +- csrc/single_prefill_sm90.cu | 8 +- csrc/single_prefill_sm90_jit_binding.cu | 8 +- csrc/tgv_gemm.cu | 27 +++--- csrc/trtllm_allreduce.cu | 22 ++--- csrc/trtllm_allreduce_fusion.cu | 16 ++-- csrc/trtllm_alltoall.cu | 40 ++++---- csrc/trtllm_fmha_kernel_launcher.cu | 49 +++++----- csrc/trtllm_fused_moe_kernel_launcher.cu | 96 ++++++++++--------- csrc/trtllm_gemm_runner.cu | 7 +- csrc/trtllm_mnnvl_allreduce.cu | 12 +-- csrc/trtllm_moe_allreduce_fusion.cu | 33 +++---- csrc/tvm_ffi_utils.h | 12 ++- csrc/vllm_custom_all_reduce.cu | 23 +++-- csrc/xqa/xqa_wrapper.cu | 13 +-- flashinfer/deep_gemm.py | 9 +- flashinfer/gemm.py | 5 +- 86 files changed, 825 insertions(+), 748 deletions(-) diff --git a/csrc/batch_attention.cu b/csrc/batch_attention.cu index a9a3861dea..21c78e8a39 100644 --- a/csrc/batch_attention.cu +++ b/csrc/batch_attention.cu @@ -35,11 +35,12 @@ cudaError_t BatchPagedAttentionPersistent(const Params params_1, const Params pa using namespace flashinfer; -Array BatchPagedAttentionPlan(Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Tensor page_locked_int_workspace_buffer, Tensor qo_indptr, - Tensor kv_indptr, Tensor kv_len, int64_t batch_size, - int64_t num_qo_heads, int64_t num_kv_heads, - int64_t head_dim_o, bool causal) { +Array BatchPagedAttentionPlan(TensorView float_workspace_buffer, + TensorView int_workspace_buffer, + TensorView page_locked_int_workspace_buffer, + TensorView qo_indptr, TensorView kv_indptr, + TensorView kv_len, int64_t batch_size, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t head_dim_o, bool causal) { size_t float_workspace_size_in_bytes = float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = @@ -63,11 +64,12 @@ Array BatchPagedAttentionPlan(Tensor float_workspace_buffer, Tensor int return Array(plan_info.ToVector()); } -void BatchPagedAttentionRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Array plan_info_vec, Tensor q, Tensor k_cache, Tensor v_cache, - Tensor kv_indices, Tensor o, Optional maybe_lse, - int64_t mask_mode_code, int64_t layout_code, int64_t num_qo_heads, - int64_t num_kv_heads, int64_t page_size, +void BatchPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_workspace_buffer, + Array plan_info_vec, TensorView q, TensorView k_cache, + TensorView v_cache, TensorView kv_indices, TensorView o, + Optional maybe_lse, int64_t mask_mode_code, + int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads, + int64_t page_size, double v_scale, // must use double due to pytorch binding double sm_scale, double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS) { diff --git a/csrc/batch_attention_jit_binding.cu b/csrc/batch_attention_jit_binding.cu index 2a4d558887..1b25eb0a48 100644 --- a/csrc/batch_attention_jit_binding.cu +++ b/csrc/batch_attention_jit_binding.cu @@ -19,18 +19,19 @@ using tvm::ffi::Array; using tvm::ffi::Optional; -Array BatchPagedAttentionPlan(Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Tensor page_locked_int_workspace_buffer, Tensor qo_indptr, - Tensor kv_indptr, Tensor kv_len, int64_t batch_size, - int64_t num_qo_heads, int64_t num_kv_heads, - int64_t head_dim_o, bool causal); +Array BatchPagedAttentionPlan(TensorView float_workspace_buffer, + TensorView int_workspace_buffer, + TensorView page_locked_int_workspace_buffer, + TensorView qo_indptr, TensorView kv_indptr, + TensorView kv_len, int64_t batch_size, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t head_dim_o, bool causal); -void BatchPagedAttentionRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Array plan_info_vec, Tensor q, Tensor k_cache, Tensor v_cache, - Tensor kv_indices, Tensor o, Optional maybe_lse, - int64_t mask_mode_code, int64_t layout_code, int64_t num_qo_heads, - int64_t num_kv_heads, int64_t page_size, double v_scale, - double sm_scale, +void BatchPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_workspace_buffer, + Array plan_info_vec, TensorView q, TensorView k_cache, + TensorView v_cache, TensorView kv_indices, TensorView o, + Optional maybe_lse, int64_t mask_mode_code, + int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads, + int64_t page_size, double v_scale, double sm_scale, double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS); TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, &BatchPagedAttentionPlan); diff --git a/csrc/batch_decode.cu b/csrc/batch_decode.cu index 7ffa762e87..afb105f442 100644 --- a/csrc/batch_decode.cu +++ b/csrc/batch_decode.cu @@ -37,11 +37,11 @@ using tvm::ffi::Array; using tvm::ffi::Optional; Array BatchDecodeWithPagedKVCachePlan( - Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Tensor page_locked_int_workspace_buffer, Tensor indptr, int64_t batch_size, + TensorView float_workspace_buffer, TensorView int_workspace_buffer, + TensorView page_locked_int_workspace_buffer, TensorView indptr, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, int64_t head_dim_vo, - Tensor empty_q_data, Tensor empty_kv_data) { + TensorView empty_q_data, TensorView empty_kv_data) { size_t float_workspace_size_in_bytes = float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = @@ -78,12 +78,14 @@ Array BatchDecodeWithPagedKVCachePlan( return Array(plan_info.ToVector()); } -void BatchDecodeWithPagedKVCacheRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Array plan_info_vec, Tensor q, Tensor paged_k_cache, - Tensor paged_v_cache, Tensor paged_kv_indptr, - Tensor paged_kv_indices, Tensor paged_kv_last_page_len, - Tensor o, Optional maybe_lse, int64_t kv_layout_code, - int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS) { +void BatchDecodeWithPagedKVCacheRun(TensorView float_workspace_buffer, + TensorView int_workspace_buffer, Array plan_info_vec, + TensorView q, TensorView paged_k_cache, + TensorView paged_v_cache, TensorView paged_kv_indptr, + TensorView paged_kv_indices, TensorView paged_kv_last_page_len, + TensorView o, Optional maybe_lse, + int64_t kv_layout_code, int64_t window_left, + bool enable_pdl ADDITIONAL_FUNC_PARAMS) { DecodePlanInfo plan_info; plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); QKVLayout kv_layout = static_cast(kv_layout_code); diff --git a/csrc/batch_decode_jit_binding.cu b/csrc/batch_decode_jit_binding.cu index c621eaf1d6..0ce644fbc2 100644 --- a/csrc/batch_decode_jit_binding.cu +++ b/csrc/batch_decode_jit_binding.cu @@ -21,18 +21,20 @@ using tvm::ffi::Array; using tvm::ffi::Optional; Array BatchDecodeWithPagedKVCachePlan( - Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Tensor page_locked_int_workspace_buffer, Tensor indptr, int64_t batch_size, + TensorView float_workspace_buffer, TensorView int_workspace_buffer, + TensorView page_locked_int_workspace_buffer, TensorView indptr, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, int64_t head_dim_vo, - Tensor empty_q_data, Tensor empty_kv_data); + TensorView empty_q_data, TensorView empty_kv_data); -void BatchDecodeWithPagedKVCacheRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Array plan_info_vec, Tensor q, Tensor paged_k_cache, - Tensor paged_v_cache, Tensor paged_kv_indptr, - Tensor paged_kv_indices, Tensor paged_kv_last_page_len, - Tensor o, Optional maybe_lse, int64_t kv_layout_code, - int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS); +void BatchDecodeWithPagedKVCacheRun(TensorView float_workspace_buffer, + TensorView int_workspace_buffer, Array plan_info_vec, + TensorView q, TensorView paged_k_cache, + TensorView paged_v_cache, TensorView paged_kv_indptr, + TensorView paged_kv_indices, TensorView paged_kv_last_page_len, + TensorView o, Optional maybe_lse, + int64_t kv_layout_code, int64_t window_left, + bool enable_pdl ADDITIONAL_FUNC_PARAMS); // Batched decode with paged KV-Cache plan TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, BatchDecodeWithPagedKVCachePlan); diff --git a/csrc/batch_decode_mla_binding.cu b/csrc/batch_decode_mla_binding.cu index 8cb9a08db2..3bb416d971 100644 --- a/csrc/batch_decode_mla_binding.cu +++ b/csrc/batch_decode_mla_binding.cu @@ -5,21 +5,20 @@ using tvm::ffi::Array; using tvm::ffi::Optional; -Array BatchDecodeWithPagedKVCachePlanMLA(Tensor float_workspace_buffer, - Tensor int_workspace_buffer, - Tensor page_locked_int_workspace_buffer, - Tensor indptr, int64_t batch_size, +Array BatchDecodeWithPagedKVCachePlanMLA(TensorView float_workspace_buffer, + TensorView int_workspace_buffer, + TensorView page_locked_int_workspace_buffer, + TensorView indptr, int64_t batch_size, int64_t num_qo_heads, int64_t page_size, bool enable_cuda_graph); -void BatchDecodeWithPagedKVCacheRunMLA(Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Array plan_info_vec, Tensor q_nope, Tensor q_pe, - Tensor paged_ckv_cache, Tensor paged_kpe_cache, - Tensor paged_kv_indptr, Tensor paged_kv_indices, - Tensor paged_kv_last_page_len, Tensor o, double sm_scale, - int64_t window_left, double logits_soft_cap, - double rope_scale, double rope_theta, - Optional maybe_lse, bool enable_pdl); +void BatchDecodeWithPagedKVCacheRunMLA( + TensorView float_workspace_buffer, TensorView int_workspace_buffer, + Array plan_info_vec, TensorView q_nope, TensorView q_pe, TensorView paged_ckv_cache, + TensorView paged_kpe_cache, TensorView paged_kv_indptr, TensorView paged_kv_indices, + TensorView paged_kv_last_page_len, TensorView o, double sm_scale, int64_t window_left, + double logits_soft_cap, double rope_scale, double rope_theta, Optional maybe_lse, + bool enable_pdl); TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, BatchDecodeWithPagedKVCachePlanMLA); TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, BatchDecodeWithPagedKVCacheRunMLA); diff --git a/csrc/batch_decode_mla_cute_sm80.cu b/csrc/batch_decode_mla_cute_sm80.cu index 3ed6f0a41e..d96190e539 100644 --- a/csrc/batch_decode_mla_cute_sm80.cu +++ b/csrc/batch_decode_mla_cute_sm80.cu @@ -11,10 +11,10 @@ using namespace flashinfer; using tvm::ffi::Array; using tvm::ffi::Optional; -Array BatchDecodeWithPagedKVCachePlanMLA(ffi::Tensor float_workspace_buffer, - ffi::Tensor int_workspace_buffer, - ffi::Tensor page_locked_int_workspace_buffer, - ffi::Tensor indptr, int64_t batch_size, +Array BatchDecodeWithPagedKVCachePlanMLA(ffi::TensorView float_workspace_buffer, + ffi::TensorView int_workspace_buffer, + ffi::TensorView page_locked_int_workspace_buffer, + ffi::TensorView indptr, int64_t batch_size, int64_t num_qo_heads, int64_t page_size, bool enable_cuda_graph) { size_t float_workspace_size_in_bytes = @@ -43,11 +43,13 @@ Array BatchDecodeWithPagedKVCachePlanMLA(ffi::Tensor float_workspace_bu } void BatchDecodeWithPagedKVCacheRunMLA( - ffi::Tensor float_workspace_buffer, ffi::Tensor int_workspace_buffer, - Array plan_info_vec, ffi::Tensor q_nope, ffi::Tensor q_pe, ffi::Tensor paged_ckv_cache, - ffi::Tensor paged_kpe_cache, ffi::Tensor paged_kv_indptr, ffi::Tensor paged_kv_indices, - ffi::Tensor paged_kv_last_page_len, ffi::Tensor o, double sm_scale, int64_t window_left, - double logits_soft_cap, double rope_scale, double rope_theta, Optional maybe_lse, + ffi::TensorView float_workspace_buffer, ffi::TensorView int_workspace_buffer, + Array plan_info_vec, ffi::TensorView q_nope, ffi::TensorView q_pe, + ffi::TensorView paged_ckv_cache, ffi::TensorView paged_kpe_cache, + ffi::TensorView paged_kv_indptr, ffi::TensorView paged_kv_indices, + ffi::TensorView paged_kv_last_page_len, ffi::TensorView o, double sm_scale, int64_t window_left, + double logits_soft_cap, double rope_scale, double rope_theta, + Optional maybe_lse, bool enable_pdl // fake placeholder, sm80 does not support pdl ) { DecodePlanInfo plan_info; diff --git a/csrc/batch_decode_mla_plan.cu b/csrc/batch_decode_mla_plan.cu index f131589bb3..d7c41c90ca 100644 --- a/csrc/batch_decode_mla_plan.cu +++ b/csrc/batch_decode_mla_plan.cu @@ -9,10 +9,10 @@ using namespace flashinfer; using tvm::ffi::Array; -Array BatchDecodeWithPagedKVCachePlanMLA(Tensor float_workspace_buffer, - Tensor int_workspace_buffer, - Tensor page_locked_int_workspace_buffer, - Tensor indptr, int64_t batch_size, +Array BatchDecodeWithPagedKVCachePlanMLA(TensorView float_workspace_buffer, + TensorView int_workspace_buffer, + TensorView page_locked_int_workspace_buffer, + TensorView indptr, int64_t batch_size, int64_t num_qo_heads, int64_t page_size, bool enable_cuda_graph) { cudaSetDevice(float_workspace_buffer->device.device_id); diff --git a/csrc/batch_decode_mla_run.cu b/csrc/batch_decode_mla_run.cu index 907d28b927..f4cef6dc4a 100644 --- a/csrc/batch_decode_mla_run.cu +++ b/csrc/batch_decode_mla_run.cu @@ -10,14 +10,13 @@ using namespace flashinfer; using tvm::ffi::Array; using tvm::ffi::Optional; -void BatchDecodeWithPagedKVCacheRunMLA(Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Array plan_info_vec, Tensor q_nope, Tensor q_pe, - Tensor paged_ckv_cache, Tensor paged_kpe_cache, - Tensor paged_kv_indptr, Tensor paged_kv_indices, - Tensor paged_kv_last_page_len, Tensor o, double sm_scale, - int64_t window_left, double logits_soft_cap, - double rope_scale, double rope_theta, - Optional maybe_lse, bool enable_pdl) { +void BatchDecodeWithPagedKVCacheRunMLA( + TensorView float_workspace_buffer, TensorView int_workspace_buffer, + Array plan_info_vec, TensorView q_nope, TensorView q_pe, TensorView paged_ckv_cache, + TensorView paged_kpe_cache, TensorView paged_kv_indptr, TensorView paged_kv_indices, + TensorView paged_kv_last_page_len, TensorView o, double sm_scale, int64_t window_left, + double logits_soft_cap, double rope_scale, double rope_theta, Optional maybe_lse, + bool enable_pdl) { DecodePlanInfo plan_info; plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); diff --git a/csrc/batch_mla_binding.cu b/csrc/batch_mla_binding.cu index f5aba576a6..6822e28b93 100644 --- a/csrc/batch_mla_binding.cu +++ b/csrc/batch_mla_binding.cu @@ -20,16 +20,17 @@ using tvm::ffi::Array; using tvm::ffi::Optional; -Array BatchMLAPagedAttentionPlan(Tensor float_workspace_buffer, - Tensor int_workspace_buffer, - Tensor page_locked_int_workspace_buffer, Tensor qo_indptr, - Tensor kv_indptr, Tensor kv_len, int64_t num_heads, - int64_t head_dim_o, bool causal); +Array BatchMLAPagedAttentionPlan(TensorView float_workspace_buffer, + TensorView int_workspace_buffer, + TensorView page_locked_int_workspace_buffer, + TensorView qo_indptr, TensorView kv_indptr, + TensorView kv_len, int64_t num_heads, int64_t head_dim_o, + bool causal); -void BatchMLAPagedAttentionRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Array plan_info_vec, Tensor q_nope, Tensor q_pe, - Tensor ckv_cache, Tensor kpe_cache, Tensor kv_indices, Tensor o, - Optional maybe_lse, int64_t mask_mode_code, +void BatchMLAPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_workspace_buffer, + Array plan_info_vec, TensorView q_nope, TensorView q_pe, + TensorView ckv_cache, TensorView kpe_cache, TensorView kv_indices, + TensorView o, Optional maybe_lse, int64_t mask_mode_code, int64_t num_heads, int64_t page_size, double sm_scale); TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, BatchMLAPagedAttentionPlan); diff --git a/csrc/batch_mla_plan.cu b/csrc/batch_mla_plan.cu index 81855dc75c..6715b5db2f 100644 --- a/csrc/batch_mla_plan.cu +++ b/csrc/batch_mla_plan.cu @@ -23,11 +23,12 @@ using namespace flashinfer; using tvm::ffi::Array; -Array BatchMLAPagedAttentionPlan(Tensor float_workspace_buffer, - Tensor int_workspace_buffer, - Tensor page_locked_int_workspace_buffer, Tensor qo_indptr, - Tensor kv_indptr, Tensor kv_len, int64_t num_heads, - int64_t head_dim_o, bool causal) { +Array BatchMLAPagedAttentionPlan(TensorView float_workspace_buffer, + TensorView int_workspace_buffer, + TensorView page_locked_int_workspace_buffer, + TensorView qo_indptr, TensorView kv_indptr, + TensorView kv_len, int64_t num_heads, int64_t head_dim_o, + bool causal) { size_t float_workspace_size_in_bytes = float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = diff --git a/csrc/batch_mla_run.cu b/csrc/batch_mla_run.cu index f80ef47bbf..de7acedb04 100644 --- a/csrc/batch_mla_run.cu +++ b/csrc/batch_mla_run.cu @@ -27,10 +27,10 @@ using namespace flashinfer; using tvm::ffi::Array; using tvm::ffi::Optional; -void BatchMLAPagedAttentionRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Array plan_info_vec, Tensor q_nope, Tensor q_pe, - Tensor ckv_cache, Tensor kpe_cache, Tensor kv_indices, Tensor o, - Optional maybe_lse, int64_t mask_mode_code, +void BatchMLAPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_workspace_buffer, + Array plan_info_vec, TensorView q_nope, TensorView q_pe, + TensorView ckv_cache, TensorView kpe_cache, TensorView kv_indices, + TensorView o, Optional maybe_lse, int64_t mask_mode_code, int64_t num_heads, int64_t page_size, double sm_scale) { // q_nope: [n, num_heads, head_dim_ckv] // q_pe: [n, num_heads, head_dim_kpe] diff --git a/csrc/batch_mla_sm90_binding.cu b/csrc/batch_mla_sm90_binding.cu index 18bdd41549..2e6cd1aa7d 100644 --- a/csrc/batch_mla_sm90_binding.cu +++ b/csrc/batch_mla_sm90_binding.cu @@ -20,16 +20,18 @@ using tvm::ffi::Array; using tvm::ffi::Optional; -Array BatchMLAPagedAttentionSM90Plan(Tensor float_workspace_buffer, - Tensor int_workspace_buffer, - Tensor page_locked_int_workspace_buffer, - Tensor qo_indptr, Tensor kv_indptr, Tensor kv_len, - int64_t num_heads, int64_t head_dim_o, bool causal); +Array BatchMLAPagedAttentionSM90Plan(TensorView float_workspace_buffer, + TensorView int_workspace_buffer, + TensorView page_locked_int_workspace_buffer, + TensorView qo_indptr, TensorView kv_indptr, + TensorView kv_len, int64_t num_heads, + int64_t head_dim_o, bool causal); -void BatchMLAPagedAttentionSM90Run(Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Array plan_info_vec, Tensor q_nope, Tensor q_pe, - Tensor ckv_cache, Tensor kpe_cache, Tensor kv_indices, Tensor o, - Optional maybe_lse, int64_t mask_mode_code, +void BatchMLAPagedAttentionSM90Run(TensorView float_workspace_buffer, + TensorView int_workspace_buffer, Array plan_info_vec, + TensorView q_nope, TensorView q_pe, TensorView ckv_cache, + TensorView kpe_cache, TensorView kv_indices, TensorView o, + Optional maybe_lse, int64_t mask_mode_code, int64_t num_heads, int64_t page_size, double sm_scale ADDITIONAL_FUNC_PARAMS); diff --git a/csrc/batch_mla_sm90_plan.cu b/csrc/batch_mla_sm90_plan.cu index 76750598b0..78427ffe57 100644 --- a/csrc/batch_mla_sm90_plan.cu +++ b/csrc/batch_mla_sm90_plan.cu @@ -23,11 +23,12 @@ using namespace flashinfer; using tvm::ffi::Array; -Array BatchMLAPagedAttentionSM90Plan(Tensor float_workspace_buffer, - Tensor int_workspace_buffer, - Tensor page_locked_int_workspace_buffer, - Tensor qo_indptr, Tensor kv_indptr, Tensor kv_len, - int64_t num_heads, int64_t head_dim_o, bool causal) { +Array BatchMLAPagedAttentionSM90Plan(TensorView float_workspace_buffer, + TensorView int_workspace_buffer, + TensorView page_locked_int_workspace_buffer, + TensorView qo_indptr, TensorView kv_indptr, + TensorView kv_len, int64_t num_heads, + int64_t head_dim_o, bool causal) { size_t float_workspace_size_in_bytes = float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = diff --git a/csrc/batch_mla_sm90_run.cu b/csrc/batch_mla_sm90_run.cu index 98e76002a9..efb744e9f2 100644 --- a/csrc/batch_mla_sm90_run.cu +++ b/csrc/batch_mla_sm90_run.cu @@ -26,10 +26,11 @@ using namespace flashinfer; using tvm::ffi::Array; using tvm::ffi::Optional; -void BatchMLAPagedAttentionSM90Run(Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Array plan_info_vec, Tensor q_nope, Tensor q_pe, - Tensor ckv_cache, Tensor kpe_cache, Tensor kv_indices, Tensor o, - Optional maybe_lse, int64_t mask_mode_code, +void BatchMLAPagedAttentionSM90Run(TensorView float_workspace_buffer, + TensorView int_workspace_buffer, Array plan_info_vec, + TensorView q_nope, TensorView q_pe, TensorView ckv_cache, + TensorView kpe_cache, TensorView kv_indices, TensorView o, + Optional maybe_lse, int64_t mask_mode_code, int64_t num_heads, int64_t page_size, double sm_scale ADDITIONAL_FUNC_PARAMS) { // q_nope: [n, num_heads, head_dim_ckv] diff --git a/csrc/batch_prefill.cu b/csrc/batch_prefill.cu index a76ad1ae29..796295a2ea 100644 --- a/csrc/batch_prefill.cu +++ b/csrc/batch_prefill.cu @@ -45,11 +45,12 @@ using tvm::ffi::Array; using tvm::ffi::Optional; Array BatchPrefillWithKVCachePlan( - Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Tensor page_locked_int_workspace_buffer, Tensor qo_indptr, Tensor kv_indptr, Tensor kv_len_arr, - int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, - int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, - bool causal, int64_t window_left, int64_t fixed_split_size, bool disable_split_kv) { + TensorView float_workspace_buffer, TensorView int_workspace_buffer, + TensorView page_locked_int_workspace_buffer, TensorView qo_indptr, TensorView kv_indptr, + TensorView kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, + int64_t head_dim_vo, bool causal, int64_t window_left, int64_t fixed_split_size, + bool disable_split_kv) { size_t float_workspace_size_in_bytes = float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = @@ -73,10 +74,11 @@ Array BatchPrefillWithKVCachePlan( return Array(plan_info.ToVector()); } -void BatchPrefillWithRaggedKVCacheRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Array plan_info_vec, Tensor q, Tensor k, Tensor v, - Tensor qo_indptr, Tensor kv_indptr, Tensor o, - Optional maybe_lse, int64_t mask_mode_code, +void BatchPrefillWithRaggedKVCacheRun(TensorView float_workspace_buffer, + TensorView int_workspace_buffer, Array plan_info_vec, + TensorView q, TensorView k, TensorView v, + TensorView qo_indptr, TensorView kv_indptr, TensorView o, + Optional maybe_lse, int64_t mask_mode_code, int64_t layout, int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS) { PrefillPlanInfo plan_info; @@ -196,11 +198,13 @@ void BatchPrefillWithRaggedKVCacheRun(Tensor float_workspace_buffer, Tensor int_ }); } -void BatchPrefillWithPagedKVCacheRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Array plan_info_vec, Tensor q, Tensor paged_k_cache, - Tensor paged_v_cache, Tensor qo_indptr, Tensor paged_kv_indptr, - Tensor paged_kv_indices, Tensor paged_kv_last_page_len, - Tensor o, Optional maybe_lse, int64_t mask_mode_code, +void BatchPrefillWithPagedKVCacheRun(TensorView float_workspace_buffer, + TensorView int_workspace_buffer, Array plan_info_vec, + TensorView q, TensorView paged_k_cache, + TensorView paged_v_cache, TensorView qo_indptr, + TensorView paged_kv_indptr, TensorView paged_kv_indices, + TensorView paged_kv_last_page_len, TensorView o, + Optional maybe_lse, int64_t mask_mode_code, int64_t layout, int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS) { PrefillPlanInfo plan_info; diff --git a/csrc/batch_prefill_fp8_sm90.cu b/csrc/batch_prefill_fp8_sm90.cu index 82e979065a..5e221fa2ff 100644 --- a/csrc/batch_prefill_fp8_sm90.cu +++ b/csrc/batch_prefill_fp8_sm90.cu @@ -37,11 +37,12 @@ using tvm::ffi::Array; using tvm::ffi::Optional; Array BatchPrefillWithKVCacheSM90Plan( - ffi::Tensor float_workspace_buffer, ffi::Tensor int_workspace_buffer, - ffi::Tensor page_locked_int_workspace_buffer, ffi::Tensor qo_indptr, ffi::Tensor kv_indptr, - ffi::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, - int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, - int64_t head_dim_vo, bool causal, int64_t window_left) { + ffi::TensorView float_workspace_buffer, ffi::TensorView int_workspace_buffer, + ffi::TensorView page_locked_int_workspace_buffer, ffi::TensorView qo_indptr, + ffi::TensorView kv_indptr, ffi::TensorView kv_len_arr, int64_t total_num_rows, + int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, + bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal, + int64_t window_left) { size_t float_workspace_size_in_bytes = float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = @@ -66,25 +67,26 @@ Array BatchPrefillWithKVCacheSM90Plan( return Array(plan_info.ToVector()); } -void BatchPrefillWithRaggedKVCacheSM90Run(ffi::Tensor float_workspace_buffer, - ffi::Tensor int_workspace_buffer, - Array plan_info_vec, ffi::Tensor q, - ffi::Tensor k, ffi::Tensor v, ffi::Tensor qo_indptr, - ffi::Tensor kv_indptr, ffi::Tensor o, - Optional maybe_lse, int64_t mask_mode_code, - int64_t layout, int64_t window_left, +void BatchPrefillWithRaggedKVCacheSM90Run(ffi::TensorView float_workspace_buffer, + ffi::TensorView int_workspace_buffer, + Array plan_info_vec, ffi::TensorView q, + ffi::TensorView k, ffi::TensorView v, + ffi::TensorView qo_indptr, ffi::TensorView kv_indptr, + ffi::TensorView o, Optional maybe_lse, + int64_t mask_mode_code, int64_t layout, + int64_t window_left, bool enable_pdl // placeholder ADDITIONAL_FUNC_PARAMS) { return; // TODO: Implement this function } void BatchPrefillWithPagedKVCacheSM90Run( - ffi::Tensor float_workspace_buffer, ffi::Tensor int_workspace_buffer, - Array plan_info_vec, ffi::Tensor q, ffi::Tensor paged_k_cache, - ffi::Tensor paged_v_cache, ffi::Tensor qo_indptr, ffi::Tensor paged_kv_indptr, - ffi::Tensor paged_kv_indices, ffi::Tensor paged_kv_last_page_len, ffi::Tensor o, - Optional maybe_lse, int64_t mask_mode_code, int64_t layout, int64_t window_left, - bool enable_pdl ADDITIONAL_FUNC_PARAMS) { + ffi::TensorView float_workspace_buffer, ffi::TensorView int_workspace_buffer, + Array plan_info_vec, ffi::TensorView q, ffi::TensorView paged_k_cache, + ffi::TensorView paged_v_cache, ffi::TensorView qo_indptr, ffi::TensorView paged_kv_indptr, + ffi::TensorView paged_kv_indices, ffi::TensorView paged_kv_last_page_len, ffi::TensorView o, + Optional maybe_lse, int64_t mask_mode_code, int64_t layout, + int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS) { PrefillPlanSM90Info plan_info; plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); diff --git a/csrc/batch_prefill_jit_binding.cu b/csrc/batch_prefill_jit_binding.cu index 58a9a553e5..da1e1981dc 100644 --- a/csrc/batch_prefill_jit_binding.cu +++ b/csrc/batch_prefill_jit_binding.cu @@ -20,24 +20,28 @@ using tvm::ffi::Array; using tvm::ffi::Optional; Array BatchPrefillWithKVCachePlan( - Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Tensor page_locked_int_workspace_buffer, Tensor qo_indptr, Tensor kv_indptr, Tensor kv_len_arr, - int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, - int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, - bool causal, int64_t window_left, int64_t fixed_split_size, bool disable_split_kv); + TensorView float_workspace_buffer, TensorView int_workspace_buffer, + TensorView page_locked_int_workspace_buffer, TensorView qo_indptr, TensorView kv_indptr, + TensorView kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, + int64_t head_dim_vo, bool causal, int64_t window_left, int64_t fixed_split_size, + bool disable_split_kv); -void BatchPrefillWithRaggedKVCacheRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Array plan_info_vec, Tensor q, Tensor k, Tensor v, - Tensor qo_indptr, Tensor kv_indptr, Tensor o, - Optional maybe_lse, int64_t mask_mode_code, +void BatchPrefillWithRaggedKVCacheRun(TensorView float_workspace_buffer, + TensorView int_workspace_buffer, Array plan_info_vec, + TensorView q, TensorView k, TensorView v, + TensorView qo_indptr, TensorView kv_indptr, TensorView o, + Optional maybe_lse, int64_t mask_mode_code, int64_t layout, int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS); -void BatchPrefillWithPagedKVCacheRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Array plan_info_vec, Tensor q, Tensor paged_k_cache, - Tensor paged_v_cache, Tensor qo_indptr, Tensor paged_kv_indptr, - Tensor paged_kv_indices, Tensor paged_kv_last_page_len, - Tensor o, Optional maybe_lse, int64_t mask_mode_code, +void BatchPrefillWithPagedKVCacheRun(TensorView float_workspace_buffer, + TensorView int_workspace_buffer, Array plan_info_vec, + TensorView q, TensorView paged_k_cache, + TensorView paged_v_cache, TensorView qo_indptr, + TensorView paged_kv_indptr, TensorView paged_kv_indices, + TensorView paged_kv_last_page_len, TensorView o, + Optional maybe_lse, int64_t mask_mode_code, int64_t layout, int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS); diff --git a/csrc/batch_prefill_sm90.cu b/csrc/batch_prefill_sm90.cu index 92817ea5c7..a43060eba3 100644 --- a/csrc/batch_prefill_sm90.cu +++ b/csrc/batch_prefill_sm90.cu @@ -43,11 +43,12 @@ using tvm::ffi::Array; using tvm::ffi::Optional; Array BatchPrefillWithKVCacheSM90Plan( - ffi::Tensor float_workspace_buffer, ffi::Tensor int_workspace_buffer, - ffi::Tensor page_locked_int_workspace_buffer, ffi::Tensor qo_indptr, ffi::Tensor kv_indptr, - ffi::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, - int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, - int64_t head_dim_vo, bool causal, int64_t window_left) { + ffi::TensorView float_workspace_buffer, ffi::TensorView int_workspace_buffer, + ffi::TensorView page_locked_int_workspace_buffer, ffi::TensorView qo_indptr, + ffi::TensorView kv_indptr, ffi::TensorView kv_len_arr, int64_t total_num_rows, + int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, + bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal, + int64_t window_left) { size_t float_workspace_size_in_bytes = float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = @@ -72,14 +73,12 @@ Array BatchPrefillWithKVCacheSM90Plan( return Array(plan_info.ToVector()); } -void BatchPrefillWithRaggedKVCacheSM90Run(ffi::Tensor float_workspace_buffer, - ffi::Tensor int_workspace_buffer, - Array plan_info_vec, ffi::Tensor q, - ffi::Tensor k, ffi::Tensor v, ffi::Tensor qo_indptr, - ffi::Tensor kv_indptr, ffi::Tensor o, - Optional maybe_lse, int64_t mask_mode_code, - int64_t layout, int64_t window_left, - bool enable_pdl ADDITIONAL_FUNC_PARAMS) { +void BatchPrefillWithRaggedKVCacheSM90Run( + ffi::TensorView float_workspace_buffer, ffi::TensorView int_workspace_buffer, + Array plan_info_vec, ffi::TensorView q, ffi::TensorView k, ffi::TensorView v, + ffi::TensorView qo_indptr, ffi::TensorView kv_indptr, ffi::TensorView o, + Optional maybe_lse, int64_t mask_mode_code, int64_t layout, + int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS) { PrefillPlanSM90Info plan_info; plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); @@ -164,12 +163,12 @@ void BatchPrefillWithRaggedKVCacheSM90Run(ffi::Tensor float_workspace_buffer, } void BatchPrefillWithPagedKVCacheSM90Run( - ffi::Tensor float_workspace_buffer, ffi::Tensor int_workspace_buffer, - Array plan_info_vec, ffi::Tensor q, ffi::Tensor paged_k_cache, - ffi::Tensor paged_v_cache, ffi::Tensor qo_indptr, ffi::Tensor paged_kv_indptr, - ffi::Tensor paged_kv_indices, ffi::Tensor paged_kv_last_page_len, ffi::Tensor o, - Optional maybe_lse, int64_t mask_mode_code, int64_t layout, int64_t window_left, - bool enable_pdl ADDITIONAL_FUNC_PARAMS) { + ffi::TensorView float_workspace_buffer, ffi::TensorView int_workspace_buffer, + Array plan_info_vec, ffi::TensorView q, ffi::TensorView paged_k_cache, + ffi::TensorView paged_v_cache, ffi::TensorView qo_indptr, ffi::TensorView paged_kv_indptr, + ffi::TensorView paged_kv_indices, ffi::TensorView paged_kv_last_page_len, ffi::TensorView o, + Optional maybe_lse, int64_t mask_mode_code, int64_t layout, + int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS) { PrefillPlanSM90Info plan_info; plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); diff --git a/csrc/batch_prefill_sm90_jit_binding.cu b/csrc/batch_prefill_sm90_jit_binding.cu index 5146d324c8..259a1e22d9 100644 --- a/csrc/batch_prefill_sm90_jit_binding.cu +++ b/csrc/batch_prefill_sm90_jit_binding.cu @@ -22,28 +22,27 @@ using tvm::ffi::Array; using tvm::ffi::Optional; Array BatchPrefillWithKVCacheSM90Plan( - ffi::Tensor float_workspace_buffer, ffi::Tensor int_workspace_buffer, - ffi::Tensor page_locked_int_workspace_buffer, ffi::Tensor qo_indptr, ffi::Tensor kv_indptr, - ffi::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, - int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, - int64_t head_dim_vo, bool causal, int64_t window_left); + ffi::TensorView float_workspace_buffer, ffi::TensorView int_workspace_buffer, + ffi::TensorView page_locked_int_workspace_buffer, ffi::TensorView qo_indptr, + ffi::TensorView kv_indptr, ffi::TensorView kv_len_arr, int64_t total_num_rows, + int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, + bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal, + int64_t window_left); -void BatchPrefillWithRaggedKVCacheSM90Run(ffi::Tensor float_workspace_buffer, - ffi::Tensor int_workspace_buffer, - Array plan_info_vec, ffi::Tensor q, - ffi::Tensor k, ffi::Tensor v, ffi::Tensor qo_indptr, - ffi::Tensor kv_indptr, ffi::Tensor o, - Optional maybe_lse, int64_t mask_mode_code, - int64_t layout, int64_t window_left, - bool enable_pdl ADDITIONAL_FUNC_PARAMS); +void BatchPrefillWithRaggedKVCacheSM90Run( + ffi::TensorView float_workspace_buffer, ffi::TensorView int_workspace_buffer, + Array plan_info_vec, ffi::TensorView q, ffi::TensorView k, ffi::TensorView v, + ffi::TensorView qo_indptr, ffi::TensorView kv_indptr, ffi::TensorView o, + Optional maybe_lse, int64_t mask_mode_code, int64_t layout, + int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS); void BatchPrefillWithPagedKVCacheSM90Run( - ffi::Tensor float_workspace_buffer, ffi::Tensor int_workspace_buffer, - Array plan_info_vec, ffi::Tensor q, ffi::Tensor paged_k_cache, - ffi::Tensor paged_v_cache, ffi::Tensor qo_indptr, ffi::Tensor paged_kv_indptr, - ffi::Tensor paged_kv_indices, ffi::Tensor paged_kv_last_page_len, ffi::Tensor o, - Optional maybe_lse, int64_t mask_mode_code, int64_t layout, int64_t window_left, - bool enable_pdl ADDITIONAL_FUNC_PARAMS); + ffi::TensorView float_workspace_buffer, ffi::TensorView int_workspace_buffer, + Array plan_info_vec, ffi::TensorView q, ffi::TensorView paged_k_cache, + ffi::TensorView paged_v_cache, ffi::TensorView qo_indptr, ffi::TensorView paged_kv_indptr, + ffi::TensorView paged_kv_indices, ffi::TensorView paged_kv_last_page_len, ffi::TensorView o, + Optional maybe_lse, int64_t mask_mode_code, int64_t layout, + int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS); TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, BatchPrefillWithKVCacheSM90Plan); TVM_FFI_DLL_EXPORT_TYPED_FUNC(ragged_run, BatchPrefillWithRaggedKVCacheSM90Run); diff --git a/csrc/blackwell_fmha_plan.cu b/csrc/blackwell_fmha_plan.cu index 2720fce024..f976ce5b0b 100644 --- a/csrc/blackwell_fmha_plan.cu +++ b/csrc/blackwell_fmha_plan.cu @@ -17,10 +17,10 @@ #include "flashinfer/attention/blackwell/plan.cuh" #include "tvm_ffi_utils.h" -void blackwell_fmha_plan(Tensor qo_segment_offsets, Tensor kv_segment_offsets, Tensor work_indptr, - Tensor qo_tile_indices, Tensor head_indices, Tensor batch_indices, - int64_t qo_tile_size, int64_t num_heads, int64_t num_buckets, - bool causal) { +void blackwell_fmha_plan(TensorView qo_segment_offsets, TensorView kv_segment_offsets, + TensorView work_indptr, TensorView qo_tile_indices, + TensorView head_indices, TensorView batch_indices, int64_t qo_tile_size, + int64_t num_heads, int64_t num_buckets, bool causal) { cudaSetDevice(qo_segment_offsets->device.device_id); const cudaStream_t stream = get_stream(qo_tile_indices->device); int batch_size = qo_segment_offsets->shape[0] - 1; diff --git a/csrc/bmm_fp8.cu b/csrc/bmm_fp8.cu index 9115bb6378..2709191316 100644 --- a/csrc/bmm_fp8.cu +++ b/csrc/bmm_fp8.cu @@ -20,8 +20,8 @@ #include "tvm_ffi_utils.h" -void bmm_fp8(Tensor A, Tensor B, Tensor D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, - int64_t cublas_handle) { +void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, TensorView B_scale, + TensorView workspace_buffer, int64_t cublas_handle) { CHECK_CUDA(A); CHECK_CUDA(B); CHECK_CUDA(D); @@ -50,7 +50,7 @@ void bmm_fp8(Tensor A, Tensor B, Tensor D, Tensor A_scale, Tensor B_scale, Tenso auto stream = get_stream(A->device); auto status = flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt( - workspace_buffer->data, get_numel(workspace_buffer), static_cast(B->data), + workspace_buffer->data, workspace_buffer.numel(), static_cast(B->data), static_cast(A->data), static_cast(D->data), batch_size, n, m, k, static_cast(B_scale->data), static_cast(A_scale->data), lt_handle, stream); diff --git a/csrc/cascade.cu b/csrc/cascade.cu index cb4dfe7ab6..186a5c113b 100644 --- a/csrc/cascade.cu +++ b/csrc/cascade.cu @@ -20,7 +20,8 @@ using namespace flashinfer; using tvm::ffi::Optional; -void merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor v_merged, Tensor s_merged) { +void merge_state(TensorView v_a, TensorView s_a, TensorView v_b, TensorView s_b, + TensorView v_merged, TensorView s_merged) { CHECK_INPUT(v_a); CHECK_INPUT(s_a); CHECK_INPUT(v_b); @@ -56,8 +57,8 @@ void merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor v_merged TVM_FFI_ICHECK(success) << "MergeState kernel launch failed: unsupported data type."; } -void merge_state_in_place(Tensor v, Tensor s, Tensor v_other, Tensor s_other, - Optional mask) { +void merge_state_in_place(TensorView v, TensorView s, TensorView v_other, TensorView s_other, + Optional mask) { CHECK_INPUT(v); CHECK_INPUT(s); CHECK_INPUT(v_other); @@ -99,7 +100,7 @@ void merge_state_in_place(Tensor v, Tensor s, Tensor v_other, Tensor s_other, TVM_FFI_ICHECK(success) << "MergeStateInPlace kernel launch failed: unsupported data type."; } -void merge_states(Tensor v, Tensor s, Tensor v_merged, Tensor s_merged) { +void merge_states(TensorView v, TensorView s, TensorView v_merged, TensorView s_merged) { CHECK_INPUT(v); CHECK_INPUT(s); CHECK_DEVICE(s, v); diff --git a/csrc/cudnn_sdpa_kernel_launcher.cu b/csrc/cudnn_sdpa_kernel_launcher.cu index c5c7d58e7d..a8e2682e40 100644 --- a/csrc/cudnn_sdpa_kernel_launcher.cu +++ b/csrc/cudnn_sdpa_kernel_launcher.cu @@ -322,10 +322,13 @@ __global__ static void __launch_bounds__(128) } } -static void create_packed_tma_desc_kv_prefill( - int b, int32_t* actual_seq_lens_kv_data, int64_t d_qk, int64_t d_vo, int64_t h_kv, - uint32_t* tensor_traversal_stride_qkv, uint32_t* tensor_box_size_kv, - tma::cudaTmaDesc* packed_tma_desc_k, tma::cudaTmaDesc* packed_tma_desc_v, Tensor k, Tensor v) { +static void create_packed_tma_desc_kv_prefill(int b, int32_t* actual_seq_lens_kv_data, int64_t d_qk, + int64_t d_vo, int64_t h_kv, + uint32_t* tensor_traversal_stride_qkv, + uint32_t* tensor_box_size_kv, + tma::cudaTmaDesc* packed_tma_desc_k, + tma::cudaTmaDesc* packed_tma_desc_v, TensorView k, + TensorView v) { int64_t batch_offset_k = 0; int64_t batch_offset_v = 0; // tma descriptors for packed q and o @@ -363,8 +366,8 @@ static void create_packed_tma_desc_qo_prefill(int b, int32_t* actual_seq_lens_q_ uint32_t* tensor_traversal_stride_qkv, uint32_t* tensor_box_size_q, tma::cudaTmaDesc* packed_tma_desc_q, - tma::cudaTmaDesc* packed_tma_desc_o, Tensor q, - Tensor out, int64_t* batch_offset_array) { + tma::cudaTmaDesc* packed_tma_desc_o, TensorView q, + TensorView out, int64_t* batch_offset_array) { int64_t batch_offset_q = 0; int64_t batch_offset_o = 0; // tma descriptors for packed q and o @@ -511,12 +514,13 @@ void setup_decode(CUfunction* hfunc_decode, CUfunction* lean_attn_reduction) { } }; -void prefill(int64_t b, int64_t s_qo, int64_t max_s_kv, Tensor q, Tensor k_cache, Tensor v_cache, - double scale, Tensor workspace_buffer, Tensor actual_seq_lens_q, - Tensor actual_seq_lens_kv, Tensor actual_seq_lens_q_gpu, Tensor actual_seq_lens_kv_gpu, - Tensor block_tables, bool causal, bool return_lse, Tensor out, Tensor lse, - Optional batch_offset_q_array, Optional batch_offset_o_array, - Optional batch_offset_k_array, Optional batch_offset_v_array, +void prefill(int64_t b, int64_t s_qo, int64_t max_s_kv, TensorView q, TensorView k_cache, + TensorView v_cache, double scale, TensorView workspace_buffer, + TensorView actual_seq_lens_q, TensorView actual_seq_lens_kv, + TensorView actual_seq_lens_q_gpu, TensorView actual_seq_lens_kv_gpu, + TensorView block_tables, bool causal, bool return_lse, TensorView out, TensorView lse, + Optional batch_offset_q_array, Optional batch_offset_o_array, + Optional batch_offset_k_array, Optional batch_offset_v_array, bool is_cuda_graph_compatible) { constexpr size_t SMEM_SIZE = 227 * 1024; // All smem constexpr int64_t TILE_M_1 = 128; @@ -833,9 +837,9 @@ int32_t get_kernel_id(int32_t q_heads_per_kv) { } void setup_tma_desc_decode(int64_t b, int64_t s_kv, int64_t h_qo, int64_t h_kv, int64_t d, - int64_t total_num_pages, Tensor q, Tensor out, Tensor k_cache, - Tensor v_cache, int32_t split_factor, int64_t page_size, - int8_t* partial_o_dev, tma::cudaTmaDesc* tma_desc_q, + int64_t total_num_pages, TensorView q, TensorView out, + TensorView k_cache, TensorView v_cache, int32_t split_factor, + int64_t page_size, int8_t* partial_o_dev, tma::cudaTmaDesc* tma_desc_q, tma::cudaTmaDesc* tma_desc_o, tma::cudaTmaDesc* tma_desc_partial_o, tma::cudaTmaDesc* tma_desc_k, tma::cudaTmaDesc* tma_desc_v) { auto kid = get_kernel_id(h_qo / h_kv); @@ -914,10 +918,11 @@ void setup_tma_desc_decode(int64_t b, int64_t s_kv, int64_t h_qo, int64_t h_kv, tma::cudaTmaDescSwizzle::SWIZZLE_128B); } -void decode(int64_t max_s_kv, Tensor q, Tensor k_cache, Tensor v_cache, double scale, - Tensor workspace_buffer, Tensor actual_seq_lens_kv, Tensor actual_seq_lens_kv_gpu, - Tensor block_tables, Tensor out, Optional batch_offset_q_array, - Optional batch_offset_o_array, bool is_cuda_graph_compatible) { +void decode(int64_t max_s_kv, TensorView q, TensorView k_cache, TensorView v_cache, double scale, + TensorView workspace_buffer, TensorView actual_seq_lens_kv, + TensorView actual_seq_lens_kv_gpu, TensorView block_tables, TensorView out, + Optional batch_offset_q_array, Optional batch_offset_o_array, + bool is_cuda_graph_compatible) { constexpr size_t SMEM_SIZE = 227 * 1024; // All smem constexpr size_t REDUCTION_MEM_SIZE = 128 * 1024; constexpr int64_t TILE_N_1 = 128; diff --git a/csrc/cutlass_mla.cu b/csrc/cutlass_mla.cu index d9a96aa573..ecc528de40 100644 --- a/csrc/cutlass_mla.cu +++ b/csrc/cutlass_mla.cu @@ -20,9 +20,9 @@ using namespace flashinfer; using namespace flashinfer::attention; -void CutlassMLAPagedAttention(ffi::Tensor workspace, ffi::Tensor out, ffi::Tensor lse, - ffi::Tensor q_nope_pe, ffi::Tensor ckv_kpe_cache, ffi::Tensor kv_lens, - ffi::Tensor page_table) { +void CutlassMLAPagedAttention(ffi::TensorView workspace, ffi::TensorView out, ffi::TensorView lse, + ffi::TensorView q_nope_pe, ffi::TensorView ckv_kpe_cache, + ffi::TensorView kv_lens, ffi::TensorView page_table) { cudaSetDevice(q_nope_pe->device.device_id); const cudaStream_t stream = get_stream(q_nope_pe->device); diff --git a/csrc/flashinfer_cascade_binding.cu b/csrc/flashinfer_cascade_binding.cu index e5887c0239..0fe7e7c06c 100644 --- a/csrc/flashinfer_cascade_binding.cu +++ b/csrc/flashinfer_cascade_binding.cu @@ -18,12 +18,13 @@ using tvm::ffi::Optional; -void merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor v_merged, Tensor s_merged); +void merge_state(TensorView v_a, TensorView s_a, TensorView v_b, TensorView s_b, + TensorView v_merged, TensorView s_merged); -void merge_state_in_place(Tensor v, Tensor s, Tensor v_other, Tensor s_other, - Optional mask); +void merge_state_in_place(TensorView v, TensorView s, TensorView v_other, TensorView s_other, + Optional mask); -void merge_states(Tensor v, Tensor s, Tensor v_merged, Tensor s_merged); +void merge_states(TensorView v, TensorView s, TensorView v_merged, TensorView s_merged); // Merge two self-attention states TVM_FFI_DLL_EXPORT_TYPED_FUNC(merge_state, merge_state); diff --git a/csrc/flashinfer_gemm_binding.cu b/csrc/flashinfer_gemm_binding.cu index ed6bc88ae5..52d0551413 100644 --- a/csrc/flashinfer_gemm_binding.cu +++ b/csrc/flashinfer_gemm_binding.cu @@ -16,12 +16,12 @@ #include "tvm_ffi_utils.h" -void bmm_fp8(Tensor A, Tensor B, Tensor D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, - int64_t cublas_handle); +void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, TensorView B_scale, + TensorView workspace_buffer, int64_t cublas_handle); -void CutlassSegmentGEMM(Tensor workspace_buffer, Tensor all_problems, Tensor x_ptr, Tensor w_ptr, - Tensor y_ptr, Tensor x_ld, Tensor w_ld, Tensor y_ld, Tensor empty_x_data, - bool weight_column_major); +void CutlassSegmentGEMM(TensorView workspace_buffer, TensorView all_problems, TensorView x_ptr, + TensorView w_ptr, TensorView y_ptr, TensorView x_ld, TensorView w_ld, + TensorView y_ld, TensorView empty_x_data, bool weight_column_major); TVM_FFI_DLL_EXPORT_TYPED_FUNC(cutlass_segment_gemm, CutlassSegmentGEMM); TVM_FFI_DLL_EXPORT_TYPED_FUNC(bmm_fp8, bmm_fp8); diff --git a/csrc/flashinfer_gemm_sm90_binding.cu b/csrc/flashinfer_gemm_sm90_binding.cu index b994ac3a06..d5d86abd69 100644 --- a/csrc/flashinfer_gemm_sm90_binding.cu +++ b/csrc/flashinfer_gemm_sm90_binding.cu @@ -15,10 +15,11 @@ */ #include "tvm_ffi_utils.h" -void CutlassSegmentGEMMSM90(Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Tensor all_problems, Tensor x_ptr, Tensor w_ptr, Tensor y_ptr, - Tensor x_stride, Tensor weight_stride, Tensor y_stride, - Tensor empty_x_data, Tensor empty_y_data, bool weight_column_major); +void CutlassSegmentGEMMSM90(TensorView float_workspace_buffer, TensorView int_workspace_buffer, + TensorView all_problems, TensorView x_ptr, TensorView w_ptr, + TensorView y_ptr, TensorView x_stride, TensorView weight_stride, + TensorView y_stride, TensorView empty_x_data, TensorView empty_y_data, + bool weight_column_major); // "Cutlass Segment GEMM operator for SM90" TVM_FFI_DLL_EXPORT_TYPED_FUNC(cutlass_segment_gemm_sm90, CutlassSegmentGEMMSM90); diff --git a/csrc/flashinfer_mla_binding.cu b/csrc/flashinfer_mla_binding.cu index 2eb06102a2..c8fc1a61e3 100644 --- a/csrc/flashinfer_mla_binding.cu +++ b/csrc/flashinfer_mla_binding.cu @@ -15,7 +15,8 @@ */ #include "tvm_ffi_utils.h" -void CutlassMLAPagedAttention(Tensor workspace, Tensor out, Tensor lse, Tensor q_nope_pe, - Tensor ckv_kpe_cache, Tensor kv_lens, Tensor page_table); +void CutlassMLAPagedAttention(TensorView workspace, TensorView out, TensorView lse, + TensorView q_nope_pe, TensorView ckv_kpe_cache, TensorView kv_lens, + TensorView page_table); TVM_FFI_DLL_EXPORT_TYPED_FUNC(cutlass_mla_paged_attention, CutlassMLAPagedAttention); diff --git a/csrc/flashinfer_norm_binding.cu b/csrc/flashinfer_norm_binding.cu index a14944e266..d647eb7810 100644 --- a/csrc/flashinfer_norm_binding.cu +++ b/csrc/flashinfer_norm_binding.cu @@ -15,15 +15,15 @@ */ #include "tvm_ffi_utils.h" -using tvm::ffi::Tensor; +void rmsnorm(TensorView out, TensorView input, TensorView weight, double eps, bool enable_pdl); -void rmsnorm(Tensor out, Tensor input, Tensor weight, double eps, bool enable_pdl); +void fused_add_rmsnorm(TensorView input, TensorView residual, TensorView weight, double eps, + bool enable_pdl); -void fused_add_rmsnorm(Tensor input, Tensor residual, Tensor weight, double eps, bool enable_pdl); +void gemma_rmsnorm(TensorView out, TensorView input, TensorView weight, double eps, + bool enable_pdl); -void gemma_rmsnorm(Tensor out, Tensor input, Tensor weight, double eps, bool enable_pdl); - -void gemma_fused_add_rmsnorm(Tensor input, Tensor residual, Tensor weight, double eps, +void gemma_fused_add_rmsnorm(TensorView input, TensorView residual, TensorView weight, double eps, bool enable_pdl); TVM_FFI_DLL_EXPORT_TYPED_FUNC(rmsnorm, rmsnorm); diff --git a/csrc/flashinfer_page_binding.cu b/csrc/flashinfer_page_binding.cu index 9026ff442b..dbab4f5cb8 100644 --- a/csrc/flashinfer_page_binding.cu +++ b/csrc/flashinfer_page_binding.cu @@ -17,21 +17,20 @@ using tvm::ffi::Tensor; -void append_paged_kv_cache(Tensor append_key, Tensor append_value, Tensor batch_indices, - Tensor positions, Tensor paged_k_cache, Tensor paged_v_cache, - Tensor kv_indices, Tensor kv_indptr, Tensor kv_last_page_len, +void append_paged_kv_cache(TensorView append_key, TensorView append_value, TensorView batch_indices, + TensorView positions, TensorView paged_k_cache, TensorView paged_v_cache, + TensorView kv_indices, TensorView kv_indptr, TensorView kv_last_page_len, int64_t layout); -void append_paged_mla_kv_cache(Tensor append_ckv, Tensor append_kpe, Tensor batch_indices, - Tensor positions, Tensor ckv_cache, Tensor kpe_cache, - Tensor kv_indices, Tensor kv_indptr, Tensor kv_last_page_len); +void append_paged_mla_kv_cache(TensorView append_ckv, TensorView append_kpe, + TensorView batch_indices, TensorView positions, TensorView ckv_cache, + TensorView kpe_cache, TensorView kv_indices, TensorView kv_indptr, + TensorView kv_last_page_len); -void block_sparse_indices_to_vector_sparse_offsets(Tensor block_sparse_indices, - Tensor block_sparse_indptr, - Tensor vector_sparse_offsets, - Tensor vector_sparse_indptr, Tensor kv_len_arr, - int64_t stride_block, int64_t stride_n, - int64_t batch_size, int64_t block_size); +void block_sparse_indices_to_vector_sparse_offsets( + TensorView block_sparse_indices, TensorView block_sparse_indptr, + TensorView vector_sparse_offsets, TensorView vector_sparse_indptr, TensorView kv_len_arr, + int64_t stride_block, int64_t stride_n, int64_t batch_size, int64_t block_size); TVM_FFI_DLL_EXPORT_TYPED_FUNC(append_paged_kv_cache, append_paged_kv_cache); TVM_FFI_DLL_EXPORT_TYPED_FUNC(append_paged_mla_kv_cache, append_paged_mla_kv_cache); diff --git a/csrc/flashinfer_quantization_binding.cu b/csrc/flashinfer_quantization_binding.cu index 09f939980a..9e22fccc23 100644 --- a/csrc/flashinfer_quantization_binding.cu +++ b/csrc/flashinfer_quantization_binding.cu @@ -15,10 +15,10 @@ */ #include "tvm_ffi_utils.h" -void packbits(Tensor x, const std::string& bitorder, Tensor y); +void packbits(TensorView x, const std::string& bitorder, TensorView y); -void segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, - const std::string& bitorder, Tensor y); +void segment_packbits(TensorView x, TensorView input_indptr, TensorView output_indptr, + const std::string& bitorder, TensorView y); TVM_FFI_DLL_EXPORT_TYPED_FUNC(packbits, packbits); TVM_FFI_DLL_EXPORT_TYPED_FUNC(segment_packbits, segment_packbits); diff --git a/csrc/flashinfer_rope_binding.cu b/csrc/flashinfer_rope_binding.cu index bfea53c275..ed954082f1 100644 --- a/csrc/flashinfer_rope_binding.cu +++ b/csrc/flashinfer_rope_binding.cu @@ -17,29 +17,33 @@ using tvm::ffi::Tensor; -void apply_rope(Tensor q, Tensor k, Tensor q_rope, Tensor k_rope, Tensor indptr, Tensor offsets, - int64_t rotary_dim, bool interleave, double rope_scale, double rope_theta); - -void apply_llama31_rope(Tensor q, Tensor k, Tensor q_rope, Tensor k_rope, Tensor indptr, - Tensor offsets, int64_t rotary_dim, bool interleave, double rope_scale, - double rope_theta, double low_freq_factor, double high_freq_factor, - double old_context_length); - -void apply_rope_pos_ids(Tensor q, Tensor k, Tensor q_rope, Tensor k_rope, Tensor pos_ids, - int64_t rotary_dim, bool interleave, double rope_scale, double rope_theta); - -void apply_llama31_rope_pos_ids(Tensor q, Tensor k, Tensor q_rope, Tensor k_rope, Tensor pos_ids, - int64_t rotary_dim, bool interleave, double rope_scale, - double rope_theta, double low_freq_factor, double high_freq_factor, - double old_context_length); - -void apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor q_rope, Tensor k_rope, - Tensor cos_sin_cache, Tensor pos_ids, bool interleave); - -void mla_rope_quantize(Tensor q_rope_in, Tensor k_rope_in, Tensor q_nope_in, Tensor k_nope_in, - Tensor q_rope_out, Tensor k_rope_out, Tensor q_nope_out, Tensor k_nope_out, - Tensor cos_sin_cache, Tensor pos_ids, double quant_scale_q, - double quant_scale_kv, bool interleave); +void apply_rope(TensorView q, TensorView k, TensorView q_rope, TensorView k_rope, TensorView indptr, + TensorView offsets, int64_t rotary_dim, bool interleave, double rope_scale, + double rope_theta); + +void apply_llama31_rope(TensorView q, TensorView k, TensorView q_rope, TensorView k_rope, + TensorView indptr, TensorView offsets, int64_t rotary_dim, bool interleave, + double rope_scale, double rope_theta, double low_freq_factor, + double high_freq_factor, double old_context_length); + +void apply_rope_pos_ids(TensorView q, TensorView k, TensorView q_rope, TensorView k_rope, + TensorView pos_ids, int64_t rotary_dim, bool interleave, double rope_scale, + double rope_theta); + +void apply_llama31_rope_pos_ids(TensorView q, TensorView k, TensorView q_rope, TensorView k_rope, + TensorView pos_ids, int64_t rotary_dim, bool interleave, + double rope_scale, double rope_theta, double low_freq_factor, + double high_freq_factor, double old_context_length); + +void apply_rope_pos_ids_cos_sin_cache(TensorView q, TensorView k, TensorView q_rope, + TensorView k_rope, TensorView cos_sin_cache, + TensorView pos_ids, bool interleave); + +void mla_rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in, + TensorView k_nope_in, TensorView q_rope_out, TensorView k_rope_out, + TensorView q_nope_out, TensorView k_nope_out, TensorView cos_sin_cache, + TensorView pos_ids, double quant_scale_q, double quant_scale_kv, + bool interleave); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope, apply_rope); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_llama31_rope, apply_llama31_rope); diff --git a/csrc/flashinfer_sampling_binding.cu b/csrc/flashinfer_sampling_binding.cu index 88af546e4a..8e4bbb98b8 100644 --- a/csrc/flashinfer_sampling_binding.cu +++ b/csrc/flashinfer_sampling_binding.cu @@ -17,45 +17,50 @@ using tvm::ffi::Optional; -void softmax(Tensor workspace_buffer, Tensor logits, Tensor output, - Optional maybe_temperature_arr, double temperature_val, bool enable_pdl); +void softmax(TensorView workspace_buffer, TensorView logits, TensorView output, + Optional maybe_temperature_arr, double temperature_val, bool enable_pdl); -void sampling_from_probs(Tensor probs, Tensor output, Optional maybe_indices, +void sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, bool deterministic, uint64_t philox_seed, uint64_t philox_offset); -void sampling_from_logits(Tensor logits, Tensor output, Optional maybe_indices, +void sampling_from_logits(TensorView logits, TensorView output, Optional maybe_indices, bool deterministic, uint64_t philox_seed, uint64_t philox_offset); -void top_p_sampling_from_probs(Tensor probs, Tensor output, Optional maybe_indices, - Optional maybe_top_p_arr, double top_p_val, +void top_p_sampling_from_probs(TensorView probs, TensorView output, + Optional maybe_indices, + Optional maybe_top_p_arr, double top_p_val, bool deterministic, uint64_t philox_seed, uint64_t philox_offset); -void top_k_sampling_from_probs(Tensor probs, Tensor output, Optional maybe_indices, - Optional maybe_top_k_arr, int64_t top_k_val, +void top_k_sampling_from_probs(TensorView probs, TensorView output, + Optional maybe_indices, + Optional maybe_top_k_arr, int64_t top_k_val, bool deterministic, uint64_t philox_seed, uint64_t philox_offset); -void min_p_sampling_from_probs(Tensor probs, Tensor output, Optional maybe_indices, - Optional maybe_min_p_arr, double min_p_val, +void min_p_sampling_from_probs(TensorView probs, TensorView output, + Optional maybe_indices, + Optional maybe_min_p_arr, double min_p_val, bool deterministic, uint64_t philox_seed, uint64_t philox_offset); -void top_k_top_p_sampling_from_probs(Tensor probs, Tensor output, Optional maybe_indices, - Optional maybe_top_k_arr, double top_k_val, - Optional maybe_top_p_arr, double top_p_val, +void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output, + Optional maybe_indices, + Optional maybe_top_k_arr, double top_k_val, + Optional maybe_top_p_arr, double top_p_val, bool deterministic, uint64_t philox_seed, uint64_t philox_offset); -void top_p_renorm_probs(Tensor probs, Tensor renorm_probs, Optional maybe_top_p_arr, - double top_p_val); +void top_p_renorm_probs(TensorView probs, TensorView renorm_probs, + Optional maybe_top_p_arr, double top_p_val); -void top_k_renorm_probs(Tensor probs, Tensor renorm_probs, Optional maybe_top_k_arr, - int64_t top_k_val); +void top_k_renorm_probs(TensorView probs, TensorView renorm_probs, + Optional maybe_top_k_arr, int64_t top_k_val); -void top_k_mask_logits(Tensor logits, Tensor mask_logits, Optional maybe_top_k_arr, - int64_t top_k_val); +void top_k_mask_logits(TensorView logits, TensorView mask_logits, + Optional maybe_top_k_arr, int64_t top_k_val); -void chain_speculative_sampling(Tensor draft_probs, Tensor draft_token_ids, Tensor target_probs, - Tensor output_token_ids, Tensor output_accepted_token_num, - Tensor output_emitted_draft_token_num, bool deterministic, +void chain_speculative_sampling(TensorView draft_probs, TensorView draft_token_ids, + TensorView target_probs, TensorView output_token_ids, + TensorView output_accepted_token_num, + TensorView output_emitted_draft_token_num, bool deterministic, uint64_t philox_seed, uint64_t philox_offset); // Softmax diff --git a/csrc/flashinfer_xqa_binding.cu b/csrc/flashinfer_xqa_binding.cu index 3400362982..c812739c76 100644 --- a/csrc/flashinfer_xqa_binding.cu +++ b/csrc/flashinfer_xqa_binding.cu @@ -17,15 +17,16 @@ #include "tvm_ffi_utils.h" void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingWinSize, - double qScale, Tensor output, + double qScale, TensorView output, #if LOW_PREC_OUTPUT - Tensor rcpOutScale, + TensorView rcpOutScale, #endif - Tensor q, Tensor attentionSinks, Tensor pool, Tensor kvCachePageList, - int64_t maxSeqLen, Tensor seqLen, int64_t batchSize, Tensor kvCacheScale, + TensorView q, TensorView attentionSinks, TensorView pool, + TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, + int64_t batchSize, TensorView kvCacheScale, #if SPEC_DEC - int64_t qSeqLen, Tensor qCuSeqLens, Tensor mask, + int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask, #endif - Tensor semaphores, Tensor scratch); + TensorView semaphores, TensorView scratch); TVM_FFI_DLL_EXPORT_TYPED_FUNC(xqa_wrapper, xqa_wrapper); diff --git a/csrc/fmha_cutlass_sm100.cu b/csrc/fmha_cutlass_sm100.cu index efdf252ce4..82dc5e697c 100644 --- a/csrc/fmha_cutlass_sm100.cu +++ b/csrc/fmha_cutlass_sm100.cu @@ -70,13 +70,14 @@ using tvm::ffi::Optional; using namespace flashinfer; -void FMHACutlassSM100Run(ffi::Tensor workspace_buffer, ffi::Tensor q, ffi::Tensor k, ffi::Tensor v, - ffi::Tensor qo_segment_offsets, ffi::Tensor kv_segment_offsets, - ffi::Tensor work_indptr, ffi::Tensor qo_tile_indices, - ffi::Tensor qo_head_indices, ffi::Tensor batch_indices, ffi::Tensor o, - Optional maybe_lse, int64_t mask_mode_code, double sm_scale, - int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk, - int64_t head_dim_vo, int64_t max_qo_len) { +void FMHACutlassSM100Run(ffi::TensorView workspace_buffer, ffi::TensorView q, ffi::TensorView k, + ffi::TensorView v, ffi::TensorView qo_segment_offsets, + ffi::TensorView kv_segment_offsets, ffi::TensorView work_indptr, + ffi::TensorView qo_tile_indices, ffi::TensorView qo_head_indices, + ffi::TensorView batch_indices, ffi::TensorView o, + Optional maybe_lse, int64_t mask_mode_code, + double sm_scale, int64_t num_qo_heads, int64_t num_kv_heads, + int64_t head_dim_qk, int64_t head_dim_vo, int64_t max_qo_len) { TVM_FFI_ICHECK_EQ(q->dtype, k->dtype); auto scalar_type_in = q->dtype; auto scalar_type_out = o->dtype; diff --git a/csrc/fmha_cutlass_sm100_binding.cu b/csrc/fmha_cutlass_sm100_binding.cu index 6a3370e0cc..ddb3b8d9cd 100644 --- a/csrc/fmha_cutlass_sm100_binding.cu +++ b/csrc/fmha_cutlass_sm100_binding.cu @@ -17,16 +17,18 @@ using tvm::ffi::Optional; -void FMHACutlassSM100Run(Tensor workspace_buffer, Tensor q, Tensor k, Tensor v, - Tensor qo_segment_offsets, Tensor kv_segment_offsets, Tensor work_indptr, - Tensor qo_tile_indices, Tensor qo_head_indices, Tensor batch_indices, - Tensor o, Optional maybe_lse, int64_t mask_mode_code, - double sm_scale, int64_t num_qo_heads, int64_t num_kv_heads, - int64_t head_dim_qk, int64_t head_dim_vo, int64_t max_qo_len); +void FMHACutlassSM100Run(TensorView workspace_buffer, TensorView q, TensorView k, TensorView v, + TensorView qo_segment_offsets, TensorView kv_segment_offsets, + TensorView work_indptr, TensorView qo_tile_indices, + TensorView qo_head_indices, TensorView batch_indices, TensorView o, + Optional maybe_lse, int64_t mask_mode_code, double sm_scale, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk, + int64_t head_dim_vo, int64_t max_qo_len); -void blackwell_fmha_plan(Tensor qo_segment_offsets, Tensor kv_segment_offsets, Tensor work_indptr, - Tensor qo_tile_indices, Tensor head_indices, Tensor batch_indices, - int64_t qo_tile_size, int64_t num_heads, int64_t num_buckets, bool causal); +void blackwell_fmha_plan(TensorView qo_segment_offsets, TensorView kv_segment_offsets, + TensorView work_indptr, TensorView qo_tile_indices, + TensorView head_indices, TensorView batch_indices, int64_t qo_tile_size, + int64_t num_heads, int64_t num_buckets, bool causal); TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, FMHACutlassSM100Run); TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, blackwell_fmha_plan); diff --git a/csrc/fp4_gemm_cutlass.cu b/csrc/fp4_gemm_cutlass.cu index f20885b5c2..e55f6e6ed4 100644 --- a/csrc/fp4_gemm_cutlass.cu +++ b/csrc/fp4_gemm_cutlass.cu @@ -58,14 +58,15 @@ CutlassGemmConfig getFp4GemmConfig(int64_t m, int64_t n, int64_t k, int64_t tact } template -void runGemm(Tensor& out, Tensor const& mat1, Tensor const& mat2, Tensor const& mat1Scale, - Tensor const& mat2Scale, Tensor const& globalScale, int64_t m, int64_t n, int64_t k, - int64_t batch_count, CutlassGemmConfig const& gemmConfig, Tensor workspace_buffer) { +void runGemm(TensorView out, TensorView mat1, TensorView mat2, TensorView mat1Scale, + TensorView mat2Scale, TensorView globalScale, int64_t m, int64_t n, int64_t k, + int64_t batch_count, CutlassGemmConfig const& gemmConfig, + TensorView workspace_buffer) { CutlassFp4GemmRunner gemmRunner; int64_t const required_workspace_size = gemmRunner.getWorkspaceSize(m, n, k, batch_count); int64_t const provided_workspace_size = - get_numel(workspace_buffer) * get_element_size(workspace_buffer); + workspace_buffer.numel() * get_element_size(workspace_buffer); auto runKernel = [&](void* workspace) { gemmRunner.gemm(out->data, mat1->data, mat2->data, mat1Scale->data, mat2Scale->data, @@ -93,9 +94,9 @@ constexpr auto SF_DTYPE = dl_uint8; // uint8_t // mat2Scale: ceil(N / 128) * 128 * ceil(K / sfVecSize / 4) * 4, SF_DTYPE (UE4M3 or UE8M0) // globalScale: [1], 1 / (((448 * 6) / mat1.abs().max()) * ((448 * 6) / mat2.abs().max())) // B = 1 for GEMM op as a special case -Tensor fp4_bmm_impl(Tensor const& mat1, Tensor const& mat2, Tensor const& mat1Scale, - Tensor const& mat2Scale, Tensor const& globalScale, Tensor out, - Tensor workspace_buffer, int64_t tactic) { +void fp4_bmm_impl(TensorView mat1, TensorView mat2, TensorView mat1Scale, TensorView mat2Scale, + TensorView globalScale, TensorView out, TensorView workspace_buffer, + int64_t tactic) { CHECK_INPUT_AND_TYPE(mat1, FLOAT4_E2M1X2); CHECK_INPUT_AND_TYPE(mat2, FLOAT4_E2M1X2); @@ -169,15 +170,13 @@ Tensor fp4_bmm_impl(Tensor const& mat1, Tensor const& mat2, Tensor const& mat1Sc default: TVM_FFI_ICHECK(false) << "out_dtype must be one of fp16/bf16."; } - return out; } } // namespace -Tensor fp4_gemm(Tensor const& mat1, Tensor const& mat2, Tensor const& mat1Scale, - Tensor const& mat2Scale, Tensor const& globalScale, Tensor out, - Tensor workspace_buffer, int64_t tactic) { - return fp4_bmm_impl(mat1, mat2, mat1Scale, mat2Scale, globalScale, out, workspace_buffer, tactic); +void fp4_gemm(TensorView mat1, TensorView mat2, TensorView mat1Scale, TensorView mat2Scale, + TensorView globalScale, TensorView out, TensorView workspace_buffer, int64_t tactic) { + fp4_bmm_impl(mat1, mat2, mat1Scale, mat2Scale, globalScale, out, workspace_buffer, tactic); } int64_t fp4_gemm_tactic_num() { diff --git a/csrc/fp4_gemm_cutlass_sm120.cu b/csrc/fp4_gemm_cutlass_sm120.cu index c7c3f513e2..3848d55f85 100644 --- a/csrc/fp4_gemm_cutlass_sm120.cu +++ b/csrc/fp4_gemm_cutlass_sm120.cu @@ -50,14 +50,15 @@ CutlassGemmConfig getFp4GemmConfig(int64_t m, int64_t n, int64_t k, int64_t tact } template -void runGemm(Tensor& out, Tensor const& mat1, Tensor const& mat2, Tensor const& mat1Scale, - Tensor const& mat2Scale, Tensor const& globalScale, int64_t m, int64_t n, int64_t k, - int64_t batch_count, CutlassGemmConfig const& gemmConfig, Tensor workspace_buffer) { +void runGemm(TensorView out, TensorView mat1, TensorView mat2, TensorView mat1Scale, + TensorView mat2Scale, TensorView globalScale, int64_t m, int64_t n, int64_t k, + int64_t batch_count, CutlassGemmConfig const& gemmConfig, + TensorView workspace_buffer) { CutlassFp4GemmRunner gemmRunner; int64_t const required_workspace_size = gemmRunner.getWorkspaceSize(m, n, k, batch_count); int64_t const provided_workspace_size = - get_numel(workspace_buffer) * get_element_size(workspace_buffer); + workspace_buffer.numel() * get_element_size(workspace_buffer); auto runKernel = [&](void* workspace) { gemmRunner.gemm(out->data, mat1->data, mat2->data, mat1Scale->data, mat2Scale->data, @@ -78,9 +79,9 @@ void runGemm(Tensor& out, Tensor const& mat1, Tensor const& mat2, Tensor const& constexpr auto FLOAT4_E2M1X2 = dl_uint8; // uint8_t constexpr auto SF_DTYPE = dl_uint8; // uint8_t -Tensor fp4_bmm_impl(Tensor const& mat1, Tensor const& mat2, Tensor const& mat1Scale, - Tensor const& mat2Scale, Tensor const& globalScale, Tensor out, - Tensor workspace_buffer, int64_t tactic) { +void fp4_bmm_impl(TensorView mat1, TensorView mat2, TensorView mat1Scale, TensorView mat2Scale, + TensorView globalScale, TensorView out, TensorView workspace_buffer, + int64_t tactic) { // Validate inputs TVM_FFI_ICHECK_EQ(mat1->dtype, FLOAT4_E2M1X2) << "mat1 must be FLOAT4_E2M1X2 (uint8)"; TVM_FFI_ICHECK_EQ(mat2->dtype, FLOAT4_E2M1X2) << "mat2 must be FLOAT4_E2M1X2 (uint8)"; @@ -134,7 +135,7 @@ Tensor fp4_bmm_impl(Tensor const& mat1, Tensor const& mat2, Tensor const& mat1Sc // k_packed stores 2 FP4 values per byte int64_t k = k_packed * 2; - TVM_FFI_ICHECK_EQ(get_numel(globalScale), 1) << "globalScale must be a scalar tensor"; + TVM_FFI_ICHECK_EQ(globalScale.numel(), 1) << "globalScale must be a scalar tensor"; // Configure the kernel CutlassGemmConfig config = @@ -165,15 +166,13 @@ Tensor fp4_bmm_impl(Tensor const& mat1, Tensor const& mat2, Tensor const& mat1Sc default: TVM_FFI_ICHECK(false) << "out_dtype must be one of fp16/bf16."; } - return out; } } // namespace -Tensor fp4_gemm(Tensor const& mat1, Tensor const& mat2, Tensor const& mat1Scale, - Tensor const& mat2Scale, Tensor const& globalScale, Tensor out, - Tensor workspace_buffer, int64_t tactic) { - return fp4_bmm_impl(mat1, mat2, mat1Scale, mat2Scale, globalScale, out, workspace_buffer, tactic); +void fp4_gemm(TensorView mat1, TensorView mat2, TensorView mat1Scale, TensorView mat2Scale, + TensorView globalScale, TensorView out, TensorView workspace_buffer, int64_t tactic) { + fp4_bmm_impl(mat1, mat2, mat1Scale, mat2Scale, globalScale, out, workspace_buffer, tactic); } int64_t fp4_gemm_tactic_num() { diff --git a/csrc/fp8_gemm_cutlass.cu b/csrc/fp8_gemm_cutlass.cu index 25b28bf17e..f8cc28ab38 100644 --- a/csrc/fp8_gemm_cutlass.cu +++ b/csrc/fp8_gemm_cutlass.cu @@ -57,14 +57,14 @@ CutlassGemmConfig getFp8GemmConfig(int64_t m, int64_t n, int64_t k, int64_t tact } template -void runGemm(Tensor out, Tensor mat1, Tensor mat2, Tensor scale_a, Tensor scale_b, int64_t m, - int64_t n, int64_t k, int64_t b, CutlassGemmConfig const& gemmConfig, - Tensor workspace_buffer) { +void runGemm(TensorView out, TensorView mat1, TensorView mat2, TensorView scale_a, + TensorView scale_b, int64_t m, int64_t n, int64_t k, int64_t b, + CutlassGemmConfig const& gemmConfig, TensorView workspace_buffer) { CutlassFp8GemmRunner gemmRunner; int64_t const required_workspace_size = gemmRunner.getWorkspaceSize(m, n, k); int64_t const provided_workspace_size = - get_numel(workspace_buffer) * get_element_size(workspace_buffer); + workspace_buffer.numel() * get_element_size(workspace_buffer); auto runKernel = [&](void* workspace) { gemmRunner.gemm(static_cast<__nv_fp8_e4m3*>(mat1->data), @@ -84,8 +84,8 @@ void runGemm(Tensor out, Tensor mat1, Tensor mat2, Tensor scale_a, Tensor scale_ } } -Tensor fp8_bmm_impl(Tensor mat1, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor out, - Tensor workspace_buffer, int64_t tactic) { +void fp8_bmm_impl(TensorView mat1, TensorView mat2, TensorView scale_a, TensorView scale_b, + TensorView out, TensorView workspace_buffer, int64_t tactic) { CHECK_INPUT(mat1); CHECK_INPUT(mat2); CHECK_INPUT(scale_a); @@ -147,14 +147,13 @@ Tensor fp8_bmm_impl(Tensor mat1, Tensor mat2, Tensor scale_a, Tensor scale_b, Te default: TVM_FFI_LOG_AND_THROW(NotImplementedError) << "out_dtype must be one of fp16/bf16."; } - return out; } } // namespace -Tensor fp8_gemm(Tensor mat1, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor out, - Tensor workspace_buffer, int64_t tactic) { - return fp8_bmm_impl(mat1, mat2, scale_a, scale_b, out, workspace_buffer, tactic); +void fp8_gemm(TensorView mat1, TensorView mat2, TensorView scale_a, TensorView scale_b, + TensorView out, TensorView workspace_buffer, int64_t tactic) { + fp8_bmm_impl(mat1, mat2, scale_a, scale_b, out, workspace_buffer, tactic); } int64_t fp8_gemm_tactic_num() { diff --git a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu index 96b511d54a..23e027717b 100644 --- a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu +++ b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu @@ -1043,14 +1043,12 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { return kernels::QuantParams::GroupWise( group_size, static_cast(fc1_weight_scales->data), static_cast(fc2_weight_scales->data), - static_cast(get_numel(fc1_act_scales) > 0 ? fc1_act_scales->data : nullptr), - static_cast(get_numel(fc2_act_scales) > 0 ? fc2_act_scales->data : nullptr), - static_cast(get_numel(fc1_weight_zeros) > 0 ? fc1_weight_zeros->data - : nullptr), - static_cast(get_numel(fc2_weight_zeros) > 0 ? fc2_weight_zeros->data - : nullptr), - static_cast(get_numel(fc1_alpha) > 0 ? fc1_alpha->data : nullptr), - static_cast(get_numel(fc2_alpha) > 0 ? fc2_alpha->data : nullptr)); + static_cast(fc1_act_scales.numel() > 0 ? fc1_act_scales->data : nullptr), + static_cast(fc2_act_scales.numel() > 0 ? fc2_act_scales->data : nullptr), + static_cast(fc1_weight_zeros.numel() > 0 ? fc1_weight_zeros->data : nullptr), + static_cast(fc2_weight_zeros.numel() > 0 ? fc2_weight_zeros->data : nullptr), + static_cast(fc1_alpha.numel() > 0 ? fc1_alpha->data : nullptr), + static_cast(fc2_alpha.numel() > 0 ? fc2_alpha->data : nullptr)); } else { return kernels::QuantParams{}; } diff --git a/csrc/gemm_groupwise_sm100.cu b/csrc/gemm_groupwise_sm100.cu index ca2158fba7..56a89a23e0 100644 --- a/csrc/gemm_groupwise_sm100.cu +++ b/csrc/gemm_groupwise_sm100.cu @@ -86,10 +86,11 @@ cudaError_t CutlassGroupwiseScaledGEMMSM100(void* float_buffer, size_t float_buf } // namespace gemm } // namespace flashinfer -void CutlassGemmGroupwiseScaledSM100(Tensor float_workspace_buffer, Tensor A, Tensor B, Tensor SFA, - Tensor SFB, Tensor C, int64_t scale_granularity_m, - int64_t scale_granularity_n, int64_t scale_granularity_k, - std::string scale_major_mode, int64_t mma_sm) { +void CutlassGemmGroupwiseScaledSM100(TensorView float_workspace_buffer, TensorView A, TensorView B, + TensorView SFA, TensorView SFB, TensorView C, + int64_t scale_granularity_m, int64_t scale_granularity_n, + int64_t scale_granularity_k, std::string scale_major_mode, + int64_t mma_sm) { cudaSetDevice(float_workspace_buffer->device.device_id); const cudaStream_t stream = get_stream(C->device); DISPATCH_SCALE_MAJOR_K(scale_major_mode, SCALE_MAJOR_K, [&] { diff --git a/csrc/gemm_groupwise_sm120.cu b/csrc/gemm_groupwise_sm120.cu index d9033f0e5f..a434325a1a 100644 --- a/csrc/gemm_groupwise_sm120.cu +++ b/csrc/gemm_groupwise_sm120.cu @@ -82,10 +82,10 @@ cudaError_t CutlassGroupwiseScaledGEMMSM120(void* float_buffer, size_t float_buf } // namespace gemm } // namespace flashinfer -void CutlassGemmGroupwiseScaledSM120(Tensor float_workspace_buffer, Tensor A, Tensor B, Tensor SFA, - Tensor SFB, Tensor C, int64_t scale_granularity_m, - int64_t scale_granularity_n, int64_t scale_granularity_k, - std::string scale_major_mode) { +void CutlassGemmGroupwiseScaledSM120(TensorView float_workspace_buffer, TensorView A, TensorView B, + TensorView SFA, TensorView SFB, TensorView C, + int64_t scale_granularity_m, int64_t scale_granularity_n, + int64_t scale_granularity_k, std::string scale_major_mode) { cudaSetDevice(float_workspace_buffer->device.device_id); auto stream = get_stream(C->device); @@ -123,7 +123,7 @@ void CutlassGemmGroupwiseScaledSM120(Tensor float_workspace_buffer, Tensor A, Te auto status = flashinfer::gemm::CutlassGroupwiseScaledGEMMSM120< SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, SCALE_MAJOR_K>( static_cast(float_workspace_buffer->data), - get_element_size(float_workspace_buffer) * get_numel(float_workspace_buffer), + get_element_size(float_workspace_buffer) * float_workspace_buffer.numel(), static_cast(A->data), static_cast(B->data), static_cast(SFA->data), static_cast(SFB->data), static_cast(C->data), m, n, k, l, diff --git a/csrc/gemm_sm100_binding.cu b/csrc/gemm_sm100_binding.cu index f489c761a5..7bbdc770cf 100644 --- a/csrc/gemm_sm100_binding.cu +++ b/csrc/gemm_sm100_binding.cu @@ -15,9 +15,10 @@ */ #include "tvm_ffi_utils.h" -void CutlassGemmGroupwiseScaledSM100(Tensor float_workspace_buffer, Tensor A, Tensor B, Tensor SFA, - Tensor SFB, Tensor C, int64_t scale_granularity_m, - int64_t scale_granularity_n, int64_t scale_granularity_k, - std::string scale_major_mode, int64_t mma_sm); +void CutlassGemmGroupwiseScaledSM100(TensorView float_workspace_buffer, TensorView A, TensorView B, + TensorView SFA, TensorView SFB, TensorView C, + int64_t scale_granularity_m, int64_t scale_granularity_n, + int64_t scale_granularity_k, std::string scale_major_mode, + int64_t mma_sm); TVM_FFI_DLL_EXPORT_TYPED_FUNC(gemm_fp8_nt_groupwise, CutlassGemmGroupwiseScaledSM100); diff --git a/csrc/gemm_sm120_binding.cu b/csrc/gemm_sm120_binding.cu index 13a94e7d56..2f25ae720e 100644 --- a/csrc/gemm_sm120_binding.cu +++ b/csrc/gemm_sm120_binding.cu @@ -15,9 +15,9 @@ */ #include "tvm_ffi_utils.h" -void CutlassGemmGroupwiseScaledSM120(Tensor float_workspace_buffer, Tensor A, Tensor B, Tensor SFA, - Tensor SFB, Tensor C, int64_t scale_granularity_m, - int64_t scale_granularity_n, int64_t scale_granularity_k, - std::string scale_major_mode); +void CutlassGemmGroupwiseScaledSM120(TensorView float_workspace_buffer, TensorView A, TensorView B, + TensorView SFA, TensorView SFB, TensorView C, + int64_t scale_granularity_m, int64_t scale_granularity_n, + int64_t scale_granularity_k, std::string scale_major_mode); TVM_FFI_DLL_EXPORT_TYPED_FUNC(gemm_fp8_nt_groupwise, CutlassGemmGroupwiseScaledSM120); diff --git a/csrc/group_gemm.cu b/csrc/group_gemm.cu index a857a791b3..73684a3fcf 100644 --- a/csrc/group_gemm.cu +++ b/csrc/group_gemm.cu @@ -20,9 +20,9 @@ using namespace flashinfer; using namespace flashinfer::group_gemm; -void CutlassSegmentGEMM(Tensor workspace_buffer, Tensor all_problems, Tensor x_ptr, Tensor w_ptr, - Tensor y_ptr, Tensor x_ld, Tensor w_ld, Tensor y_ld, Tensor empty_x_data, - bool weight_column_major) { +void CutlassSegmentGEMM(TensorView workspace_buffer, TensorView all_problems, TensorView x_ptr, + TensorView w_ptr, TensorView y_ptr, TensorView x_ld, TensorView w_ld, + TensorView y_ld, TensorView empty_x_data, bool weight_column_major) { unsigned int batch_size = x_ptr->shape[0]; cudaSetDevice(workspace_buffer->device.device_id); diff --git a/csrc/group_gemm_fp8_groupwise_sm100.cu b/csrc/group_gemm_fp8_groupwise_sm100.cu index 331b43ab16..a3abef08a4 100644 --- a/csrc/group_gemm_fp8_groupwise_sm100.cu +++ b/csrc/group_gemm_fp8_groupwise_sm100.cu @@ -86,13 +86,11 @@ cudaError_t CutlassFP8GroupwiseScaledGroupGEMMSM100( } // namespace group_gemm } // namespace flashinfer -void CutlassGroupGemmFP8GroupwiseScaledSM100(Tensor int_workspace_buffer, - Tensor float_workspace_buffer, Tensor A, Tensor B, - Tensor SFA, Tensor SFB, Tensor D, Tensor m_indptr, - int64_t n, int64_t k, int64_t scale_granularity_m, - int64_t scale_granularity_n, - int64_t scale_granularity_k, - std::string scale_major_mode, int64_t mma_sm) { +void CutlassGroupGemmFP8GroupwiseScaledSM100( + TensorView int_workspace_buffer, TensorView float_workspace_buffer, TensorView A, TensorView B, + TensorView SFA, TensorView SFB, TensorView D, TensorView m_indptr, int64_t n, int64_t k, + int64_t scale_granularity_m, int64_t scale_granularity_n, int64_t scale_granularity_k, + std::string scale_major_mode, int64_t mma_sm) { cudaSetDevice(float_workspace_buffer->device.device_id); auto stream = get_stream(D->device); int num_groups = m_indptr->shape[0] - 1; diff --git a/csrc/group_gemm_fp8_groupwise_sm120.cu b/csrc/group_gemm_fp8_groupwise_sm120.cu index b367a6db5c..f19aecd15e 100644 --- a/csrc/group_gemm_fp8_groupwise_sm120.cu +++ b/csrc/group_gemm_fp8_groupwise_sm120.cu @@ -81,9 +81,10 @@ cudaError_t CutlassFP8GroupwiseScaledGroupGEMMSM120( } // namespace flashinfer void CutlassGroupGemmFP8GroupwiseScaledSM120( - Tensor int_workspace_buffer, Tensor float_workspace_buffer, Tensor A, Tensor B, Tensor SFA, - Tensor SFB, Tensor D, Tensor m_indptr, int64_t n, int64_t k, int64_t scale_granularity_m, - int64_t scale_granularity_n, int64_t scale_granularity_k, std::string scale_major_mode) { + TensorView int_workspace_buffer, TensorView float_workspace_buffer, TensorView A, TensorView B, + TensorView SFA, TensorView SFB, TensorView D, TensorView m_indptr, int64_t n, int64_t k, + int64_t scale_granularity_m, int64_t scale_granularity_n, int64_t scale_granularity_k, + std::string scale_major_mode) { cudaSetDevice(float_workspace_buffer->device.device_id); auto stream = get_stream(D->device); int num_groups = m_indptr->shape[0] - 1; diff --git a/csrc/group_gemm_mxfp4_groupwise_sm100.cu b/csrc/group_gemm_mxfp4_groupwise_sm100.cu index 023a55a2ea..1403420602 100644 --- a/csrc/group_gemm_mxfp4_groupwise_sm100.cu +++ b/csrc/group_gemm_mxfp4_groupwise_sm100.cu @@ -127,10 +127,11 @@ cudaError_t CutlassMXFP4GroupwiseScaledGroupGEMMSM100( } // namespace group_gemm } // namespace flashinfer -void CutlassGroupGemmMXFP4GroupwiseScaledSM100(Tensor int_workspace_buffer, - Tensor float_workspace_buffer, Tensor A, Tensor B, - Tensor SFA, Tensor SFB, Tensor D, Tensor m_indptr, - int64_t n, int64_t k, int64_t mma_sm, int64_t tile_m, +void CutlassGroupGemmMXFP4GroupwiseScaledSM100(TensorView int_workspace_buffer, + TensorView float_workspace_buffer, TensorView A, + TensorView B, TensorView SFA, TensorView SFB, + TensorView D, TensorView m_indptr, int64_t n, + int64_t k, int64_t mma_sm, int64_t tile_m, int64_t tile_n, int64_t tile_k, bool swap_ab) { cudaSetDevice(float_workspace_buffer->device.device_id); auto stream = get_stream(A->device); diff --git a/csrc/group_gemm_sm100_binding.cu b/csrc/group_gemm_sm100_binding.cu index 1391eac619..f607e424d3 100644 --- a/csrc/group_gemm_sm100_binding.cu +++ b/csrc/group_gemm_sm100_binding.cu @@ -17,18 +17,17 @@ #include "tvm_ffi_utils.h" -void CutlassGroupGemmFP8GroupwiseScaledSM100(Tensor int_workspace_buffer, - Tensor float_workspace_buffer, Tensor A, Tensor B, - Tensor SFA, Tensor SFB, Tensor D, Tensor m_indptr, - int64_t n, int64_t k, int64_t scale_granularity_m, - int64_t scale_granularity_n, - int64_t scale_granularity_k, - std::string scale_major_mode, int64_t mma_sm); +void CutlassGroupGemmFP8GroupwiseScaledSM100( + TensorView int_workspace_buffer, TensorView float_workspace_buffer, TensorView A, TensorView B, + TensorView SFA, TensorView SFB, TensorView D, TensorView m_indptr, int64_t n, int64_t k, + int64_t scale_granularity_m, int64_t scale_granularity_n, int64_t scale_granularity_k, + std::string scale_major_mode, int64_t mma_sm); -void CutlassGroupGemmMXFP4GroupwiseScaledSM100(Tensor int_workspace_buffer, - Tensor float_workspace_buffer, Tensor A, Tensor B, - Tensor SFA, Tensor SFB, Tensor D, Tensor m_indptr, - int64_t n, int64_t k, int64_t mma_sm, int64_t tile_m, +void CutlassGroupGemmMXFP4GroupwiseScaledSM100(TensorView int_workspace_buffer, + TensorView float_workspace_buffer, TensorView A, + TensorView B, TensorView SFA, TensorView SFB, + TensorView D, TensorView m_indptr, int64_t n, + int64_t k, int64_t mma_sm, int64_t tile_m, int64_t tile_n, int64_t tile_k, bool swap_ab); TVM_FFI_DLL_EXPORT_TYPED_FUNC(group_gemm_fp8_nt_groupwise, CutlassGroupGemmFP8GroupwiseScaledSM100); diff --git a/csrc/group_gemm_sm120_binding.cu b/csrc/group_gemm_sm120_binding.cu index d3dadd4d44..7c014fca68 100644 --- a/csrc/group_gemm_sm120_binding.cu +++ b/csrc/group_gemm_sm120_binding.cu @@ -18,8 +18,9 @@ #include "tvm_ffi_utils.h" void CutlassGroupGemmFP8GroupwiseScaledSM120( - Tensor int_workspace_buffer, Tensor float_workspace_buffer, Tensor A, Tensor B, Tensor SFA, - Tensor SFB, Tensor D, Tensor m_indptr, int64_t n, int64_t k, int64_t scale_granularity_m, - int64_t scale_granularity_n, int64_t scale_granularity_k, std::string scale_major_mode); + TensorView int_workspace_buffer, TensorView float_workspace_buffer, TensorView A, TensorView B, + TensorView SFA, TensorView SFB, TensorView D, TensorView m_indptr, int64_t n, int64_t k, + int64_t scale_granularity_m, int64_t scale_granularity_n, int64_t scale_granularity_k, + std::string scale_major_mode); TVM_FFI_DLL_EXPORT_TYPED_FUNC(group_gemm_fp8_nt_groupwise, CutlassGroupGemmFP8GroupwiseScaledSM120); diff --git a/csrc/group_gemm_sm90.cu b/csrc/group_gemm_sm90.cu index 171d6f9902..ab6ddd5e6c 100644 --- a/csrc/group_gemm_sm90.cu +++ b/csrc/group_gemm_sm90.cu @@ -47,10 +47,11 @@ cudaError_t CutlassSegmentGEMMSM90Run(void* float_buffer, size_t float_buffer_si } // namespace group_gemm } // namespace flashinfer -void CutlassSegmentGEMMSM90(Tensor float_workspace_buffer, Tensor int_workspace_buffer, - Tensor all_problems, Tensor x_ptr, Tensor w_ptr, Tensor y_ptr, - Tensor x_stride, Tensor weight_stride, Tensor y_stride, - Tensor empty_x_data, Tensor empty_y_data, bool weight_column_major) { +void CutlassSegmentGEMMSM90(TensorView float_workspace_buffer, TensorView int_workspace_buffer, + TensorView all_problems, TensorView x_ptr, TensorView w_ptr, + TensorView y_ptr, TensorView x_stride, TensorView weight_stride, + TensorView y_stride, TensorView empty_x_data, TensorView empty_y_data, + bool weight_column_major) { unsigned int batch_size = x_ptr->shape[0]; cudaSetDevice(float_workspace_buffer->device.device_id); const cudaStream_t stream = get_stream(float_workspace_buffer->device); diff --git a/csrc/norm.cu b/csrc/norm.cu index 54b2edd5ce..43c807d753 100644 --- a/csrc/norm.cu +++ b/csrc/norm.cu @@ -19,9 +19,7 @@ using namespace flashinfer; -using tvm::ffi::Tensor; - -void rmsnorm(Tensor output, Tensor input, Tensor weight, double eps, bool enable_pdl) { +void rmsnorm(TensorView output, TensorView input, TensorView weight, double eps, bool enable_pdl) { CHECK_LAST_DIM_CONTIGUOUS_INPUT(input); CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight); CHECK_DEVICE(input, weight); @@ -46,7 +44,8 @@ void rmsnorm(Tensor output, Tensor input, Tensor weight, double eps, bool enable }); } -void fused_add_rmsnorm(Tensor input, Tensor residual, Tensor weight, double eps, bool enable_pdl) { +void fused_add_rmsnorm(TensorView input, TensorView residual, TensorView weight, double eps, + bool enable_pdl) { CHECK_LAST_DIM_CONTIGUOUS_INPUT(input); CHECK_LAST_DIM_CONTIGUOUS_INPUT(residual); CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight); @@ -75,7 +74,8 @@ void fused_add_rmsnorm(Tensor input, Tensor residual, Tensor weight, double eps, }); } -void gemma_rmsnorm(Tensor output, Tensor input, Tensor weight, double eps, bool enable_pdl) { +void gemma_rmsnorm(TensorView output, TensorView input, TensorView weight, double eps, + bool enable_pdl) { CHECK_LAST_DIM_CONTIGUOUS_INPUT(input); CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight); CHECK_DEVICE(input, weight); @@ -100,7 +100,7 @@ void gemma_rmsnorm(Tensor output, Tensor input, Tensor weight, double eps, bool }); } -void gemma_fused_add_rmsnorm(Tensor input, Tensor residual, Tensor weight, double eps, +void gemma_fused_add_rmsnorm(TensorView input, TensorView residual, TensorView weight, double eps, bool enable_pdl) { CHECK_LAST_DIM_CONTIGUOUS_INPUT(input); CHECK_LAST_DIM_CONTIGUOUS_INPUT(residual); diff --git a/csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp b/csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp index 2f94e56585..c14d77a606 100644 --- a/csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp +++ b/csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp @@ -32,7 +32,6 @@ using tvm::ffi::Array; using tvm::ffi::Map; using tvm::ffi::Optional; -using tvm::ffi::Tensor; static int getExp(float v) { int vIntRepr; @@ -141,7 +140,7 @@ int computeSFIndex(int rowIdx, int colIdx, int totalRow, int totalColumn, // Interleave (and possibly pad) the weights block scaling factor. // blockScale: [num_experts, rows, cols] or [rows, cols] // Return: num_experts * pad_up(rows, 128) * pad_up(cols, 4) -void BlockScaleInterleave(Tensor blockScale, Tensor interleavedBlockScale) { +void BlockScaleInterleave(TensorView blockScale, TensorView interleavedBlockScale) { bool is_cuda = (blockScale->device.device_type == kDLCUDA); if (is_cuda) { CHECK_CUDA(blockScale); @@ -195,7 +194,7 @@ void BlockScaleInterleave(Tensor blockScale, Tensor interleavedBlockScale) { // blockScale: [num_experts, rows, cols] or [rows, cols] // Note: rows and cols are the dimensions of the original unswizzled SFMatrix, so reshape input // before passing into this function! Return: The same shape as blockScale -void BlockScaleInterleaveReverse(Tensor const& blockScale, Tensor reversedBlockScale) { +void BlockScaleInterleaveReverse(TensorView const& blockScale, TensorView reversedBlockScale) { bool is_cuda = (blockScale->device.device_type == kDLCUDA); if (is_cuda) { CHECK_CUDA(blockScale); @@ -245,8 +244,9 @@ void BlockScaleInterleaveReverse(Tensor const& blockScale, Tensor reversedBlockS } // Used by the (fp16 -> int4) quant layer + int4 gemm network. -void E2M1AndUFP8SFScaleToFloatV2(Tensor valueE2M1, Tensor scaleFP8SF, Optional globalScale, - Tensor floatTensor, int64_t sfVecSize, int64_t sfType, +void E2M1AndUFP8SFScaleToFloatV2(TensorView valueE2M1, TensorView scaleFP8SF, + Optional globalScale, TensorView floatTensorView, + int64_t sfVecSize, int64_t sfType, bool isSfSwizzledLayout = true) { CHECK_CPU_INPUT(valueE2M1, dl_uint8); CHECK_CPU_INPUT(scaleFP8SF, dl_uint8); @@ -271,7 +271,7 @@ void E2M1AndUFP8SFScaleToFloatV2(Tensor valueE2M1, Tensor scaleFP8SF, Optional(packedShape[0]); ++vIdx) { for (int group = 0; group < groupsPerHiddenDim; ++group) { float* floatPtr = - static_cast(floatTensor->data) + vIdx * hiddenDim + group * sfVecSize; + static_cast(floatTensorView->data) + vIdx * hiddenDim + group * sfVecSize; uint8_t* packedFp4Ptr = static_cast(valueE2M1->data) + vIdx * packedFp4HiddenDim + group * sfVecSize / 2; uint8_t* scaleFP8SFPtr = static_cast(scaleFP8SF->data); @@ -298,7 +298,8 @@ void E2M1AndUFP8SFScaleToFloatV2(Tensor valueE2M1, Tensor scaleFP8SF, Optional const& globalScale, Tensor valueE2M1, - Tensor scaleFP8SF, int64_t sfVecSize, bool sfUseUE8M0, bool isSfSwizzledLayout, - bool isSf8x4Layout, bool enable_pdl) { +void fp4_quantize(TensorView self, Optional const& globalScale, TensorView valueE2M1, + TensorView scaleFP8SF, int64_t sfVecSize, bool sfUseUE8M0, + bool isSfSwizzledLayout, bool isSf8x4Layout, bool enable_pdl) { CHECK_CUDA(self); CHECK_CONTIGUOUS(self); if (sfUseUE8M0) { @@ -131,8 +131,9 @@ void fp4_quantize(Tensor self, Optional const& globalScale, Tensor value // self_fp4: [B, M, K / 2], FLOAT4_E2M1X2 // self_block_scale_factors: // [B, ceil(M / 128) * 128 * ceil(K / sfVecSize / 4) * 4], SF_DTYPE (UE4M3 or UE8M0) -void fp4_batched_quantize(Tensor self, Optional const& mask, Tensor globalScale, - Tensor valueE2M1, Tensor scaleFP8SF, int64_t sfVecSize, bool sfUseUE8M0) { +void fp4_batched_quantize(TensorView self, Optional const& mask, TensorView globalScale, + TensorView valueE2M1, TensorView scaleFP8SF, int64_t sfVecSize, + bool sfUseUE8M0) { CHECK_CUDA(self); CHECK_CONTIGUOUS(self); auto fp32_dtype = DLDataType{kDLFloat, 32, 1}; @@ -194,9 +195,9 @@ void fp4_batched_quantize(Tensor self, Optional const& mask, Tensor glob #undef LAUNCH_FP4_QUANTIZE_KERNEL } -void silu_and_mul_nvfp4_batched_quantize(Tensor const& self, Tensor const& mask, - Tensor const& globalScale, Tensor valueE2M1, - Tensor scaleFP8SF, int64_t sfVecSize) { +void silu_and_mul_nvfp4_batched_quantize(TensorView const& self, TensorView const& mask, + TensorView const& globalScale, TensorView valueE2M1, + TensorView scaleFP8SF, int64_t sfVecSize) { // TODO(shuw): mask can be none CHECK_CUDA(self); CHECK_CONTIGUOUS(self); diff --git a/csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h b/csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h index d603faa388..411f5991b5 100644 --- a/csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h +++ b/csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h @@ -24,16 +24,16 @@ #include "tensorrt_llm/thop/utils.h" using tvm::ffi::Optional; -using tvm::ffi::Tensor; using tvm::ffi::Tuple; -void fp4_quantize(Tensor self, Optional const& globalScale, Tensor valueE2M1, - Tensor scaleFP8SF, int64_t sfVecSize, bool sfUseUE8M0, bool isSfSwizzledLayout, - bool isSf8x4Layout, bool enable_pdl); +void fp4_quantize(TensorView self, Optional const& globalScale, TensorView valueE2M1, + TensorView scaleFP8SF, int64_t sfVecSize, bool sfUseUE8M0, + bool isSfSwizzledLayout, bool isSf8x4Layout, bool enable_pdl); -void fp4_batched_quantize(Tensor self, Optional const& mask, Tensor globalScale, - Tensor valueE2M1, Tensor scaleFP8SF, int64_t sfVecSize, bool sfUseUE8M0); +void fp4_batched_quantize(TensorView self, Optional const& mask, TensorView globalScale, + TensorView valueE2M1, TensorView scaleFP8SF, int64_t sfVecSize, + bool sfUseUE8M0); -void silu_and_mul_nvfp4_batched_quantize(Tensor const& self, Tensor const& mask, - Tensor const& globalScale, Tensor valueE2M1, - Tensor scaleFP8SF, int64_t sfVecSize); +void silu_and_mul_nvfp4_batched_quantize(TensorView const& self, TensorView const& mask, + TensorView const& globalScale, TensorView valueE2M1, + TensorView scaleFP8SF, int64_t sfVecSize); diff --git a/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp b/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp index 5191f81229..f943dbcfb4 100644 --- a/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp +++ b/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp @@ -26,8 +26,8 @@ // isSfSwizzledLayout: bool, if true, the scale factors are stored in swizzled layout, otherwise in // linear layout. See QuantizationSFLayout enum for more details about the two layouts. // returns -void mxfp8_quantize(Tensor input, Tensor valMxFP8, Tensor scaleFP8SF, bool isSfSwizzledLayout, - int64_t alignment, bool enable_pdl) { +void mxfp8_quantize(TensorView input, TensorView valMxFP8, TensorView scaleFP8SF, + bool isSfSwizzledLayout, int64_t alignment, bool enable_pdl) { CHECK_CUDA(input); CHECK_CONTIGUOUS(input); @@ -92,7 +92,7 @@ inline uint8_t float_to_ue8m0(float value) { } // Used in tests to quantize mxe4m3 tensors on host. -void mxfp8_quantize_host(Tensor x_fp32, Tensor fp8_tensor, Tensor scale_tensor, +void mxfp8_quantize_host(TensorView x_fp32, TensorView fp8_tensor, TensorView scale_tensor, bool is_sf_swizzled_layout) { int32_t const sf_vec_size = 32; auto fp32_dtype = DLDataType{kDLFloat, 32, 1}; @@ -138,8 +138,8 @@ void mxfp8_quantize_host(Tensor x_fp32, Tensor fp8_tensor, Tensor scale_tensor, } // Used in tests to dequantize mxe4m3 tensors on host. -void mxfp8_dequantize_host(Tensor value_e4m3, Tensor scale_ue8m08sf, Tensor float_tensor, - bool is_sf_swizzled_layout) { +void mxfp8_dequantize_host(TensorView value_e4m3, TensorView scale_ue8m08sf, + TensorView float_tensor, bool is_sf_swizzled_layout) { int32_t const sf_vec_size = 32; CHECK_INPUT_TYPE(value_e4m3, dl_uint8); CHECK_INPUT_TYPE(scale_ue8m08sf, dl_uint8); diff --git a/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.h b/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.h index a8c48e9fb0..15d587e8bf 100644 --- a/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.h +++ b/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.h @@ -66,15 +66,15 @@ inline int computeSFIndex(int rowIdx, int colIdx, int totalRow, int totalColumn, // linear layout. See QuantizationSFLayout enum for more details about the two layouts. // alignment: sfVecSize // returns fp8_quantized and block_scale_factors. -void mxfp8_quantize(Tensor input, Tensor valMxFP8, Tensor scaleFP8SF, bool is_sf_swizzled_layout, - int64_t alignment, bool enable_pdl); +void mxfp8_quantize(TensorView input, TensorView valMxFP8, TensorView scaleFP8SF, + bool is_sf_swizzled_layout, int64_t alignment, bool enable_pdl); // x_fp32: [M, K], fp32_quantized (on the host) // isSfSwizzledLayout: bool, if true, the scale factors are stored in swizzled layout, otherwise in // linear layout. See QuantizationSFLayout enum for more details about the two layouts. // returns fp8_quantized and block_scale_factors (on the host). -void mxfp8_quantize_host(Tensor x_fp32, Tensor fp8_tensor, Tensor scale_tensor, +void mxfp8_quantize_host(TensorView x_fp32, TensorView fp8_tensor, TensorView scale_tensor, bool is_sf_swizzled_layout = true); -void mxfp8_dequantize_host(Tensor value_e4m3, Tensor scale_ue8m08sf, Tensor float_tensor, - bool is_sf_swizzled_layout = true); +void mxfp8_dequantize_host(TensorView value_e4m3, TensorView scale_ue8m08sf, + TensorView float_tensor, bool is_sf_swizzled_layout = true); diff --git a/csrc/nvshmem_binding.cu b/csrc/nvshmem_binding.cu index d551c8dc1f..6d7ecacacb 100644 --- a/csrc/nvshmem_binding.cu +++ b/csrc/nvshmem_binding.cu @@ -28,9 +28,9 @@ constexpr int nvshmemx_uniqueid_t_size = sizeof(nvshmemx_uniqueid_t); using tvm::ffi::Array; using tvm::ffi::Shape; -void get_unique_id(Tensor uid) { +void get_unique_id(TensorView uid) { CHECK_CONTIGUOUS(uid); - TVM_FFI_ICHECK_EQ(get_numel(uid) * get_element_size(uid), nvshmemx_uniqueid_t_size); + TVM_FFI_ICHECK_EQ(uid.numel() * get_element_size(uid), nvshmemx_uniqueid_t_size); TVM_FFI_ICHECK_EQ(uid->device.device_type, kDLCPU); nvshmemx_uniqueid_t* uid_ptr = reinterpret_cast(uid->data); *uid_ptr = NVSHMEMX_UNIQUEID_INITIALIZER; @@ -39,9 +39,9 @@ void get_unique_id(Tensor uid) { int64_t unique_id_size() { return nvshmemx_uniqueid_t_size; } -int64_t init(Tensor uid, int64_t rank, int64_t world_size) { +int64_t init(TensorView uid, int64_t rank, int64_t world_size) { CHECK_CONTIGUOUS(uid); - TVM_FFI_ICHECK_EQ(get_numel(uid) * get_element_size(uid), nvshmemx_uniqueid_t_size); + TVM_FFI_ICHECK_EQ(uid.numel() * get_element_size(uid), nvshmemx_uniqueid_t_size); TVM_FFI_ICHECK_EQ(uid->device.device_type, kDLCPU); nvshmemx_uniqueid_t* uid_ptr = reinterpret_cast(uid->data); nvshmemx_init_attr_t attr = NVSHMEMX_INIT_ATTR_INITIALIZER; @@ -76,20 +76,20 @@ void barrier_all_on_current_stream() { nvshmemx_barrier_all_on_stream(stream); } -void alltoall(Tensor dest, Tensor source) { +void alltoall(TensorView dest, TensorView source) { CHECK_CONTIGUOUS(dest); CHECK_CONTIGUOUS(source); TVM_FFI_ICHECK_EQ(dest->dtype, source->dtype) << "dest and source must have the same dtype"; - size_t nbytes = get_numel(dest) * get_element_size(dest) / dest->shape[0]; + size_t nbytes = dest.numel() * get_element_size(dest) / dest->shape[0]; cudaStream_t stream = get_stream(dest->device); NVSHMEMCHECK(nvshmemx_alltoallmem_on_stream(NVSHMEM_TEAM_WORLD, static_cast(dest->data), static_cast(source->data), nbytes, stream)); } -void fake_alltoall(Tensor dest, Tensor source) {} +void fake_alltoall(TensorView dest, TensorView source) {} -void sum_reduce(Tensor dest, Tensor source, int64_t nelems) { +void sum_reduce(TensorView dest, TensorView source, int64_t nelems) { CHECK_CONTIGUOUS(dest); CHECK_CONTIGUOUS(source); TVM_FFI_ICHECK_EQ(dest->dtype, source->dtype) << "dest and source must have the same dtype"; @@ -124,10 +124,10 @@ void sum_reduce(Tensor dest, Tensor source, int64_t nelems) { } } -void fake_sum_reduce(Tensor dest, Tensor source, int64_t nelems) {} +void fake_sum_reduce(TensorView dest, TensorView source, int64_t nelems) {} -void allreduce_on_stream_with_copy(Tensor dest_symm, Tensor source_symm, Tensor dest_local, - Tensor source_local, int64_t nelems) { +void allreduce_on_stream_with_copy(TensorView dest_symm, TensorView source_symm, + TensorView dest_local, TensorView source_local, int64_t nelems) { CHECK_CONTIGUOUS(dest_symm); CHECK_CONTIGUOUS(source_symm); CHECK_CONTIGUOUS(dest_local); @@ -150,8 +150,9 @@ void allreduce_on_stream_with_copy(Tensor dest_symm, Tensor source_symm, Tensor cudaStreamSynchronize(stream); } -void fake_allreduce_on_stream_with_copy(Tensor dest_symm, Tensor source_symm, Tensor dest_local, - Tensor source_local, int64_t nelems) {} +void fake_allreduce_on_stream_with_copy(TensorView dest_symm, TensorView source_symm, + TensorView dest_local, TensorView source_local, + int64_t nelems) {} TVM_FFI_DLL_EXPORT_TYPED_FUNC(nvshmem_get_unique_id, get_unique_id); TVM_FFI_DLL_EXPORT_TYPED_FUNC(nvshmem_unique_id_size, unique_id_size); diff --git a/csrc/page.cu b/csrc/page.cu index e4058b3883..e6397f6150 100644 --- a/csrc/page.cu +++ b/csrc/page.cu @@ -21,9 +21,9 @@ using namespace flashinfer; using tvm::ffi::Tensor; -void append_paged_kv_cache(Tensor append_key, Tensor append_value, Tensor batch_indices, - Tensor positions, Tensor paged_k_cache, Tensor paged_v_cache, - Tensor kv_indices, Tensor kv_indptr, Tensor kv_last_page_len, +void append_paged_kv_cache(TensorView append_key, TensorView append_value, TensorView batch_indices, + TensorView positions, TensorView paged_k_cache, TensorView paged_v_cache, + TensorView kv_indices, TensorView kv_indptr, TensorView kv_last_page_len, int64_t layout) { CHECK_LAST_DIM_CONTIGUOUS(append_key); CHECK_LAST_DIM_CONTIGUOUS(append_value); @@ -109,12 +109,10 @@ void append_paged_kv_cache(Tensor append_key, Tensor append_value, Tensor batch_ << paged_k_cache->dtype; } -void block_sparse_indices_to_vector_sparse_offsets(Tensor block_sparse_indices, - Tensor block_sparse_indptr, - Tensor vector_sparse_offsets, - Tensor vector_sparse_indptr, Tensor kv_len_arr, - int64_t stride_block, int64_t stride_n, - int64_t batch_size, int64_t block_size) { +void block_sparse_indices_to_vector_sparse_offsets( + TensorView block_sparse_indices, TensorView block_sparse_indptr, + TensorView vector_sparse_offsets, TensorView vector_sparse_indptr, TensorView kv_len_arr, + int64_t stride_block, int64_t stride_n, int64_t batch_size, int64_t block_size) { CHECK_INPUT(block_sparse_indices); CHECK_INPUT(block_sparse_indptr); CHECK_INPUT(vector_sparse_offsets); @@ -135,9 +133,10 @@ void block_sparse_indices_to_vector_sparse_offsets(Tensor block_sparse_indices, << "BlockSparseIndicesToVectorSparseOffset failed with error: " << cudaGetErrorString(status); } -void append_paged_mla_kv_cache(Tensor append_ckv, Tensor append_kpe, Tensor batch_indices, - Tensor positions, Tensor ckv_cache, Tensor kpe_cache, - Tensor kv_indices, Tensor kv_indptr, Tensor kv_last_page_len) { +void append_paged_mla_kv_cache(TensorView append_ckv, TensorView append_kpe, + TensorView batch_indices, TensorView positions, TensorView ckv_cache, + TensorView kpe_cache, TensorView kv_indices, TensorView kv_indptr, + TensorView kv_last_page_len) { CHECK_LAST_DIM_CONTIGUOUS(append_ckv); CHECK_LAST_DIM_CONTIGUOUS(append_kpe); CHECK_INPUT(batch_indices); diff --git a/csrc/pod.cu b/csrc/pod.cu index 7e7325908f..b9036c7532 100644 --- a/csrc/pod.cu +++ b/csrc/pod.cu @@ -39,17 +39,19 @@ using tvm::ffi::Optional; void pod_with_kv_cache_tensor( // Prefill params - Tensor q_p, Tensor k_p, Tensor v_p, Tensor tmp_p, Tensor o_p, Optional maybe_lse_p, - int64_t mask_mode_code_p, int64_t layout_p, int64_t window_left_p, - Optional maybe_custom_mask_p, Optional maybe_alibi_slopes_p, - double logits_soft_cap_p, double sm_scale_p, double rope_rcp_scale_p, double rope_rcp_theta_p, + TensorView q_p, TensorView k_p, TensorView v_p, TensorView tmp_p, TensorView o_p, + Optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, + int64_t window_left_p, Optional maybe_custom_mask_p, + Optional maybe_alibi_slopes_p, double logits_soft_cap_p, double sm_scale_p, + double rope_rcp_scale_p, double rope_rcp_theta_p, // Decode params - Tensor float_workspace_buffer_d, Tensor int_workspace_buffer_d, Array plan_info_vec, - Tensor q_d, Tensor paged_k_cache_d, Tensor paged_v_cache_d, Tensor qo_indptr_d, - Tensor paged_kv_indptr_d, Tensor paged_kv_indices_d, Tensor paged_kv_last_page_len_d, - Tensor o_d, Optional maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d, - int64_t window_left_d, Optional maybe_custom_mask_d, - Optional maybe_mask_indptr_d, Optional maybe_alibi_slopes_d, + TensorView float_workspace_buffer_d, TensorView int_workspace_buffer_d, + Array plan_info_vec, TensorView q_d, TensorView paged_k_cache_d, + TensorView paged_v_cache_d, TensorView qo_indptr_d, TensorView paged_kv_indptr_d, + TensorView paged_kv_indices_d, TensorView paged_kv_last_page_len_d, TensorView o_d, + Optional maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d, + int64_t window_left_d, Optional maybe_custom_mask_d, + Optional maybe_mask_indptr_d, Optional maybe_alibi_slopes_d, double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, bool enable_pdl) { // Prefill setup @@ -83,7 +85,7 @@ void pod_with_kv_cache_tensor( const MaskMode mask_mode_p = static_cast(mask_mode_code_p); - // Decode setup (Tensor decode = batched prefill) + // Decode setup (TensorView decode = batched prefill) PrefillPlanInfo plan_info; plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); QKVLayout kv_layout_d = static_cast(layout_d); diff --git a/csrc/pod_jit_binding.cu b/csrc/pod_jit_binding.cu index 86c0caff60..915e4bcdbf 100644 --- a/csrc/pod_jit_binding.cu +++ b/csrc/pod_jit_binding.cu @@ -21,17 +21,19 @@ using tvm::ffi::Optional; void pod_with_kv_cache_tensor( // Prefill params - Tensor q_p, Tensor k_p, Tensor v_p, Tensor tmp_p, Tensor o_p, Optional maybe_lse_p, - int64_t mask_mode_code_p, int64_t layout_p, int64_t window_left_p, - Optional maybe_custom_mask_p, Optional maybe_alibi_slopes_p, - double logits_soft_cap_p, double sm_scale_p, double rope_rcp_scale_p, double rope_rcp_theta_p, + TensorView q_p, TensorView k_p, TensorView v_p, TensorView tmp_p, TensorView o_p, + Optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, + int64_t window_left_p, Optional maybe_custom_mask_p, + Optional maybe_alibi_slopes_p, double logits_soft_cap_p, double sm_scale_p, + double rope_rcp_scale_p, double rope_rcp_theta_p, // Decode params - Tensor float_workspace_buffer_d, Tensor int_workspace_buffer_d, Array plan_info_vec, - Tensor q_d, Tensor paged_k_cache_d, Tensor paged_v_cache_d, Tensor qo_indptr_d, - Tensor paged_kv_indptr_d, Tensor paged_kv_indices_d, Tensor paged_kv_last_page_len_d, - Tensor o_d, Optional maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d, - int64_t window_left_d, Optional maybe_custom_mask_d, - Optional maybe_mask_indptr_d, Optional maybe_alibi_slopes_d, + TensorView float_workspace_buffer_d, TensorView int_workspace_buffer_d, + Array plan_info_vec, TensorView q_d, TensorView paged_k_cache_d, + TensorView paged_v_cache_d, TensorView qo_indptr_d, TensorView paged_kv_indptr_d, + TensorView paged_kv_indices_d, TensorView paged_kv_last_page_len_d, TensorView o_d, + Optional maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d, + int64_t window_left_d, Optional maybe_custom_mask_d, + Optional maybe_mask_indptr_d, Optional maybe_alibi_slopes_d, double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, bool enable_pdl); diff --git a/csrc/quantization.cu b/csrc/quantization.cu index b00de23843..557d79af79 100644 --- a/csrc/quantization.cu +++ b/csrc/quantization.cu @@ -19,12 +19,12 @@ using namespace flashinfer; -void packbits(Tensor x, const std::string& bitorder, Tensor y) { +void packbits(TensorView x, const std::string& bitorder, TensorView y) { CHECK_INPUT(x); auto device = x->device; TVM_FFI_ICHECK(bitorder == "big" || bitorder == "little") << "bitorder must be 'big' or 'little'"; - int64_t num_elements = get_numel(x); + int64_t num_elements = x.numel(); auto stream = get_stream(x->device); cudaError_t status = quantization::PackBits( static_cast(x->data), static_cast(y->data), num_elements, @@ -34,8 +34,8 @@ void packbits(Tensor x, const std::string& bitorder, Tensor y) { << "PackBits failed with error code " << cudaGetErrorString(status); } -void segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, - const std::string& bitorder, Tensor y) { +void segment_packbits(TensorView x, TensorView input_indptr, TensorView output_indptr, + const std::string& bitorder, TensorView y) { CHECK_INPUT(x); CHECK_INPUT(input_indptr); CHECK_INPUT(output_indptr); diff --git a/csrc/renorm.cu b/csrc/renorm.cu index 07d190be2e..d186b40eae 100644 --- a/csrc/renorm.cu +++ b/csrc/renorm.cu @@ -21,8 +21,8 @@ using namespace flashinfer; using tvm::ffi::Optional; -void top_p_renorm_probs(Tensor probs, Tensor renorm_probs, Optional maybe_top_p_arr, - double top_p_val) { +void top_p_renorm_probs(TensorView probs, TensorView renorm_probs, + Optional maybe_top_p_arr, double top_p_val) { CHECK_INPUT(probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) unsigned int batch_size = probs->shape[0]; @@ -39,8 +39,8 @@ void top_p_renorm_probs(Tensor probs, Tensor renorm_probs, Optional mayb << "TopPRenormProb failed with error code " << cudaGetErrorString(status); } -void top_k_renorm_probs(Tensor probs, Tensor renorm_probs, Optional maybe_top_k_arr, - int64_t top_k_val) { +void top_k_renorm_probs(TensorView probs, TensorView renorm_probs, + Optional maybe_top_k_arr, int64_t top_k_val) { CHECK_INPUT(probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) unsigned int batch_size = probs->shape[0]; @@ -58,8 +58,8 @@ void top_k_renorm_probs(Tensor probs, Tensor renorm_probs, Optional mayb << "TopKRenormProb failed with error code " << cudaGetErrorString(status); } -void top_k_mask_logits(Tensor logits, Tensor mask_logits, Optional maybe_top_k_arr, - int64_t top_k_val) { +void top_k_mask_logits(TensorView logits, TensorView mask_logits, + Optional maybe_top_k_arr, int64_t top_k_val) { CHECK_INPUT(logits); CHECK_DIM(2, logits); // logits: (batch_size, vocab_size) unsigned int batch_size = logits->shape[0]; diff --git a/csrc/rope.cu b/csrc/rope.cu index ed1d26806f..d435cc7377 100644 --- a/csrc/rope.cu +++ b/csrc/rope.cu @@ -21,8 +21,9 @@ using namespace flashinfer; using tvm::ffi::Tensor; -void apply_rope(Tensor q, Tensor k, Tensor q_rope, Tensor k_rope, Tensor indptr, Tensor offsets, - int64_t rotary_dim, bool interleave, double rope_scale, double rope_theta) { +void apply_rope(TensorView q, TensorView k, TensorView q_rope, TensorView k_rope, TensorView indptr, + TensorView offsets, int64_t rotary_dim, bool interleave, double rope_scale, + double rope_theta) { CHECK_LAST_DIM_CONTIGUOUS_INPUT(q); CHECK_LAST_DIM_CONTIGUOUS_INPUT(k); CHECK_INPUT(indptr); @@ -68,8 +69,9 @@ void apply_rope(Tensor q, Tensor k, Tensor q_rope, Tensor k_rope, Tensor indptr, }); } -void apply_rope_pos_ids(Tensor q, Tensor k, Tensor q_rope, Tensor k_rope, Tensor pos_ids, - int64_t rotary_dim, bool interleave, double rope_scale, double rope_theta) { +void apply_rope_pos_ids(TensorView q, TensorView k, TensorView q_rope, TensorView k_rope, + TensorView pos_ids, int64_t rotary_dim, bool interleave, double rope_scale, + double rope_theta) { CHECK_LAST_DIM_CONTIGUOUS_INPUT(q); CHECK_LAST_DIM_CONTIGUOUS_INPUT(k); CHECK_INPUT(pos_ids); @@ -111,8 +113,9 @@ void apply_rope_pos_ids(Tensor q, Tensor k, Tensor q_rope, Tensor k_rope, Tensor }); } -void apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor q_rope, Tensor k_rope, - Tensor cos_sin_cache, Tensor pos_ids, bool interleave) { +void apply_rope_pos_ids_cos_sin_cache(TensorView q, TensorView k, TensorView q_rope, + TensorView k_rope, TensorView cos_sin_cache, + TensorView pos_ids, bool interleave) { CHECK_LAST_DIM_CONTIGUOUS_INPUT(q); CHECK_LAST_DIM_CONTIGUOUS_INPUT(k); CHECK_INPUT(cos_sin_cache); @@ -161,10 +164,10 @@ void apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor q_rope, Tensor }); } -void apply_llama31_rope(Tensor q, Tensor k, Tensor q_rope, Tensor k_rope, Tensor indptr, - Tensor offsets, int64_t rotary_dim, bool interleave, double rope_scale, - double rope_theta, double low_freq_factor, double high_freq_factor, - double old_context_length) { +void apply_llama31_rope(TensorView q, TensorView k, TensorView q_rope, TensorView k_rope, + TensorView indptr, TensorView offsets, int64_t rotary_dim, bool interleave, + double rope_scale, double rope_theta, double low_freq_factor, + double high_freq_factor, double old_context_length) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(indptr); @@ -213,10 +216,10 @@ void apply_llama31_rope(Tensor q, Tensor k, Tensor q_rope, Tensor k_rope, Tensor }); } -void apply_llama31_rope_pos_ids(Tensor q, Tensor k, Tensor q_rope, Tensor k_rope, Tensor pos_ids, - int64_t rotary_dim, bool interleave, double rope_scale, - double rope_theta, double low_freq_factor, double high_freq_factor, - double old_context_length) { +void apply_llama31_rope_pos_ids(TensorView q, TensorView k, TensorView q_rope, TensorView k_rope, + TensorView pos_ids, int64_t rotary_dim, bool interleave, + double rope_scale, double rope_theta, double low_freq_factor, + double high_freq_factor, double old_context_length) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(pos_ids); @@ -259,10 +262,11 @@ void apply_llama31_rope_pos_ids(Tensor q, Tensor k, Tensor q_rope, Tensor k_rope }); } -void mla_rope_quantize(Tensor q_rope_in, Tensor k_rope_in, Tensor q_nope_in, Tensor k_nope_in, - Tensor q_rope_out, Tensor k_rope_out, Tensor q_nope_out, Tensor k_nope_out, - Tensor cos_sin_cache, Tensor pos_ids, double quant_scale_q, - double quant_scale_kv, bool interleave) { +void mla_rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in, + TensorView k_nope_in, TensorView q_rope_out, TensorView k_rope_out, + TensorView q_nope_out, TensorView k_nope_out, TensorView cos_sin_cache, + TensorView pos_ids, double quant_scale_q, double quant_scale_kv, + bool interleave) { CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_rope_in); CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_rope_in); CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_nope_in); diff --git a/csrc/sampling.cu b/csrc/sampling.cu index 9795e779ab..d17295d091 100644 --- a/csrc/sampling.cu +++ b/csrc/sampling.cu @@ -21,8 +21,8 @@ using namespace flashinfer; using tvm::ffi::Optional; -void softmax(Tensor workspace_buffer, Tensor logits, Tensor output, - Optional maybe_temperature_arr, double temperature_val, bool enable_pdl) { +void softmax(TensorView workspace_buffer, TensorView logits, TensorView output, + Optional maybe_temperature_arr, double temperature_val, bool enable_pdl) { CHECK_INPUT(workspace_buffer); CHECK_INPUT(logits); CHECK_INPUT(output); @@ -43,7 +43,7 @@ void softmax(Tensor workspace_buffer, Tensor logits, Tensor output, << "OnlineSoftmax failed with error code " << cudaGetErrorString(status); } -void sampling_from_logits(Tensor logits, Tensor output, Optional maybe_indices, +void sampling_from_logits(TensorView logits, TensorView output, Optional maybe_indices, bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { CHECK_INPUT(logits); CHECK_DIM(2, logits); // logits: (batch_size, vocab_size) @@ -60,7 +60,7 @@ void sampling_from_logits(Tensor logits, Tensor output, Optional maybe_i << "SamplingFromLogits failed with error code " << cudaGetErrorString(status); } -void sampling_from_probs(Tensor probs, Tensor output, Optional maybe_indices, +void sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { CHECK_INPUT(probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) @@ -77,8 +77,9 @@ void sampling_from_probs(Tensor probs, Tensor output, Optional maybe_ind << "SamplingFromProbs failed with error code " << cudaGetErrorString(status); } -void top_p_sampling_from_probs(Tensor probs, Tensor output, Optional maybe_indices, - Optional maybe_top_p_arr, double top_p_val, +void top_p_sampling_from_probs(TensorView probs, TensorView output, + Optional maybe_indices, + Optional maybe_top_p_arr, double top_p_val, bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { CHECK_INPUT(probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) @@ -97,8 +98,9 @@ void top_p_sampling_from_probs(Tensor probs, Tensor output, Optional may << "TopPSamplingFromProbs failed with error code " << cudaGetErrorString(status); } -void top_k_sampling_from_probs(Tensor probs, Tensor output, Optional maybe_indices, - Optional maybe_top_k_arr, int64_t top_k_val, +void top_k_sampling_from_probs(TensorView probs, TensorView output, + Optional maybe_indices, + Optional maybe_top_k_arr, int64_t top_k_val, bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { CHECK_INPUT(probs); CHECK_INPUT(output); @@ -120,8 +122,9 @@ void top_k_sampling_from_probs(Tensor probs, Tensor output, Optional may << "TopKSamplingFromProbs failed with error code " << cudaGetErrorString(status); } -void min_p_sampling_from_probs(Tensor probs, Tensor output, Optional maybe_indices, - Optional maybe_min_p_arr, double min_p_val, +void min_p_sampling_from_probs(TensorView probs, TensorView output, + Optional maybe_indices, + Optional maybe_min_p_arr, double min_p_val, bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { CHECK_INPUT(probs); CHECK_INPUT(output); @@ -144,9 +147,10 @@ void min_p_sampling_from_probs(Tensor probs, Tensor output, Optional may << "MinPSamplingFromProb failed with error code " << cudaGetErrorString(status); } -void top_k_top_p_sampling_from_probs(Tensor probs, Tensor output, Optional maybe_indices, - Optional maybe_top_k_arr, double top_k_val, - Optional maybe_top_p_arr, double top_p_val, +void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output, + Optional maybe_indices, + Optional maybe_top_k_arr, double top_k_val, + Optional maybe_top_p_arr, double top_p_val, bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { CHECK_INPUT(probs); @@ -173,9 +177,10 @@ void top_k_top_p_sampling_from_probs(Tensor probs, Tensor output, Optional maybe_lse, int64_t layout, +void single_decode_with_kv_cache(TensorView q, TensorView k, TensorView v, TensorView tmp, + TensorView o, Optional maybe_lse, int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS) { CHECK_INPUT(q); CHECK_INPUT(k); diff --git a/csrc/single_decode_jit_binding.cu b/csrc/single_decode_jit_binding.cu index d6ed939800..f1a7772028 100644 --- a/csrc/single_decode_jit_binding.cu +++ b/csrc/single_decode_jit_binding.cu @@ -19,8 +19,8 @@ using tvm::ffi::Optional; -void single_decode_with_kv_cache(Tensor q, Tensor k, Tensor v, Tensor tmp, Tensor o, - Optional maybe_lse, int64_t layout, +void single_decode_with_kv_cache(TensorView q, TensorView k, TensorView v, TensorView tmp, + TensorView o, Optional maybe_lse, int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS); // Single-request decode with KV-Cache operator diff --git a/csrc/single_prefill.cu b/csrc/single_prefill.cu index a1d864209f..46d24a71ea 100644 --- a/csrc/single_prefill.cu +++ b/csrc/single_prefill.cu @@ -34,10 +34,10 @@ cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::D using namespace flashinfer; -void single_prefill_with_kv_cache(ffi::Tensor q, ffi::Tensor k, ffi::Tensor v, ffi::Tensor tmp, - ffi::Tensor o, Optional maybe_lse, - int64_t mask_mode_code, int64_t layout, - int64_t window_left ADDITIONAL_FUNC_PARAMS) { +void single_prefill_with_kv_cache(ffi::TensorView q, ffi::TensorView k, ffi::TensorView v, + ffi::TensorView tmp, ffi::TensorView o, + Optional maybe_lse, int64_t mask_mode_code, + int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS) { unsigned int head_dim_qk = q->shape[2]; unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads; QKVLayout kv_layout = static_cast(layout); diff --git a/csrc/single_prefill_fp8_sm90.cu b/csrc/single_prefill_fp8_sm90.cu index 23c3f8ad21..6051d2c737 100644 --- a/csrc/single_prefill_fp8_sm90.cu +++ b/csrc/single_prefill_fp8_sm90.cu @@ -32,10 +32,10 @@ using namespace flashinfer; using tvm::ffi::Optional; -void single_prefill_with_kv_cache_sm90(ffi::Tensor q, ffi::Tensor k, ffi::Tensor v, ffi::Tensor tmp, - ffi::Tensor o, Optional maybe_lse, - int64_t mask_mode_code, int64_t layout, - int64_t window_left ADDITIONAL_FUNC_PARAMS) { +void single_prefill_with_kv_cache_sm90(ffi::TensorView q, ffi::TensorView k, ffi::TensorView v, + ffi::TensorView tmp, ffi::TensorView o, + Optional maybe_lse, int64_t mask_mode_code, + int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS) { unsigned int head_dim_qk = q->shape[2]; unsigned int head_dim_vo = v->shape[2]; unsigned int num_qo_heads = q->shape[1]; diff --git a/csrc/single_prefill_jit_binding.cu b/csrc/single_prefill_jit_binding.cu index a1ed64cc8f..178e32dd84 100644 --- a/csrc/single_prefill_jit_binding.cu +++ b/csrc/single_prefill_jit_binding.cu @@ -18,10 +18,10 @@ using tvm::ffi::Optional; -void single_prefill_with_kv_cache(ffi::Tensor q, ffi::Tensor k, ffi::Tensor v, ffi::Tensor tmp, - ffi::Tensor o, Optional maybe_lse, - int64_t mask_mode_code, int64_t layout, - int64_t window_left ADDITIONAL_FUNC_PARAMS); +void single_prefill_with_kv_cache(ffi::TensorView q, ffi::TensorView k, ffi::TensorView v, + ffi::TensorView tmp, ffi::TensorView o, + Optional maybe_lse, int64_t mask_mode_code, + int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS); // Single-request prefill attention with KV-Cache operator TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, single_prefill_with_kv_cache); diff --git a/csrc/single_prefill_sm90.cu b/csrc/single_prefill_sm90.cu index d440022dea..8cfe8e3713 100644 --- a/csrc/single_prefill_sm90.cu +++ b/csrc/single_prefill_sm90.cu @@ -32,10 +32,10 @@ using namespace flashinfer; using tvm::ffi::Optional; -void single_prefill_with_kv_cache_sm90(ffi::Tensor q, ffi::Tensor k, ffi::Tensor v, ffi::Tensor tmp, - ffi::Tensor o, Optional maybe_lse, - int64_t mask_mode_code, int64_t layout, - int64_t window_left ADDITIONAL_FUNC_PARAMS) { +void single_prefill_with_kv_cache_sm90(ffi::TensorView q, ffi::TensorView k, ffi::TensorView v, + ffi::TensorView tmp, ffi::TensorView o, + Optional maybe_lse, int64_t mask_mode_code, + int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS) { unsigned int head_dim_qk = q->shape[2]; unsigned int head_dim_vo = v->shape[2]; unsigned int num_qo_heads = q->shape[1]; diff --git a/csrc/single_prefill_sm90_jit_binding.cu b/csrc/single_prefill_sm90_jit_binding.cu index 5a1546bd83..f8871e5299 100644 --- a/csrc/single_prefill_sm90_jit_binding.cu +++ b/csrc/single_prefill_sm90_jit_binding.cu @@ -18,10 +18,10 @@ using tvm::ffi::Optional; -void single_prefill_with_kv_cache_sm90(ffi::Tensor q, ffi::Tensor k, ffi::Tensor v, ffi::Tensor tmp, - ffi::Tensor o, Optional maybe_lse, - int64_t mask_mode_code, int64_t layout, - int64_t window_left ADDITIONAL_FUNC_PARAMS); +void single_prefill_with_kv_cache_sm90(ffi::TensorView q, ffi::TensorView k, ffi::TensorView v, + ffi::TensorView tmp, ffi::TensorView o, + Optional maybe_lse, int64_t mask_mode_code, + int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS); // Single-request prefill attention with KV-Cache operator TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, single_prefill_with_kv_cache_sm90); diff --git a/csrc/tgv_gemm.cu b/csrc/tgv_gemm.cu index 7090656dcc..a8d9cb334a 100644 --- a/csrc/tgv_gemm.cu +++ b/csrc/tgv_gemm.cu @@ -116,8 +116,8 @@ void tgv_gemm_impl(input_type* mat1_ptr, input_type* mat2_ptr, output_type* outp } // namespace -Tensor tgv_gemm(Tensor const& mat1, Tensor const& mat2, Optional bias, int64_t tactic, - bool pdl) { +void tgv_gemm(TensorView mat1, TensorView mat2, Optional bias, int64_t tactic, + TensorView out, bool pdl) { // Input validation TVM_FFI_ICHECK_EQ(mat1->device.device_type, kDLCUDA) << "mat1 tensor must be on CUDA"; TVM_FFI_ICHECK_EQ(mat2->device.device_type, kDLCUDA) << "mat2 tensor must be on CUDA"; @@ -159,7 +159,9 @@ Tensor tgv_gemm(Tensor const& mat1, Tensor const& mat2, Optional bias, i } // Create output tensor [N, M] row major - Tensor C = alloc_tensor({N, M}, mat1->dtype, mat1->device); + TVM_FFI_ICHECK_EQ(out->shape[0], N); + TVM_FFI_ICHECK_EQ(out->shape[1], M); + TVM_FFI_ICHECK_EQ(out->dtype, mat1->dtype); // manually calculate the L stride // A [M, K] row major @@ -171,12 +173,12 @@ Tensor tgv_gemm(Tensor const& mat1, Tensor const& mat2, Optional bias, i int stride_B_K = mat2->strides[0]; int stride_B_L = N * K; // original C [N, M] row major - int stride_C_M = C->strides[1]; - int stride_C_N = C->strides[0]; + int stride_C_M = out->strides[1]; + int stride_C_N = out->strides[0]; int stride_C_L = M * N; // Get CUDA stream - cudaStream_t stream = get_stream(C->device); + cudaStream_t stream = get_stream(out->device); // Dispatch based on dtype DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(mat1->dtype, c_type, [&] { @@ -185,7 +187,7 @@ Tensor tgv_gemm(Tensor const& mat1, Tensor const& mat2, Optional bias, i cutlass_input_type* mat1_ptr = static_cast(mat1->data); cutlass_input_type* mat2_ptr = static_cast(mat2->data); - cutlass_output_type* output_ptr = static_cast(C->data); + cutlass_output_type* output_ptr = static_cast(out->data); cutlass_output_type* bias_ptr = bias.has_value() ? static_cast(bias.value()->data) : nullptr; @@ -199,18 +201,17 @@ Tensor tgv_gemm(Tensor const& mat1, Tensor const& mat2, Optional bias, i // original C is [N, M] row major // after transpose, it's [M, N] column major // the storage is unchanged, only the logical coordinates are changed - std::swap(C->shape[0], C->shape[1]); - std::swap(C->strides[0], C->strides[1]); - return C; + std::swap(out->shape[0], out->shape[1]); + std::swap(out->strides[0], out->strides[1]); } // Keep backward compatibility functions -Tensor bf16_gemm(Tensor const& mat1, Tensor const& mat2, std::optional bias, int64_t tactic, - bool pdl) { +void bf16_gemm(TensorView mat1, TensorView mat2, std::optional bias, int64_t tactic, + TensorView out, bool pdl) { // Check that inputs are bfloat16 for backward compatibility TVM_FFI_ICHECK_EQ(mat1->dtype, dl_bfloat16) << "mat1 tensor must be bfloat16"; TVM_FFI_ICHECK_EQ(mat2->dtype, dl_bfloat16) << "mat2 tensor must be bfloat16"; - return tgv_gemm(mat1, mat2, bias, tactic, pdl); + tgv_gemm(mat1, mat2, bias, tactic, out, pdl); } int64_t tgv_gemm_tactic_num() { diff --git a/csrc/trtllm_allreduce.cu b/csrc/trtllm_allreduce.cu index e05238dd28..7546658282 100644 --- a/csrc/trtllm_allreduce.cu +++ b/csrc/trtllm_allreduce.cu @@ -84,17 +84,17 @@ void trtllm_lamport_initialize_all(int64_t buffer_0_ptr, int64_t buffer_1_ptr, i } // refer to cpp/tests/unit_tests/kernels/allReduce/allReduceFusionTest.cu:L268 -void trtllm_custom_all_reduce(Tensor in, Tensor out, int64_t tp_size, int64_t tp_rank, +void trtllm_custom_all_reduce(TensorView in, TensorView out, int64_t tp_size, int64_t tp_rank, int64_t token_num, int64_t fusion_op_code, int64_t strategy_code, int64_t config_code, bool launch_with_pdl, int64_t flag_value, - Tensor peer_comm_buffer_ptrs, Tensor peer_barrier_ptrs_in, - Tensor peer_barrier_ptrs_out, Optional bias, - Optional residual, Optional weight, - Optional weight_pre_residual_norm, Optional eps, - Optional intermediate_buffer, - Optional lamport_peer_comm_buffer_ptrs_0, - Optional lamport_peer_comm_buffer_ptrs_1, - Optional lamport_peer_comm_buffer_ptrs_2) { + TensorView peer_comm_buffer_ptrs, TensorView peer_barrier_ptrs_in, + TensorView peer_barrier_ptrs_out, Optional bias, + Optional residual, Optional weight, + Optional weight_pre_residual_norm, Optional eps, + Optional intermediate_buffer, + Optional lamport_peer_comm_buffer_ptrs_0, + Optional lamport_peer_comm_buffer_ptrs_1, + Optional lamport_peer_comm_buffer_ptrs_2) { AllReduceFusionOp fusion_op = static_cast(fusion_op_code); cudaSetDevice(in->device.device_id); auto stream = get_stream(in->device); @@ -102,8 +102,8 @@ void trtllm_custom_all_reduce(Tensor in, Tensor out, int64_t tp_size, int64_t tp // TODO(zihao): review dispatch type - support fp16, bf16 only DISPATCH_FLOATING_TYPES_FOR_ALLREDUCE(in->dtype, c_type, [&] { // TODO(yingyi): remove type template here (used to check if lamport is supported) - int64_t message_size = get_numel(in); - int64_t hidden_size = get_numel(in) / token_num; + int64_t message_size = in.numel(); + int64_t hidden_size = in.numel() / token_num; AllReduceParams params; params.elts_total = message_size; diff --git a/csrc/trtllm_allreduce_fusion.cu b/csrc/trtllm_allreduce_fusion.cu index 16f1a25546..4b465c7ffc 100644 --- a/csrc/trtllm_allreduce_fusion.cu +++ b/csrc/trtllm_allreduce_fusion.cu @@ -28,15 +28,15 @@ using tvm::ffi::Optional; } \ }() -void trtllm_allreduce_fusion(Tensor allreduce_in, int64_t world_size, int64_t world_rank, - int64_t token_num, int64_t hidden_size, Tensor workspace_ptrs, +void trtllm_allreduce_fusion(TensorView allreduce_in, int64_t world_size, int64_t world_rank, + int64_t token_num, int64_t hidden_size, TensorView workspace_ptrs, bool launch_with_pdl, bool use_oneshot, bool trigger_completion_at_end, - bool fp32_acc, int64_t pattern_code, Optional allreduce_out, - Optional residual_in, Optional residual_out, - Optional norm_out, Optional quant_out, - Optional scale_out, Optional rms_gamma, - Optional rms_eps, Optional scale_factor, - Optional layout_code) { + bool fp32_acc, int64_t pattern_code, + Optional allreduce_out, Optional residual_in, + Optional residual_out, Optional norm_out, + Optional quant_out, Optional scale_out, + Optional rms_gamma, Optional rms_eps, + Optional scale_factor, Optional layout_code) { cudaSetDevice(allreduce_in->device.device_id); // todo(Yingyi): add dispatch for float and bfloat16 diff --git a/csrc/trtllm_alltoall.cu b/csrc/trtllm_alltoall.cu index 4376922835..51e0b2c2d1 100644 --- a/csrc/trtllm_alltoall.cu +++ b/csrc/trtllm_alltoall.cu @@ -26,11 +26,12 @@ using namespace flashinfer::trtllm_alltoall; using tvm::ffi::Optional; using tvm::ffi::Tuple; -void moeCommPrepareIndicesOp(Tensor gatheredTargetRankIds, - Optional realRankTokenCountCumSum, Tensor localGatherIndices, - Tensor sendRankCountCumSum, Tensor sendRankLocalIndices, - Tensor recvRankCountCumSum, Tensor recvRankLocalIndices, - Tensor backwardRecvRankLocalIndices, int64_t maxTokenCountPerRank, +void moeCommPrepareIndicesOp(TensorView gatheredTargetRankIds, + Optional realRankTokenCountCumSum, + TensorView localGatherIndices, TensorView sendRankCountCumSum, + TensorView sendRankLocalIndices, TensorView recvRankCountCumSum, + TensorView recvRankLocalIndices, + TensorView backwardRecvRankLocalIndices, int64_t maxTokenCountPerRank, int64_t expertCount, int64_t topK, int64_t epRank, int64_t epSize) { CHECK_INPUT_TYPE(gatheredTargetRankIds, dl_int32); TVM_FFI_ICHECK_EQ(gatheredTargetRankIds->ndim, 2) << "gatheredTargetRankIds must be a 2D tensor"; @@ -102,8 +103,9 @@ void moeCommPrepareIndicesOp(Tensor gatheredTargetRankIds, << "CUDA error in moeAllToAllPrepareIndices: " << cudaGetErrorString(cudaResult); } -void moeLocalGatherOp(Tensor recvRankCumSum, Tensor localGatherIndices, Tensor gatheredExpertIds, - Tensor gatheredScales, Tensor localExpertIds, Tensor localScales, +void moeLocalGatherOp(TensorView recvRankCumSum, TensorView localGatherIndices, + TensorView gatheredExpertIds, TensorView gatheredScales, + TensorView localExpertIds, TensorView localScales, int64_t maxTokenCountPerRank, int64_t expertCount, int64_t topK, int64_t epRank, int64_t epSize) { CHECK_INPUT_TYPE(recvRankCumSum, dl_int32); @@ -153,9 +155,9 @@ void moeLocalGatherOp(Tensor recvRankCumSum, Tensor localGatherIndices, Tensor g static_cast(localExpertIds->data), static_cast(localScales->data), stream); } -void moeCommOp(Tensor input, Tensor sendRankCumSum, Tensor sendIndices, Tensor output, - Tensor recvRankCumSum, Tensor recvIndices, Tensor allWorkspaces, int64_t epRank, - int64_t epSize) { +void moeCommOp(TensorView input, TensorView sendRankCumSum, TensorView sendIndices, + TensorView output, TensorView recvRankCumSum, TensorView recvIndices, + TensorView allWorkspaces, int64_t epRank, int64_t epSize) { CHECK_INPUT_TYPE(sendRankCumSum, dl_int32); CHECK_INPUT_TYPE(sendIndices, dl_int32); CHECK_INPUT_TYPE(recvRankCumSum, dl_int32); @@ -221,14 +223,16 @@ int64_t getPrepareWorkspaceSizePerRank(int64_t epSize) { return flashinfer::trtllm_alltoall::moe_prepare::getMoePrepareWorkspaceSize(epSize32); } -void moePrepareOp(Tensor expertsIds, Optional scales, Optional expertsStatics, - Tensor allWorkspaces, Tensor preparedLocalExpertIds, Tensor sendRankCountCumSum, - Tensor recvRankCountCumSum, Tensor gatherRecvRankIndices, Tensor recvRankIndices, - Tensor gatherBackwardRecvRankIndices, Tensor backwardRecvRankIndices, - Tensor gatherSendRankIndices, Tensor sendRankIndices, - Optional preparedLocalScales, Optional gatheredExpertStatics, - int64_t maxTokenCountPerRank, int64_t epRank, int64_t epSize, int64_t expertCount, - int64_t slotCount, int64_t topK) { +void moePrepareOp(TensorView expertsIds, Optional scales, + Optional expertsStatics, TensorView allWorkspaces, + TensorView preparedLocalExpertIds, TensorView sendRankCountCumSum, + TensorView recvRankCountCumSum, TensorView gatherRecvRankIndices, + TensorView recvRankIndices, TensorView gatherBackwardRecvRankIndices, + TensorView backwardRecvRankIndices, TensorView gatherSendRankIndices, + TensorView sendRankIndices, Optional preparedLocalScales, + Optional gatheredExpertStatics, int64_t maxTokenCountPerRank, + int64_t epRank, int64_t epSize, int64_t expertCount, int64_t slotCount, + int64_t topK) { CHECK_INPUT_TYPE(expertsIds, dl_int32); TVM_FFI_ICHECK_EQ(expertCount % 4, 0) << "expertCount must be divisible by 4"; TVM_FFI_ICHECK_EQ(slotCount % 4, 0) << "slotCount must be divisible by 4"; diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 317101208f..05e92a1721 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -194,13 +194,14 @@ inline Data_type dl_dtype_to_tllm_data_type(const DLDataType dtype) { inline bool is_4bit(Data_type data_type) { return data_type == Data_type::DATA_TYPE_E2M1; } -void trtllm_paged_attention_decode(Tensor out, Optional out_scale_factor, Tensor query, - Tensor key_cache, Tensor value_cache, Tensor workspace_buffer, - Tensor block_tables, Tensor seq_lens, int64_t max_kv_len, - double bmm1_scale, double bmm2_scale, double o_sf_scale, - int64_t o_sf_vec_size, int64_t o_sf_start_index, - int64_t window_left, int64_t sm_count, bool enable_pdl, - int64_t workspace_size, Optional attention_sinks) { +void trtllm_paged_attention_decode(TensorView out, Optional out_scale_factor, + TensorView query, TensorView key_cache, TensorView value_cache, + TensorView workspace_buffer, TensorView block_tables, + TensorView seq_lens, int64_t max_kv_len, double bmm1_scale, + double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size, + int64_t o_sf_start_index, int64_t window_left, int64_t sm_count, + bool enable_pdl, int64_t workspace_size, + Optional attention_sinks) { auto q_data_type = dl_dtype_to_tllm_data_type(query->dtype); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache->dtype); TVM_FFI_ICHECK_EQ(key_cache->ndim, value_cache->ndim); @@ -263,15 +264,16 @@ void trtllm_paged_attention_decode(Tensor out, Optional out_scale_factor enable_pdl, workspace_size, stream); } -void trtllm_paged_attention_context(Tensor out, Optional out_scale_factor, Tensor query, - Tensor key_cache, Tensor value_cache, Tensor workspace_buffer, - Tensor block_tables, Tensor seq_lens, int64_t max_q_len, - int64_t max_kv_len, double bmm1_scale, double bmm2_scale, - double o_sf_scale, int64_t o_sf_vec_size, - int64_t o_sf_start_index, int64_t batch_size, - int64_t window_left, Tensor cum_seq_lens_q, - Tensor cum_seq_lens_kv, int64_t sm_count, bool enable_pdl, - int64_t workspace_size, Optional attention_sinks) { +void trtllm_paged_attention_context(TensorView out, Optional out_scale_factor, + TensorView query, TensorView key_cache, TensorView value_cache, + TensorView workspace_buffer, TensorView block_tables, + TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len, + double bmm1_scale, double bmm2_scale, double o_sf_scale, + int64_t o_sf_vec_size, int64_t o_sf_start_index, + int64_t batch_size, int64_t window_left, + TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, + int64_t sm_count, bool enable_pdl, int64_t workspace_size, + Optional attention_sinks) { auto q_data_type = dl_dtype_to_tllm_data_type(query->dtype); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache->dtype); auto o_data_type = dl_dtype_to_tllm_data_type(out->dtype); @@ -411,13 +413,14 @@ void trtllm_ragged_attention_launcher( fmha_runner->run(runner_params); } -void trtllm_ragged_attention(Tensor out, Tensor query, Tensor key, Tensor value, - Tensor workspace_buffer, Tensor seq_lens, int64_t max_q_len, +void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, TensorView value, + TensorView workspace_buffer, TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len, double bmm1_scale, double bmm2_scale, double o_sf_scale, int64_t batch_size, int64_t window_left, - Tensor cum_seq_lens_q, Tensor cum_seq_lens_kv, int64_t sm_count, - bool enable_pdl, bool is_causal, int64_t workspace_size, - Optional attention_sinks, Optional lse) { + TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, + int64_t sm_count, bool enable_pdl, bool is_causal, + int64_t workspace_size, Optional attention_sinks, + Optional lse) { float* attention_sinks_ptr = nullptr; if (attention_sinks.has_value()) { TVM_FFI_ICHECK_EQ(attention_sinks.value()->dtype, dl_float32) @@ -446,10 +449,10 @@ void trtllm_ragged_attention(Tensor out, Tensor query, Tensor key, Tensor value, int head_dim_v = value->shape[2]; int k_stride_keys_values = key->strides[0]; int k_stride_heads = key->strides[1]; - int k_stride_batch = get_numel(key); + int k_stride_batch = key.numel(); int v_stride_keys_values = value->strides[0]; int v_stride_heads = value->strides[1]; - int v_stride_batch = get_numel(value); + int v_stride_batch = value.numel(); trtllm_ragged_attention_launcher( out->data, query->data, key->data, value->data, workspace_buffer->data, diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 5741611644..e599d658d7 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -37,15 +37,15 @@ using tensorrt_llm::kernels::trtllmgen_moe::Routing::RoutingMethodType; using tvm::ffi::Array; using tvm::ffi::Optional; -Tensor trtllm_fp8_per_tensor_scale_moe_launcher( - Tensor routing_logits, Optional routing_bias, Tensor hidden_states, - Tensor gemm1_weights, Tensor output1_scales_scalar, Tensor output1_scales_gate_scalar, - Tensor gemm2_weights, Tensor output2_scales_scalar, int64_t const num_experts, - int64_t const top_k, int64_t const n_group, int64_t const topk_group, - int64_t const intermediate_size, int64_t const local_expert_offset, - int64_t const local_num_experts, double const routed_scaling_factor, - bool const use_routing_scales_on_input, int64_t const tile_tokens_dim, - int64_t const routing_method_type, bool enable_pdl) { +TensorView trtllm_fp8_per_tensor_scale_moe_launcher( + TensorView routing_logits, Optional routing_bias, TensorView hidden_states, + TensorView gemm1_weights, TensorView output1_scales_scalar, + TensorView output1_scales_gate_scalar, TensorView gemm2_weights, + TensorView output2_scales_scalar, int64_t const num_experts, int64_t const top_k, + int64_t const n_group, int64_t const topk_group, int64_t const intermediate_size, + int64_t const local_expert_offset, int64_t const local_num_experts, + double const routed_scaling_factor, bool const use_routing_scales_on_input, + int64_t const tile_tokens_dim, int64_t const routing_method_type, bool enable_pdl) { static const std::tuple device_props = [hidden_states] { int major, minor; cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, @@ -275,11 +275,12 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher( return output; } -Tensor trtllm_fp8_per_tensor_scale_moe( - Tensor routing_logits, Optional routing_bias, Tensor hidden_states, - Tensor gemm1_weights, Tensor output1_scales_scalar, Tensor output1_scales_gate_scalar, - Tensor gemm2_weights, Tensor output2_scales_scalar, int64_t num_experts, int64_t top_k, - int64_t n_group, int64_t topk_group, int64_t intermediate_size, int64_t local_expert_offset, +TensorView trtllm_fp8_per_tensor_scale_moe( + TensorView routing_logits, Optional routing_bias, TensorView hidden_states, + TensorView gemm1_weights, TensorView output1_scales_scalar, + TensorView output1_scales_gate_scalar, TensorView gemm2_weights, + TensorView output2_scales_scalar, int64_t num_experts, int64_t top_k, int64_t n_group, + int64_t topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, double routed_scaling_factor, bool use_routing_scales_on_input, int64_t tile_tokens_dim, int64_t routing_method_type, bool enable_pdl) { auto dtype = hidden_states->dtype; @@ -296,10 +297,10 @@ Tensor trtllm_fp8_per_tensor_scale_moe( } void trtllm_fp8_block_scale_moe_launcher( - Tensor routing_logits, Optional routing_bias, Tensor hidden_states, - Tensor hidden_states_scale, Tensor gemm1_weights, Tensor gemm1_weights_scale, - Tensor gemm2_weights, Tensor gemm2_weights_scale, Tensor output, int64_t const num_experts, - int64_t const top_k, int64_t const n_group, int64_t const topk_group, + TensorView routing_logits, Optional routing_bias, TensorView hidden_states, + TensorView hidden_states_scale, TensorView gemm1_weights, TensorView gemm1_weights_scale, + TensorView gemm2_weights, TensorView gemm2_weights_scale, TensorView output, + int64_t const num_experts, int64_t const top_k, int64_t const n_group, int64_t const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset, int64_t const local_num_experts, double const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type, @@ -558,12 +559,12 @@ void trtllm_fp8_block_scale_moe_launcher( enable_pdl); } -void trtllm_fp8_block_scale_moe(Tensor routing_logits, Optional routing_bias, - Tensor hidden_states, Tensor hidden_states_scale, - Tensor gemm1_weights, Tensor gemm1_weights_scale, - Tensor gemm2_weights, Tensor gemm2_weights_scale, Tensor output, - int64_t num_experts, int64_t top_k, int64_t n_group, - int64_t topk_group, int64_t intermediate_size, +void trtllm_fp8_block_scale_moe(TensorView routing_logits, Optional routing_bias, + TensorView hidden_states, TensorView hidden_states_scale, + TensorView gemm1_weights, TensorView gemm1_weights_scale, + TensorView gemm2_weights, TensorView gemm2_weights_scale, + TensorView output, int64_t num_experts, int64_t top_k, + int64_t n_group, int64_t topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, double routed_scaling_factor, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, @@ -603,20 +604,22 @@ void trtllm_fp8_block_scale_moe(Tensor routing_logits, Optional routing_ // TODO(siyuan): This launcher supports flexible weight and activation types. // We should cleanup other launchers and only use this one in the future. -Array trtllm_fp4_block_scale_moe_launcher( - Optional routing_logits, Tensor expert_indices, Tensor expert_weights, - Optional routing_bias, Tensor hidden_states, Optional hidden_states_scale, - Tensor gemm1_weights, Tensor gemm1_weights_scale, Optional gemm1_bias, - Optional gemm1_alpha, Optional gemm1_beta, Optional gemm1_clamp_limit, - Tensor gemm2_weights, Tensor gemm2_weights_scale, Optional gemm2_bias, - Optional output1_scales_scalar, Optional output1_scales_gate_scalar, - Optional output2_scales_scalar, int64_t const num_experts, int64_t const top_k, +Array trtllm_fp4_block_scale_moe_launcher( + Optional routing_logits, TensorView expert_indices, TensorView expert_weights, + Optional routing_bias, TensorView hidden_states, + Optional hidden_states_scale, TensorView gemm1_weights, + TensorView gemm1_weights_scale, Optional gemm1_bias, + Optional gemm1_alpha, Optional gemm1_beta, + Optional gemm1_clamp_limit, TensorView gemm2_weights, + TensorView gemm2_weights_scale, Optional gemm2_bias, + Optional output1_scales_scalar, Optional output1_scales_gate_scalar, + Optional output2_scales_scalar, int64_t const num_experts, int64_t const top_k, Optional const n_group, Optional const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset, int64_t const local_num_experts, Optional const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type, bool const do_finalize, tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner& moe_runner, btg::Dtype dtype_act, - btg::Dtype dtype_weights, int64_t const moeConfigIndex, bool enable_pdl, Tensor output) { + btg::Dtype dtype_weights, int64_t const moeConfigIndex, bool enable_pdl, TensorView output) { static const std::tuple device_props = [hidden_states] { int major, minor; cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, @@ -850,7 +853,7 @@ Array trtllm_fp4_block_scale_moe_launcher( << "hidden_states_scale must be fp8."; TVM_FFI_ICHECK_EQ( - get_numel(hidden_states_scale.value()), + hidden_states_scale.value().numel(), tensorrt_llm::computeLinearLayoutSFSize(args.num_tokens, args.hidden_size / sf_vec_size)) << "hidden_states_scale has incorrect size"; } @@ -1015,18 +1018,20 @@ Array trtllm_fp4_block_scale_moe_launcher( return {output}; } -Array trtllm_fp4_block_scale_moe( - Optional routing_logits, Tensor topk_ids, Tensor expert_weights, - Optional routing_bias, Tensor hidden_states, Optional hidden_states_scale, - Tensor gemm1_weights, Tensor gemm1_weights_scale, Optional gemm1_bias, - Optional gemm1_alpha, Optional gemm1_beta, Optional gemm1_clamp_limit, - Tensor gemm2_weights, Tensor gemm2_weights_scale, Optional gemm2_bias, - Optional output1_scales_scalar, Optional output1_scales_gate_scalar, - Optional output2_scales_scalar, int64_t num_experts, int64_t top_k, +Array trtllm_fp4_block_scale_moe( + Optional routing_logits, TensorView topk_ids, TensorView expert_weights, + Optional routing_bias, TensorView hidden_states, + Optional hidden_states_scale, TensorView gemm1_weights, + TensorView gemm1_weights_scale, Optional gemm1_bias, + Optional gemm1_alpha, Optional gemm1_beta, + Optional gemm1_clamp_limit, TensorView gemm2_weights, + TensorView gemm2_weights_scale, Optional gemm2_bias, + Optional output1_scales_scalar, Optional output1_scales_gate_scalar, + Optional output2_scales_scalar, int64_t num_experts, int64_t top_k, Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, int64_t tile_tokens_dim, int64_t routing_method_type, bool do_finalize, bool enable_pdl, - int64_t gated_act_type, Tensor output, int64_t config_index) { + int64_t gated_act_type, TensorView output, int64_t config_index) { using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; int const num_tokens = hidden_states->shape[0]; @@ -1034,11 +1039,10 @@ Array trtllm_fp4_block_scale_moe( if (hidden_states->dtype == dl_uint8) hidden_size *= 2; int hidden_states_scale_vec_size = -1; if (hidden_states_scale.has_value()) { - hidden_states_scale_vec_size = - (num_tokens * hidden_size) / get_numel(hidden_states_scale.value()); + hidden_states_scale_vec_size = (num_tokens * hidden_size) / hidden_states_scale.value().numel(); } int weight_scale_vec_size = - (local_num_experts * intermediate_size * 2 * hidden_size) / get_numel(gemm1_weights_scale); + (local_num_experts * intermediate_size * 2 * hidden_size) / gemm1_weights_scale.numel(); TVM_FFI_ICHECK(weight_scale_vec_size == 16 || weight_scale_vec_size == 32) << "unsupported weight_scale_vec_size."; auto mDtypeWeights = weight_scale_vec_size == 16 ? btg::Dtype::E2m1 : btg::Dtype::MxE2m1; diff --git a/csrc/trtllm_gemm_runner.cu b/csrc/trtllm_gemm_runner.cu index c4ec9d5cff..9b644d63fe 100644 --- a/csrc/trtllm_gemm_runner.cu +++ b/csrc/trtllm_gemm_runner.cu @@ -256,8 +256,9 @@ class TrtllmGenGemmRunner { using tvm::ffi::Array; using tvm::ffi::Optional; -void trtllm_gemm(Tensor workspace_buffer, Tensor a, Tensor b, Tensor a_scale, Tensor b_scale, - Optional globalScale, Tensor out, bool use_8x4_sf_layout, int64_t tactic) { +void trtllm_gemm(TensorView workspace_buffer, TensorView a, TensorView b, TensorView a_scale, + TensorView b_scale, Optional globalScale, TensorView out, + bool use_8x4_sf_layout, int64_t tactic) { CHECK_DEVICE(a, b); CHECK_DEVICE(a, out); CHECK_INPUT(a); @@ -309,7 +310,7 @@ void trtllm_gemm(Tensor workspace_buffer, Tensor a, Tensor b, Tensor a_scale, Te int64_t const required_workspace_size = runner.getWorkspaceSizeInBytes(m, n, k, tactic); int64_t const provided_workspace_size = - get_numel(workspace_buffer) * get_element_size(workspace_buffer); + workspace_buffer.numel() * get_element_size(workspace_buffer); if (provided_workspace_size < required_workspace_size) { Tensor new_workspace = alloc_tensor({required_workspace_size}, dl_int8, a->device); runKernel(new_workspace->data); diff --git a/csrc/trtllm_mnnvl_allreduce.cu b/csrc/trtllm_mnnvl_allreduce.cu index f09f87e111..76d11d8624 100644 --- a/csrc/trtllm_mnnvl_allreduce.cu +++ b/csrc/trtllm_mnnvl_allreduce.cu @@ -26,10 +26,10 @@ using tvm::ffi::Optional; } \ }() -void trtllm_mnnvl_all_reduce(Tensor in, int64_t multicast_buffer_ptr, int64_t buffer_ptrs_dev, - int64_t buffer_M, Tensor buffer_flags_mnnvl, int64_t nranks, +void trtllm_mnnvl_all_reduce(TensorView in, int64_t multicast_buffer_ptr, int64_t buffer_ptrs_dev, + int64_t buffer_M, TensorView buffer_flags_mnnvl, int64_t nranks, int64_t rank, bool wait_for_results, bool launch_with_pdl, - Optional out) { + Optional out) { cudaSetDevice(in->device.device_id); auto stream = get_stream(in->device); @@ -71,9 +71,9 @@ void trtllm_mnnvl_all_reduce(Tensor in, int64_t multicast_buffer_ptr, int64_t bu }); } -void trtllm_mnnvl_rmsnorm(int64_t multicast_buffer_ptr, Tensor prenorm_output, Tensor normed_output, - Tensor gamma, double epsilon, Tensor residual, Tensor buffer_flags, - bool launch_with_pdl) { +void trtllm_mnnvl_rmsnorm(int64_t multicast_buffer_ptr, TensorView prenorm_output, + TensorView normed_output, TensorView gamma, double epsilon, + TensorView residual, TensorView buffer_flags, bool launch_with_pdl) { cudaSetDevice(prenorm_output->device.device_id); auto stream = get_stream(prenorm_output->device); diff --git a/csrc/trtllm_moe_allreduce_fusion.cu b/csrc/trtllm_moe_allreduce_fusion.cu index 46884282ff..a7ee3fc0c4 100644 --- a/csrc/trtllm_moe_allreduce_fusion.cu +++ b/csrc/trtllm_moe_allreduce_fusion.cu @@ -24,16 +24,14 @@ using tvm::ffi::Optional; } \ }() -void trtllm_moe_allreduce_fusion(int64_t world_size, int64_t world_rank, int64_t token_num, - int64_t hidden_size, Tensor workspace_ptrs, bool launch_with_pdl, - Tensor residual_in, Tensor rms_gamma, double rms_eps, - double scale_factor, int64_t moe_reduction_device_num_experts, - Tensor moe_reduction_scale_input, - Tensor moe_reduction_active_experts_token_input, - Tensor moe_reduction_token_input, Optional layout_code, - Optional moe_allreduce_out, Optional residual_out, - Optional norm_out, Optional quant_out, - Optional scale_out) { +void trtllm_moe_allreduce_fusion( + int64_t world_size, int64_t world_rank, int64_t token_num, int64_t hidden_size, + TensorView workspace_ptrs, bool launch_with_pdl, TensorView residual_in, TensorView rms_gamma, + double rms_eps, double scale_factor, int64_t moe_reduction_device_num_experts, + TensorView moe_reduction_scale_input, TensorView moe_reduction_active_experts_token_input, + TensorView moe_reduction_token_input, Optional layout_code, + Optional moe_allreduce_out, Optional residual_out, + Optional norm_out, Optional quant_out, Optional scale_out) { cudaSetDevice(moe_reduction_active_experts_token_input->device.device_id); auto stream = get_stream(moe_reduction_active_experts_token_input->device); @@ -81,13 +79,12 @@ void trtllm_moe_allreduce_fusion(int64_t world_size, int64_t world_rank, int64_t }); } -void trtllm_moe_finalize_allreduce_fusion(Tensor allreduce_in, Tensor residual_in, - Tensor norm_weight, Tensor expanded_idx_to_permuted_idx, - Tensor norm_out, Tensor residual_out, - bool launch_with_pdl, Tensor workspace, - int64_t const world_rank, int64_t const world_size, - double const eps, Optional shared_expert_output, - Optional expert_scale_factor) { +void trtllm_moe_finalize_allreduce_fusion( + TensorView allreduce_in, TensorView residual_in, TensorView norm_weight, + TensorView expanded_idx_to_permuted_idx, TensorView norm_out, TensorView residual_out, + bool launch_with_pdl, TensorView workspace, int64_t const world_rank, int64_t const world_size, + double const eps, Optional shared_expert_output, + Optional expert_scale_factor) { DISPATCH_FLOATING_TYPES_FOR_ALLREDUCE(residual_in->dtype, c_type, [&] { MoeFinalizeAllReduceFusionParams params; @@ -100,7 +97,7 @@ void trtllm_moe_finalize_allreduce_fusion(Tensor allreduce_in, Tensor residual_i params.nranks = static_cast(world_size); params.rank = static_cast(world_rank); // size: num_token * hidden_dim - params.size = get_numel(residual_in); + params.size = residual_in.numel(); params.hidden_dim = hidden_dim; // workspace: AR scratch space diff --git a/csrc/tvm_ffi_utils.h b/csrc/tvm_ffi_utils.h index 99556d2e5d..1c3f7d4952 100644 --- a/csrc/tvm_ffi_utils.h +++ b/csrc/tvm_ffi_utils.h @@ -23,6 +23,7 @@ #include "dlpack/dlpack.h" using tvm::ffi::Tensor; +using tvm::ffi::TensorView; namespace ffi = tvm::ffi; inline constexpr int64_t encode_dlpack_dtype(DLDataType dtype) { @@ -227,6 +228,15 @@ inline void check_shape(const tvm::ffi::Tensor& a, const tvm::ffi::Tensor& b, co } } +inline void check_shape(const tvm::ffi::TensorView& a, const tvm::ffi::TensorView& b, + const char* a_name, const char* b_name) { + TVM_FFI_ICHECK_EQ(a->ndim, b->ndim) << a_name << "->ndim and " << b_name << "->ndim mismatch"; + for (int i = 0; i < a->ndim; ++i) { + TVM_FFI_ICHECK_EQ(a->shape[i], b->shape[i]) + << a_name << "->shape[" << i << "] and " << b_name << "->shape[" << i << "] mismatch"; + } +} + #define CHECK_CUDA(x) \ TVM_FFI_ICHECK_EQ(x->device.device_type, kDLCUDA) << #x " must be a CUDA tensor"; #define CHECK_CPU(x) \ @@ -265,7 +275,7 @@ inline cudaStream_t get_stream(DLDevice device) { inline int64_t get_element_size(ffi::Tensor x) { return (x->dtype.bits * x->dtype.lanes) / 8; } -inline int64_t get_numel(ffi::Tensor x) { return x.shape().Product(); } +inline int64_t get_element_size(ffi::TensorView x) { return (x->dtype.bits * x->dtype.lanes) / 8; } inline ffi::Tensor alloc_tensor(tvm::ffi::Shape shape, DLDataType dtype, DLDevice device) { return ffi::Tensor::FromDLPackAlloc(TVMFFIEnvGetTensorAllocator(), shape, dtype, device); diff --git a/csrc/vllm_custom_all_reduce.cu b/csrc/vllm_custom_all_reduce.cu index 897061038f..14b1425e5c 100644 --- a/csrc/vllm_custom_all_reduce.cu +++ b/csrc/vllm_custom_all_reduce.cu @@ -17,7 +17,7 @@ static_assert(sizeof(void*) == sizeof(fptr_t)); using tvm::ffi::Array; using tvm::ffi::Tuple; -fptr_t init_custom_ar(Array fake_ipc_ptrs, Tensor rank_data, int64_t rank, +fptr_t init_custom_ar(Array fake_ipc_ptrs, TensorView rank_data, int64_t rank, bool full_nvlink) { int world_size = fake_ipc_ptrs.size(); if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); @@ -28,7 +28,7 @@ fptr_t init_custom_ar(Array fake_ipc_ptrs, Tensor rank_data, int64_t ran for (int i = 0; i < world_size; i++) { ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); } - return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data->data, get_numel(rank_data), rank, + return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data->data, rank_data.numel(), rank, world_size, full_nvlink); } @@ -48,8 +48,8 @@ fptr_t init_custom_ar(Array fake_ipc_ptrs, Tensor rank_data, int64_t ran * 5. A[None].expand(2, -1, -1, -1): Not OK * 6. A[:, 1:, 1:]: Not OK */ -bool _is_weak_contiguous(Tensor t) { - auto numel = get_numel(t); +bool _is_weak_contiguous(TensorView t) { + auto numel = t.numel(); auto element_size = get_element_size(t); return t.IsContiguous() || (tvm::ffi::GetDataSize(numel, t->dtype) - t->byte_offset * element_size == @@ -63,17 +63,17 @@ bool _is_weak_contiguous(Tensor t) { * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first * copied into _reg_buffer. */ -void all_reduce(fptr_t _fa, Tensor inp, Tensor out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes, - int64_t num_ctas) { +void all_reduce(fptr_t _fa, TensorView inp, TensorView out, fptr_t _reg_buffer, + int64_t reg_buffer_sz_bytes, int64_t num_ctas) { auto fa = reinterpret_cast(_fa); cudaSetDevice(inp->device.device_id); auto stream = get_stream(inp->device); TVM_FFI_ICHECK_EQ(inp->dtype, out->dtype); - TVM_FFI_ICHECK_EQ(get_numel(inp), get_numel(out)); + TVM_FFI_ICHECK_EQ(inp.numel(), out.numel()); TVM_FFI_ICHECK(_is_weak_contiguous(out)); TVM_FFI_ICHECK(_is_weak_contiguous(inp)); - auto input_size = get_numel(inp) * get_element_size(inp); + auto input_size = inp.numel() * get_element_size(inp); auto reg_buffer = reinterpret_cast(_reg_buffer); if (reg_buffer) { TVM_FFI_ICHECK_LE(input_size, reg_buffer_sz_bytes); @@ -86,19 +86,18 @@ void all_reduce(fptr_t _fa, Tensor inp, Tensor out, fptr_t _reg_buffer, int64_t switch (encode_dlpack_dtype(out->dtype)) { case float32_code: { fa->allreduce(stream, reinterpret_cast(reg_buffer), - reinterpret_cast(out->data), get_numel(out), num_ctas); + reinterpret_cast(out->data), out.numel(), num_ctas); break; } case float16_code: { fa->allreduce(stream, reinterpret_cast(reg_buffer), - reinterpret_cast(out->data), get_numel(out), num_ctas); + reinterpret_cast(out->data), out.numel(), num_ctas); break; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) case bfloat16_code: { fa->allreduce(stream, reinterpret_cast(reg_buffer), - reinterpret_cast(out->data), get_numel(out), - num_ctas); + reinterpret_cast(out->data), out.numel(), num_ctas); break; } #endif diff --git a/csrc/xqa/xqa_wrapper.cu b/csrc/xqa/xqa_wrapper.cu index 9f8cb34f4b..dad2ef92a5 100644 --- a/csrc/xqa/xqa_wrapper.cu +++ b/csrc/xqa/xqa_wrapper.cu @@ -18,16 +18,17 @@ #include "mha.h" void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingWinSize, - double qScale, Tensor output, + double qScale, TensorView output, #if LOW_PREC_OUTPUT - Tensor rcpOutScale, + TensorView rcpOutScale, #endif - Tensor q, Tensor attentionSinks, Tensor pool, Tensor kvCachePageList, - int64_t maxSeqLen, Tensor seqLen, int64_t batchSize, Tensor kvCacheScale, + TensorView q, TensorView attentionSinks, TensorView pool, + TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, + int64_t batchSize, TensorView kvCacheScale, #if SPEC_DEC - int64_t qSeqLen, Tensor qCuSeqLens, Tensor mask, + int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask, #endif - Tensor semaphores, Tensor scratch) { + TensorView semaphores, TensorView scratch) { auto stream = get_stream(output->device); float const* attentionSinksPtr = attentionSinks.defined() ? reinterpret_cast(attentionSinks->data) : nullptr; diff --git a/flashinfer/deep_gemm.py b/flashinfer/deep_gemm.py index 9dfe4c5709..6341330d2d 100644 --- a/flashinfer/deep_gemm.py +++ b/flashinfer/deep_gemm.py @@ -948,8 +948,9 @@ def load(name: str, code: str) -> SM100FP8GemmRuntime: if cubin_name in RUNTIME_CACHE: return RUNTIME_CACHE[cubin_name] symbol, sha256 = KERNEL_MAP[cubin_name] - get_cubin(ArtifactPath.DEEPGEMM + cubin_name, sha256) - path = FLASHINFER_CUBIN_DIR / f"{ArtifactPath.DEEPGEMM + cubin_name}.cubin" + path = f"{ArtifactPath.DEEPGEMM}/{cubin_name}.cubin" + assert get_cubin(path, sha256) + path = FLASHINFER_CUBIN_DIR / path assert path.exists() RUNTIME_CACHE[cubin_name] = SM100FP8GemmRuntime(str(path), symbol) return RUNTIME_CACHE[cubin_name] @@ -1490,11 +1491,11 @@ def __init__(self, sha256: str): self.indice = None def init_indices(self): - indice_path = ArtifactPath.DEEPGEMM + "kernel_map.json" + indice_path = ArtifactPath.DEEPGEMM + "/kernel_map.json" assert get_cubin(indice_path, self.sha256), ( "cubin kernel map file not found, nor downloaded with matched sha256" ) - path = FLASHINFER_CUBIN_DIR / f"{indice_path}.json" + path = FLASHINFER_CUBIN_DIR / indice_path assert path.exists() with open(path, "r") as f: self.indice = json.load(f) diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 6ae34ef3ba..ea344a6ef6 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -982,8 +982,9 @@ def forward( # tgv_gemm takes mat1 as weights and mat2 as input tensor # from [m,k]x[k,n]+[n,] to [n,k]x[k,m]+[n,] gemm_fn = module.tgv_gemm - out = gemm_fn(b.t(), a.t(), bias, tactic, pdl) - return out.t() + c = torch.empty((a.shape[0], b.shape[1]), dtype=a.dtype, device=a.device) + gemm_fn(b.t(), a.t(), bias, tactic, c, pdl) + return c.t() return TGVGemmRunner() From 32ec089cb62e6fcadf8eb70a8acecf1154d17a0d Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Thu, 2 Oct 2025 03:44:54 +0000 Subject: [PATCH 02/11] fix --- flashinfer/gemm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index ea344a6ef6..7f2f90f0db 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -982,7 +982,9 @@ def forward( # tgv_gemm takes mat1 as weights and mat2 as input tensor # from [m,k]x[k,n]+[n,] to [n,k]x[k,m]+[n,] gemm_fn = module.tgv_gemm - c = torch.empty((a.shape[0], b.shape[1]), dtype=a.dtype, device=a.device) + c = torch.empty( + (a.shape[0], b.shape[1]), dtype=a.dtype, device=a.device + ) gemm_fn(b.t(), a.t(), bias, tactic, c, pdl) return c.t() From 1378fbfeb7127279950ca69f0195238aa4165cd7 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Thu, 2 Oct 2025 03:49:07 +0000 Subject: [PATCH 03/11] fix --- ...shinfer_cutlass_fused_moe_sm100_binding.cu | 144 +++++++++--------- csrc/trtllm_fused_moe_kernel_launcher.cu | 2 +- 2 files changed, 74 insertions(+), 72 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu index 23e027717b..d96b16ec05 100644 --- a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu +++ b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu @@ -222,12 +222,12 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { mAllProfiles = mKernelRunner->getTactics(); } - void runMoe(Tensor output, Tensor input, Tensor token_selected_experts, - Optional token_final_scales, Tensor fc1_expert_weights, - Optional fc1_expert_biases, Tensor fc2_expert_weights, - Optional fc2_expert_biases, Optional> quant_scales, - Optional input_sf, Optional swiglu_alpha, - Optional swiglu_beta, Optional swiglu_limit, int64_t tp_size, + void runMoe(TensorView output, TensorView input, TensorView token_selected_experts, + Optional token_final_scales, TensorView fc1_expert_weights, + Optional fc1_expert_biases, TensorView fc2_expert_weights, + Optional fc2_expert_biases, Optional> quant_scales, + Optional input_sf, Optional swiglu_alpha, + Optional swiglu_beta, Optional swiglu_limit, int64_t tp_size, int64_t tp_rank, int64_t ep_size, int64_t ep_rank, int64_t cluster_size, int64_t cluster_rank, bool enable_alltoall, bool min_latency_mode, Optional> profile_ids, bool enable_pdl) { @@ -390,17 +390,18 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { #endif } - void runMoeMinLantency(Tensor output, Tensor input, Tensor token_selected_experts, - Optional token_final_scales, Tensor fc1_expert_weights, - Optional fc1_expert_biases, Tensor fc2_expert_weights, - Optional fc2_expert_biases, Optional> quant_scales, - Optional input_sf, Optional swiglu_alpha, - Optional swiglu_beta, Optional swiglu_limit, - Tensor num_active_experts_per_node, Tensor experts_to_token_score, - Tensor active_expert_global_ids, int64_t tp_size, int64_t tp_rank, - int64_t ep_size, int64_t ep_rank, int64_t cluster_size, - int64_t cluster_rank, bool enable_alltoall, bool min_latency_mode, - Optional> profile_ids, bool enable_pdl) { + void runMoeMinLantency(TensorView output, TensorView input, TensorView token_selected_experts, + Optional token_final_scales, TensorView fc1_expert_weights, + Optional fc1_expert_biases, TensorView fc2_expert_weights, + Optional fc2_expert_biases, + Optional> quant_scales, Optional input_sf, + Optional swiglu_alpha, Optional swiglu_beta, + Optional swiglu_limit, TensorView num_active_experts_per_node, + TensorView experts_to_token_score, TensorView active_expert_global_ids, + int64_t tp_size, int64_t tp_rank, int64_t ep_size, int64_t ep_rank, + int64_t cluster_size, int64_t cluster_rank, bool enable_alltoall, + bool min_latency_mode, Optional> profile_ids, + bool enable_pdl) { std::lock_guard lock(mMutex); CHECK_INPUT_TYPE(input, mActivationDtype) @@ -564,12 +565,12 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { return mAllProfiles.size(); } - void runGemmProfile(Tensor input, Tensor fc1_expert_weights, Optional fc1_expert_biases, - Tensor fc2_expert_weights, Optional fc2_expert_biases, int64_t top_k, - int64_t tp_size, int64_t tp_rank, int64_t ep_size, int64_t ep_rank, - int64_t cluster_size, int64_t cluster_rank, bool enable_alltoall, - bool min_latency_mode, int64_t gemm_idx, int64_t profile_id, - bool do_preparation, bool enable_pdl) { + void runGemmProfile(TensorView input, TensorView fc1_expert_weights, + Optional fc1_expert_biases, TensorView fc2_expert_weights, + Optional fc2_expert_biases, int64_t top_k, int64_t tp_size, + int64_t tp_rank, int64_t ep_size, int64_t ep_rank, int64_t cluster_size, + int64_t cluster_rank, bool enable_alltoall, bool min_latency_mode, + int64_t gemm_idx, int64_t profile_id, bool do_preparation, bool enable_pdl) { std::lock_guard lock(mMutex); // TODO: support profiling under fp8 block scaling in the future @@ -651,12 +652,12 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { Optional GetFunction(const tvm::ffi::String& name) final { if (name == "run_gemm_profile") { return Function::FromTyped( - [this](Tensor input, Tensor fc1_expert_weights, Optional fc1_expert_biases, - Tensor fc2_expert_weights, Optional fc2_expert_biases, int64_t top_k, - int64_t tp_size, int64_t tp_rank, int64_t ep_size, int64_t ep_rank, - int64_t cluster_size, int64_t cluster_rank, bool enable_alltoall, - bool min_latency_mode, int64_t gemm_idx, int64_t profile_id, bool do_preparation, - bool enable_pdl) { + [this](TensorView input, TensorView fc1_expert_weights, + Optional fc1_expert_biases, TensorView fc2_expert_weights, + Optional fc2_expert_biases, int64_t top_k, int64_t tp_size, + int64_t tp_rank, int64_t ep_size, int64_t ep_rank, int64_t cluster_size, + int64_t cluster_rank, bool enable_alltoall, bool min_latency_mode, + int64_t gemm_idx, int64_t profile_id, bool do_preparation, bool enable_pdl) { runGemmProfile(input, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases, top_k, tp_size, tp_rank, ep_size, ep_rank, cluster_size, cluster_rank, enable_alltoall, min_latency_mode, gemm_idx, @@ -666,15 +667,15 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { return Function::FromTyped([this]() -> int64_t { return getTacticNum(); }); } else if (name == "run_moe") { return Function::FromTyped( - [this](Tensor output, Tensor input, Tensor token_selected_experts, - Optional token_final_scales, Tensor fc1_expert_weights, - Optional fc1_expert_biases, Tensor fc2_expert_weights, - Optional fc2_expert_biases, Optional> quant_scales, - Optional input_sf, Optional swiglu_alpha, - Optional swiglu_beta, Optional swiglu_limit, int64_t tp_size, - int64_t tp_rank, int64_t ep_size, int64_t ep_rank, int64_t cluster_size, - int64_t cluster_rank, bool enable_alltoall, bool min_latency_mode, - Optional> profile_ids, bool enable_pdl) { + [this](TensorView output, TensorView input, TensorView token_selected_experts, + Optional token_final_scales, TensorView fc1_expert_weights, + Optional fc1_expert_biases, TensorView fc2_expert_weights, + Optional fc2_expert_biases, Optional> quant_scales, + Optional input_sf, Optional swiglu_alpha, + Optional swiglu_beta, Optional swiglu_limit, + int64_t tp_size, int64_t tp_rank, int64_t ep_size, int64_t ep_rank, + int64_t cluster_size, int64_t cluster_rank, bool enable_alltoall, + bool min_latency_mode, Optional> profile_ids, bool enable_pdl) { runMoe(output, input, token_selected_experts, token_final_scales, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases, quant_scales, input_sf, swiglu_alpha, swiglu_beta, swiglu_limit, tp_size, tp_rank, ep_size, ep_rank, @@ -683,16 +684,17 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { }); } else if (name == "run_moe_min_latency") { return Function::FromTyped( - [this](Tensor output, Tensor input, Tensor token_selected_experts, - Optional token_final_scales, Tensor fc1_expert_weights, - Optional fc1_expert_biases, Tensor fc2_expert_weights, - Optional fc2_expert_biases, Optional> quant_scales, - Optional input_sf, Optional swiglu_alpha, - Optional swiglu_beta, Optional swiglu_limit, - Tensor num_active_experts_per_node, Tensor experts_to_token_score, - Tensor active_expert_global_ids, int64_t tp_size, int64_t tp_rank, int64_t ep_size, - int64_t ep_rank, int64_t cluster_size, int64_t cluster_rank, bool enable_alltoall, - bool min_latency_mode, Optional> profile_ids, bool enable_pdl) { + [this](TensorView output, TensorView input, TensorView token_selected_experts, + Optional token_final_scales, TensorView fc1_expert_weights, + Optional fc1_expert_biases, TensorView fc2_expert_weights, + Optional fc2_expert_biases, Optional> quant_scales, + Optional input_sf, Optional swiglu_alpha, + Optional swiglu_beta, Optional swiglu_limit, + TensorView num_active_experts_per_node, TensorView experts_to_token_score, + TensorView active_expert_global_ids, int64_t tp_size, int64_t tp_rank, + int64_t ep_size, int64_t ep_rank, int64_t cluster_size, int64_t cluster_rank, + bool enable_alltoall, bool min_latency_mode, Optional> profile_ids, + bool enable_pdl) { runMoeMinLantency(output, input, token_selected_experts, token_final_scales, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases, quant_scales, input_sf, swiglu_alpha, swiglu_beta, @@ -782,7 +784,7 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { kernels::QuantParams getQuantParams(int64_t num_experts_on_rank, int64_t hidden_size, int64_t inter_size, - Optional> quant_scales) const { + Optional> quant_scales) const { if (isFp8Quant()) { TVM_FFI_ICHECK(quant_scales.has_value()) << "Expecting quant scales for fp8 quantization"; TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 4) @@ -887,10 +889,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 4) "Expecting 4 quant scales for W4A8_MXFP4_MXFP8 quantization"; - Tensor fc1_weight_block = quant_scales.value()[0]; - Tensor fc1_global = quant_scales.value()[1]; - Tensor fc2_weight_block = quant_scales.value()[2]; - Tensor fc2_global = quant_scales.value()[3]; + TensorView fc1_weight_block = quant_scales.value()[0]; + TensorView fc1_global = quant_scales.value()[1]; + TensorView fc2_weight_block = quant_scales.value()[2]; + TensorView fc2_global = quant_scales.value()[3]; // The input for scale fc1_weight_block / fc2_weight_block is packed into INT32 constexpr int FP8_PER_INT32 = 4; @@ -946,12 +948,12 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 6) << "Expecting 6 quant scales for nvfp4 quantization"; - Tensor fc1_act_global = quant_scales.value()[0]; - Tensor fc1_weight_block = quant_scales.value()[1]; - Tensor fc1_global = quant_scales.value()[2]; - Tensor fc2_act_global = quant_scales.value()[3]; - Tensor fc2_weight_block = quant_scales.value()[4]; - Tensor fc2_global = quant_scales.value()[5]; + TensorView fc1_act_global = quant_scales.value()[0]; + TensorView fc1_weight_block = quant_scales.value()[1]; + TensorView fc1_global = quant_scales.value()[2]; + TensorView fc2_act_global = quant_scales.value()[3]; + TensorView fc2_weight_block = quant_scales.value()[4]; + TensorView fc2_global = quant_scales.value()[5]; // The input for scale fc1_weight_block / fc2_weight_block is packed into INT32 constexpr int FP8_PER_INT32 = 4; @@ -1011,8 +1013,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { static_cast(fc2_global->data), fc1_act_global->ndim == 1, fc2_act_global->ndim == 1); } else if (mUseDeepSeekFP8BlockScaling) { - Tensor fc1_scales = quant_scales.value()[0]; - Tensor fc2_scales = quant_scales.value()[1]; + TensorView fc1_scales = quant_scales.value()[0]; + TensorView fc2_scales = quant_scales.value()[1]; return kernels::QuantParams::FP8BlockScaling(static_cast(fc1_scales->data), static_cast(fc2_scales->data)); } else if (isWFP4A16Quant()) { @@ -1020,8 +1022,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 2) << "Expecting 2 quant scales for W4A16 quantization"; - Tensor fc1_weight_scales = quant_scales.value()[0]; - Tensor fc2_weight_scales = quant_scales.value()[1]; + TensorView fc1_weight_scales = quant_scales.value()[0]; + TensorView fc2_weight_scales = quant_scales.value()[1]; int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size; return kernels::QuantParams::GroupWise(group_size, static_cast(fc1_weight_scales->data), @@ -1031,14 +1033,14 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { TVM_FFI_ICHECK(quant_scales.has_value()) << "Expecting quant scales for INT4 quantization"; TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 8) << "Expecting 8 quant scales for INT4 quantization"; - Tensor fc1_weight_scales = quant_scales.value()[0]; - Tensor fc2_weight_scales = quant_scales.value()[1]; - Tensor fc1_act_scales = quant_scales.value()[2]; - Tensor fc2_act_scales = quant_scales.value()[3]; - Tensor fc1_weight_zeros = quant_scales.value()[4]; - Tensor fc2_weight_zeros = quant_scales.value()[5]; - Tensor fc1_alpha = quant_scales.value()[6]; - Tensor fc2_alpha = quant_scales.value()[7]; + TensorView fc1_weight_scales = quant_scales.value()[0]; + TensorView fc2_weight_scales = quant_scales.value()[1]; + TensorView fc1_act_scales = quant_scales.value()[2]; + TensorView fc2_act_scales = quant_scales.value()[3]; + TensorView fc1_weight_zeros = quant_scales.value()[4]; + TensorView fc2_weight_zeros = quant_scales.value()[5]; + TensorView fc1_alpha = quant_scales.value()[6]; + TensorView fc2_alpha = quant_scales.value()[7]; int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size; return kernels::QuantParams::GroupWise( group_size, static_cast(fc1_weight_scales->data), diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index e599d658d7..ee5fafe5d3 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -795,7 +795,7 @@ Array trtllm_fp4_block_scale_moe_launcher( dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_uint8, hidden_states->device); - Optional gemm1_output_scale = std::nullopt; + Optional gemm1_output_scale = std::nullopt; if (dtype_act == btg::Dtype::E2m1 || dtype_act == btg::Dtype::MxE4m3) { int64_t sf_size = tensorrt_llm::computeSwizzledLayoutSFSize(max_num_padded_tokens, intermediate_size / sf_vec_size); From 66588a7606d81ec46337032b488e1b05e3689566 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 3 Oct 2025 23:18:45 +0000 Subject: [PATCH 04/11] upd --- csrc/trtllm_fused_moe_kernel_launcher.cu | 44 +++++++++++++----------- flashinfer/fused_moe/core.py | 23 ++++++++----- 2 files changed, 38 insertions(+), 29 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index ee5fafe5d3..cb948f454f 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -37,15 +37,16 @@ using tensorrt_llm::kernels::trtllmgen_moe::Routing::RoutingMethodType; using tvm::ffi::Array; using tvm::ffi::Optional; -TensorView trtllm_fp8_per_tensor_scale_moe_launcher( +void trtllm_fp8_per_tensor_scale_moe_launcher( TensorView routing_logits, Optional routing_bias, TensorView hidden_states, TensorView gemm1_weights, TensorView output1_scales_scalar, TensorView output1_scales_gate_scalar, TensorView gemm2_weights, - TensorView output2_scales_scalar, int64_t const num_experts, int64_t const top_k, - int64_t const n_group, int64_t const topk_group, int64_t const intermediate_size, - int64_t const local_expert_offset, int64_t const local_num_experts, - double const routed_scaling_factor, bool const use_routing_scales_on_input, - int64_t const tile_tokens_dim, int64_t const routing_method_type, bool enable_pdl) { + TensorView output2_scales_scalar, TensorView output, int64_t const num_experts, + int64_t const top_k, int64_t const n_group, int64_t const topk_group, + int64_t const intermediate_size, int64_t const local_expert_offset, + int64_t const local_num_experts, double const routed_scaling_factor, + bool const use_routing_scales_on_input, int64_t const tile_tokens_dim, + int64_t const routing_method_type, bool enable_pdl) { static const std::tuple device_props = [hidden_states] { int major, minor; cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, @@ -224,8 +225,10 @@ TensorView trtllm_fp8_per_tensor_scale_moe_launcher( << "output2_scales_scalar has incorrect dim 0."; // allocate output - Tensor output = - alloc_tensor({args.num_tokens, args.hidden_size}, dl_bfloat16, hidden_states->device); + TVM_FFI_ICHECK_EQ(output->shape[0], args.num_tokens); + TVM_FFI_ICHECK_EQ(output->shape[1], args.hidden_size); + CHECK_INPUT_TYPE(output, dl_bfloat16); + CHECK_DEVICE(output, hidden_states); // setup workspace workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens->data); @@ -272,23 +275,22 @@ TensorView trtllm_fp8_per_tensor_scale_moe_launcher( cudaStream_t moe_stream = get_stream(hidden_states->device); moe_runner.run(args, workspace, hidden_states->device.device_id, moe_stream, moeConfigIndex, enable_pdl); - return output; } -TensorView trtllm_fp8_per_tensor_scale_moe( +void trtllm_fp8_per_tensor_scale_moe( TensorView routing_logits, Optional routing_bias, TensorView hidden_states, TensorView gemm1_weights, TensorView output1_scales_scalar, TensorView output1_scales_gate_scalar, TensorView gemm2_weights, - TensorView output2_scales_scalar, int64_t num_experts, int64_t top_k, int64_t n_group, - int64_t topk_group, int64_t intermediate_size, int64_t local_expert_offset, + TensorView output2_scales_scalar, TensorView output, int64_t num_experts, int64_t top_k, + int64_t n_group, int64_t topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, double routed_scaling_factor, bool use_routing_scales_on_input, int64_t tile_tokens_dim, int64_t routing_method_type, bool enable_pdl) { auto dtype = hidden_states->dtype; if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) { - return trtllm_fp8_per_tensor_scale_moe_launcher( + trtllm_fp8_per_tensor_scale_moe_launcher( routing_logits, routing_bias, hidden_states, gemm1_weights, output1_scales_scalar, - output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar, num_experts, top_k, - n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, + output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar, output, num_experts, + top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, use_routing_scales_on_input, tile_tokens_dim, routing_method_type, enable_pdl); } else { @@ -591,7 +593,7 @@ void trtllm_fp8_block_scale_moe(TensorView routing_logits, Optional int64_t moeConfigIndex = mRunner->getDefaultValidConfigIndex( top_k, hidden_size, intermediate_size, local_num_experts, num_tokens); - return trtllm_fp8_block_scale_moe_launcher( + trtllm_fp8_block_scale_moe_launcher( routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights, gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, output, num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, @@ -604,7 +606,7 @@ void trtllm_fp8_block_scale_moe(TensorView routing_logits, Optional // TODO(siyuan): This launcher supports flexible weight and activation types. // We should cleanup other launchers and only use this one in the future. -Array trtllm_fp4_block_scale_moe_launcher( +Array trtllm_fp4_block_scale_moe_launcher( Optional routing_logits, TensorView expert_indices, TensorView expert_weights, Optional routing_bias, TensorView hidden_states, Optional hidden_states_scale, TensorView gemm1_weights, @@ -795,7 +797,7 @@ Array trtllm_fp4_block_scale_moe_launcher( dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_uint8, hidden_states->device); - Optional gemm1_output_scale = std::nullopt; + Optional gemm1_output_scale = std::nullopt; if (dtype_act == btg::Dtype::E2m1 || dtype_act == btg::Dtype::MxE4m3) { int64_t sf_size = tensorrt_llm::computeSwizzledLayoutSFSize(max_num_padded_tokens, intermediate_size / sf_vec_size); @@ -1013,12 +1015,12 @@ Array trtllm_fp4_block_scale_moe_launcher( enable_pdl); if (!do_finalize) { - return {gemm2_output, expert_weights, expanded_idx_to_permuted_idx}; + return {gemm2_output, expanded_idx_to_permuted_idx}; } - return {output}; + return {}; } -Array trtllm_fp4_block_scale_moe( +Array trtllm_fp4_block_scale_moe( Optional routing_logits, TensorView topk_ids, TensorView expert_weights, Optional routing_bias, TensorView hidden_states, Optional hidden_states_scale, TensorView gemm1_weights, diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index e98e47d2a0..771de4673f 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1298,8 +1298,11 @@ def trtllm_fp8_per_tensor_scale_moe_op( ) -> torch.Tensor: if enable_pdl is None: enable_pdl = device_support_pdl(hidden_states.device) + output = torch.empty( + hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device + ) # Call the C++ function - output = moe_op.trtllm_fp8_per_tensor_scale_moe( + moe_op.trtllm_fp8_per_tensor_scale_moe( routing_logits, routing_bias, hidden_states, @@ -1308,6 +1311,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar, + output, num_experts, top_k, n_group, @@ -1576,7 +1580,7 @@ def trtllm_fp4_block_scale_moe_op( ) # Call the C++ function for block scale MoE - output = moe_op.trtllm_fp4_block_scale_moe( + intermediate_output = moe_op.trtllm_fp4_block_scale_moe( routing_logits, topk_ids, expert_weights, @@ -1611,12 +1615,15 @@ def trtllm_fp4_block_scale_moe_op( output, tactic, ) - if isinstance(output, tvm_ffi.Array): - output = list(output) - for i in range(len(output)): - if isinstance(output[i], tvm_ffi.Tensor): - output[i] = torch.from_dlpack(output[i]) - return output + if do_finalize: + return [output] + else: + gemm2_output, expanded_idx_to_permuted_idx = intermediate_output + return [ + torch.from_dlpack(gemm2_output), + expert_weights, + torch.from_dlpack(expanded_idx_to_permuted_idx), + ] @register_fake_op("flashinfer::trtllm_fp4_block_scale_moe") def _fake_trtllm_fp4_block_scale_moe( From afa7c570341e13f993cfda8c5087b8ea84b69957 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 3 Oct 2025 23:44:28 +0000 Subject: [PATCH 05/11] fix --- flashinfer/fused_moe/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 771de4673f..2045e43a18 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -20,7 +20,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch -import tvm_ffi from ..artifacts import ArtifactPath, MetaInfoHash from ..autotuner import ( From 8d9e5ba9602cd0b407c733f3de33eae7300cec9c Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sat, 4 Oct 2025 03:16:01 +0000 Subject: [PATCH 06/11] fix --- .../flashinfer_cutlass_fused_moe_sm100_binding.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu index d96b16ec05..717959392f 100644 --- a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu +++ b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu @@ -225,7 +225,7 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { void runMoe(TensorView output, TensorView input, TensorView token_selected_experts, Optional token_final_scales, TensorView fc1_expert_weights, Optional fc1_expert_biases, TensorView fc2_expert_weights, - Optional fc2_expert_biases, Optional> quant_scales, + Optional fc2_expert_biases, Optional> quant_scales, Optional input_sf, Optional swiglu_alpha, Optional swiglu_beta, Optional swiglu_limit, int64_t tp_size, int64_t tp_rank, int64_t ep_size, int64_t ep_rank, int64_t cluster_size, @@ -394,7 +394,7 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { Optional token_final_scales, TensorView fc1_expert_weights, Optional fc1_expert_biases, TensorView fc2_expert_weights, Optional fc2_expert_biases, - Optional> quant_scales, Optional input_sf, + Optional> quant_scales, Optional input_sf, Optional swiglu_alpha, Optional swiglu_beta, Optional swiglu_limit, TensorView num_active_experts_per_node, TensorView experts_to_token_score, TensorView active_expert_global_ids, @@ -670,7 +670,7 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { [this](TensorView output, TensorView input, TensorView token_selected_experts, Optional token_final_scales, TensorView fc1_expert_weights, Optional fc1_expert_biases, TensorView fc2_expert_weights, - Optional fc2_expert_biases, Optional> quant_scales, + Optional fc2_expert_biases, Optional> quant_scales, Optional input_sf, Optional swiglu_alpha, Optional swiglu_beta, Optional swiglu_limit, int64_t tp_size, int64_t tp_rank, int64_t ep_size, int64_t ep_rank, @@ -687,7 +687,7 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { [this](TensorView output, TensorView input, TensorView token_selected_experts, Optional token_final_scales, TensorView fc1_expert_weights, Optional fc1_expert_biases, TensorView fc2_expert_weights, - Optional fc2_expert_biases, Optional> quant_scales, + Optional fc2_expert_biases, Optional> quant_scales, Optional input_sf, Optional swiglu_alpha, Optional swiglu_beta, Optional swiglu_limit, TensorView num_active_experts_per_node, TensorView experts_to_token_score, @@ -784,7 +784,7 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { kernels::QuantParams getQuantParams(int64_t num_experts_on_rank, int64_t hidden_size, int64_t inter_size, - Optional> quant_scales) const { + Optional> quant_scales) const { if (isFp8Quant()) { TVM_FFI_ICHECK(quant_scales.has_value()) << "Expecting quant scales for fp8 quantization"; TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 4) From 1ed75be343cb244ee03e4fa499775804422abbfc Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sat, 4 Oct 2025 07:15:41 +0000 Subject: [PATCH 07/11] fix --- flashinfer/jit/activation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flashinfer/jit/activation.py b/flashinfer/jit/activation.py index a2fe27e944..4407ad146f 100644 --- a/flashinfer/jit/activation.py +++ b/flashinfer/jit/activation.py @@ -33,9 +33,9 @@ {{ act_func_def }} -void {{ func_name }}(Tensor out, Tensor input, bool enable_pdl) { +void {{ func_name }}(TensorView out, TensorView input, bool enable_pdl) { int d = input->shape[input->ndim -1] / 2; - int64_t num_tokens = get_numel(input) / input->shape[input->ndim -1]; + int64_t num_tokens = input.numel() / input->shape[input->ndim -1]; dim3 grid(num_tokens); cudaSetDevice(out->device.device_id); From a4117e929631fec941a8b4d4e87252f16d32580b Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sat, 4 Oct 2025 08:02:35 +0000 Subject: [PATCH 08/11] fix --- csrc/xqa/xqa_wrapper.cu | 4 +++- flashinfer/xqa.py | 7 ++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/csrc/xqa/xqa_wrapper.cu b/csrc/xqa/xqa_wrapper.cu index dad2ef92a5..e5a943ae9d 100644 --- a/csrc/xqa/xqa_wrapper.cu +++ b/csrc/xqa/xqa_wrapper.cu @@ -17,12 +17,14 @@ #include "../tvm_ffi_utils.h" #include "mha.h" +using tvm::ffi::Optional; + void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingWinSize, double qScale, TensorView output, #if LOW_PREC_OUTPUT TensorView rcpOutScale, #endif - TensorView q, TensorView attentionSinks, TensorView pool, + TensorView q, Optional attentionSinks, TensorView pool, TensorView kvCachePageList, int64_t maxSeqLen, TensorView seqLen, int64_t batchSize, TensorView kvCacheScale, #if SPEC_DEC diff --git a/flashinfer/xqa.py b/flashinfer/xqa.py index 4e4afdd844..726002741e 100644 --- a/flashinfer/xqa.py +++ b/flashinfer/xqa.py @@ -16,6 +16,7 @@ import functools from types import SimpleNamespace +from typing import Optional import torch @@ -50,7 +51,7 @@ def xqa( qScale: float, output: torch.Tensor, q: torch.Tensor, - attentionSinks: torch.Tensor, + attentionSinks: Optional[torch.Tensor], pool: torch.Tensor, kvCachePageList: torch.Tensor, maxSeqLen: int, @@ -88,7 +89,7 @@ def _fake_xqa( qScale: float, output: torch.Tensor, q: torch.Tensor, - attentionSinks: torch.Tensor, + attentionSinks: Optional[torch.Tensor], pool: torch.Tensor, kvCachePageList: torch.Tensor, maxSeqLen: int, @@ -117,7 +118,7 @@ def xqa( qScale: float, output: torch.Tensor, q: torch.Tensor, - attentionSinks: torch.Tensor, + attentionSinks: Optional[torch.Tensor], pool: torch.Tensor, kvCachePageList: torch.Tensor, maxSeqLen: int, From d53403f3cd9905e28983b23a174b2aae7879da38 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sat, 4 Oct 2025 08:44:15 +0000 Subject: [PATCH 09/11] fix --- csrc/xqa/xqa_wrapper.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/xqa/xqa_wrapper.cu b/csrc/xqa/xqa_wrapper.cu index e5a943ae9d..03bb48412c 100644 --- a/csrc/xqa/xqa_wrapper.cu +++ b/csrc/xqa/xqa_wrapper.cu @@ -33,7 +33,7 @@ void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingW TensorView semaphores, TensorView scratch) { auto stream = get_stream(output->device); float const* attentionSinksPtr = - attentionSinks.defined() ? reinterpret_cast(attentionSinks->data) : nullptr; + attentionSinks.has_value() ? reinterpret_cast(attentionSinks->data) : nullptr; launchMHAFlashInfer(multiProcessorCount, nbKHeads, slidingWinSize, qScale, reinterpret_cast(output->data), From bd1d5892c8a3724348889842d0a7192025b4f718 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sat, 4 Oct 2025 04:22:03 -0700 Subject: [PATCH 10/11] Update xqa_wrapper.cu --- csrc/xqa/xqa_wrapper.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/xqa/xqa_wrapper.cu b/csrc/xqa/xqa_wrapper.cu index 03bb48412c..eaf9fcc361 100644 --- a/csrc/xqa/xqa_wrapper.cu +++ b/csrc/xqa/xqa_wrapper.cu @@ -33,7 +33,7 @@ void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingW TensorView semaphores, TensorView scratch) { auto stream = get_stream(output->device); float const* attentionSinksPtr = - attentionSinks.has_value() ? reinterpret_cast(attentionSinks->data) : nullptr; + attentionSinks.has_value() ? reinterpret_cast(attentionSinks.value()->data) : nullptr; launchMHAFlashInfer(multiProcessorCount, nbKHeads, slidingWinSize, qScale, reinterpret_cast(output->data), From 15a100489b60dd33d90de70a4e407c7978673300 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sun, 5 Oct 2025 01:49:58 +0000 Subject: [PATCH 11/11] fix --- csrc/xqa/xqa_wrapper.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/xqa/xqa_wrapper.cu b/csrc/xqa/xqa_wrapper.cu index eaf9fcc361..1a5d636e10 100644 --- a/csrc/xqa/xqa_wrapper.cu +++ b/csrc/xqa/xqa_wrapper.cu @@ -33,7 +33,8 @@ void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingW TensorView semaphores, TensorView scratch) { auto stream = get_stream(output->device); float const* attentionSinksPtr = - attentionSinks.has_value() ? reinterpret_cast(attentionSinks.value()->data) : nullptr; + attentionSinks.has_value() ? reinterpret_cast(attentionSinks.value()->data) + : nullptr; launchMHAFlashInfer(multiProcessorCount, nbKHeads, slidingWinSize, qScale, reinterpret_cast(output->data),