From 816d107d0fd6fd0f02567673c097501abf7ce1b8 Mon Sep 17 00:00:00 2001 From: Lequn Chen Date: Tue, 29 Oct 2024 19:13:21 +0000 Subject: [PATCH 1/7] refactor JIT and AOT setup.py --- .gitignore | 7 +- docs/installation.rst | 15 +- flashinfer-aot/3rdparty | 1 - flashinfer-aot/MANIFEST.in | 12 - flashinfer-aot/csrc | 1 - .../csrc_aot/flashinfer_ops_decode.cu | 45 ---- .../csrc_aot/flashinfer_ops_prefill.cu | 56 ----- .../csrc_aot/flashinfer_sm90_ops.cu | 26 -- flashinfer-aot/flashinfer | 1 - flashinfer-aot/include | 1 - flashinfer-aot/version.txt | 1 - include/flashinfer/attention/scheduler.cuh | 14 +- python/3rdparty | 1 - python/MANIFEST.in | 12 - python/_aot_build_utils/__init__.py | 0 .../generate_batch_paged_decode_inst.py | 9 +- .../generate_batch_paged_prefill_inst.py | 12 +- .../generate_batch_ragged_prefill_inst.py | 11 +- .../generate_dispatch_inc.py | 5 +- .../generate_single_decode_inst.py | 9 +- .../generate_single_prefill_inst.py | 9 +- .../_aot_build_utils}/literal_map.py | 0 python/aot_MANIFEST.in | 13 + .../setup.py => python/aot_setup.py | 188 +++++++------- .../csrc_aot/activation.cu | 0 .../csrc_aot/batch_decode.cu | 0 .../csrc_aot/batch_prefill.cu | 0 .../csrc_aot/flashinfer_ops.cu | 234 +++++++++++++----- .../csrc_aot/pytorch_extension_utils.h | 0 .../csrc_aot/single_decode.cu | 0 .../csrc_aot/single_prefill.cu | 0 python/flashinfer/decode.py | 10 +- python/flashinfer/gemm.py | 4 +- python/flashinfer/jit/env.py | 10 +- python/flashinfer/prefill.py | 16 +- python/include | 1 - python/jit_MANIFEST.in | 15 ++ python/setup.py | 44 +++- python/version.txt | 1 - scripts/run-ci-build-wheel.sh | 3 +- 40 files changed, 406 insertions(+), 381 deletions(-) delete mode 120000 flashinfer-aot/3rdparty delete mode 100644 flashinfer-aot/MANIFEST.in delete mode 120000 flashinfer-aot/csrc delete mode 100644 flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu delete mode 100644 flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu delete mode 100644 flashinfer-aot/csrc_aot/flashinfer_sm90_ops.cu delete mode 120000 flashinfer-aot/flashinfer delete mode 120000 flashinfer-aot/include delete mode 120000 flashinfer-aot/version.txt delete mode 120000 python/3rdparty delete mode 100644 python/MANIFEST.in create mode 100644 python/_aot_build_utils/__init__.py rename {flashinfer-aot => python/_aot_build_utils}/generate_batch_paged_decode_inst.py (98%) rename {flashinfer-aot => python/_aot_build_utils}/generate_batch_paged_prefill_inst.py (98%) rename {flashinfer-aot => python/_aot_build_utils}/generate_batch_ragged_prefill_inst.py (99%) rename {flashinfer-aot => python/_aot_build_utils}/generate_dispatch_inc.py (99%) rename {flashinfer-aot => python/_aot_build_utils}/generate_single_decode_inst.py (98%) rename {flashinfer-aot => python/_aot_build_utils}/generate_single_prefill_inst.py (98%) rename {flashinfer-aot => python/_aot_build_utils}/literal_map.py (100%) create mode 100644 python/aot_MANIFEST.in rename flashinfer-aot/setup.py => python/aot_setup.py (80%) rename {flashinfer-aot => python}/csrc_aot/activation.cu (100%) rename {flashinfer-aot => python}/csrc_aot/batch_decode.cu (100%) rename {flashinfer-aot => python}/csrc_aot/batch_prefill.cu (100%) rename {flashinfer-aot => python}/csrc_aot/flashinfer_ops.cu (62%) rename {flashinfer-aot => python}/csrc_aot/pytorch_extension_utils.h (100%) rename {flashinfer-aot => python}/csrc_aot/single_decode.cu (100%) rename {flashinfer-aot => python}/csrc_aot/single_prefill.cu (100%) delete mode 120000 python/include create mode 100644 python/jit_MANIFEST.in delete mode 120000 python/version.txt diff --git a/.gitignore b/.gitignore index 14efeef1..fa13a77b 100644 --- a/.gitignore +++ b/.gitignore @@ -13,7 +13,12 @@ src/generated/ python/csrc/generated/ python/flashinfer/_build_meta.py python/flashinfer/jit/aot_config.py -flashinfer-aot/csrc_aot/generated/ +python/csrc_aot/generated/ + +# Package files +python/flashinfer/data/ +python/flashinfer/version.txt +python/MANIFEST.in # Generated documentation files docs/generated diff --git a/docs/installation.rst b/docs/installation.rst index 35bf84a1..4a423305 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -138,7 +138,7 @@ You can follow the steps below to install FlashInfer from source code: pip install ninja -4. Compile FlashInfer: +4. Install FlashInfer: .. tabs:: @@ -153,8 +153,17 @@ You can follow the steps below to install FlashInfer from source code: .. code-block:: bash - cd flashinfer/flashinfer-aot - pip install -e . -v + cd flashinfer/python + TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" python3 aot_setup.py bdist_wheel + pip install dist/flashinfer-*.whl + + .. tab:: Create sdist for JIT mode + + .. code-block:: bash + + cd flashinfer/python + python -m build --sdist + ls -la dist/ C++ API ------- diff --git a/flashinfer-aot/3rdparty b/flashinfer-aot/3rdparty deleted file mode 120000 index 303a6484..00000000 --- a/flashinfer-aot/3rdparty +++ /dev/null @@ -1 +0,0 @@ -../3rdparty \ No newline at end of file diff --git a/flashinfer-aot/MANIFEST.in b/flashinfer-aot/MANIFEST.in deleted file mode 100644 index b20747fe..00000000 --- a/flashinfer-aot/MANIFEST.in +++ /dev/null @@ -1,12 +0,0 @@ -# sdist & wheel -include version.txt -recursive-include include * -recursive-include csrc * -recursive-include 3rdparty/cutlass * - -# wheel-only -exclude flashinfer/_build_meta.py - -# Unneeded files -prune */__pycache__ -global-exclude *.so diff --git a/flashinfer-aot/csrc b/flashinfer-aot/csrc deleted file mode 120000 index bf562722..00000000 --- a/flashinfer-aot/csrc +++ /dev/null @@ -1 +0,0 @@ -../python/csrc \ No newline at end of file diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu b/flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu deleted file mode 100644 index fe665a1d..00000000 --- a/flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, - torch::Tensor tmp, - std::optional alibi_slopes, - unsigned int layout, int window_left, - float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta); - -std::vector BatchDecodeWithPagedKVCachePlan( - bool use_logits_soft_cap, unsigned int head_dim, torch::Tensor empty_q_data, - torch::Tensor empty_kv_data, torch::Tensor float_workspace_buffer, - torch::Tensor int_workspace_buffer, torch::Tensor page_locked_int_workspace_buffer, - torch::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads, - unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph); - -torch::Tensor BatchDecodeWithPagedKVCacheRun( - torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, - std::vector plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache, - torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, std::optional alibi_slopes, - unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, std::optional maybe_lse); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache, - "Single-request decode with KV-Cache operator"); - m.def("batch_decode_with_paged_kv_cache_plan", &BatchDecodeWithPagedKVCachePlan); - m.def("batch_decode_with_paged_kv_cache_run", &BatchDecodeWithPagedKVCacheRun); -} diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu b/flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu deleted file mode 100644 index 2f353d02..00000000 --- a/flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -torch::Tensor single_prefill_with_kv_cache( - unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v, - std::optional maybe_packed_custom_mask, torch::Tensor tmp, - std::optional maybe_alibi_slopes, unsigned int layout, int32_t window_left, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - std::optional maybe_lse); - -std::vector BatchPrefillWithKVCachePlan( - unsigned int head_dim, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, - torch::Tensor page_locked_int_workspace_buffer, torch::Tensor qo_indptr, - torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads, - unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph); - -torch::Tensor BatchPrefillWithRaggedKVCacheRun( - unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, - torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, - torch::Tensor k, torch::Tensor v, std::optional maybe_custom_mask, - std::optional maybe_alibi_slopes, torch::Tensor qo_indptr, - torch::Tensor kv_indptr, std::optional maybe_qk_indptr, unsigned int layout, - int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - std::optional maybe_lse); - -torch::Tensor BatchPrefillWithPagedKVCacheRun( - unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, - torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, - torch::Tensor paged_k_cache, torch::Tensor paged_v_cache, - std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, - torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, std::optional maybe_qk_indptr, - unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, std::optional maybe_lse); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, - "Single-request prefill attention with KV-Cache operator"); - m.def("batch_prefill_with_kv_cache_plan", &BatchPrefillWithKVCachePlan); - m.def("batch_prefill_with_ragged_kv_cache_run", &BatchPrefillWithRaggedKVCacheRun); - m.def("batch_prefill_with_paged_kv_cache_run", &BatchPrefillWithPagedKVCacheRun); -} diff --git a/flashinfer-aot/csrc_aot/flashinfer_sm90_ops.cu b/flashinfer-aot/csrc_aot/flashinfer_sm90_ops.cu deleted file mode 100644 index 5140982f..00000000 --- a/flashinfer-aot/csrc_aot/flashinfer_sm90_ops.cu +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - - -torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr, - torch::Tensor weight_indices, torch::Tensor x, - torch::Tensor weight, unsigned int batch_size, - bool weight_column_major); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90"); -} \ No newline at end of file diff --git a/flashinfer-aot/flashinfer b/flashinfer-aot/flashinfer deleted file mode 120000 index c5f9b1c7..00000000 --- a/flashinfer-aot/flashinfer +++ /dev/null @@ -1 +0,0 @@ -../python/flashinfer \ No newline at end of file diff --git a/flashinfer-aot/include b/flashinfer-aot/include deleted file mode 120000 index f5030fe8..00000000 --- a/flashinfer-aot/include +++ /dev/null @@ -1 +0,0 @@ -../include \ No newline at end of file diff --git a/flashinfer-aot/version.txt b/flashinfer-aot/version.txt deleted file mode 120000 index aa4e5bec..00000000 --- a/flashinfer-aot/version.txt +++ /dev/null @@ -1 +0,0 @@ -../version.txt \ No newline at end of file diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 423c989f..ecafee1e 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -50,7 +50,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ * the new batch size after the partition. */ template -auto PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( +inline auto PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( const uint32_t max_grid_size, const uint32_t num_kv_heads, const std::vector& num_pages, const uint32_t min_num_pages_per_batch = 1) { uint32_t low = min_num_pages_per_batch, high = 0; @@ -77,7 +77,7 @@ auto PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( return std::make_tuple(low, new_batch_size); } -auto PrefillBinarySearchKVChunkSize(const uint32_t max_batch_size_if_split, +inline auto PrefillBinarySearchKVChunkSize(const uint32_t max_batch_size_if_split, const std::vector& packed_qo_len_arr, const std::vector& kv_len_arr, const uint32_t qo_chunk_size, @@ -129,7 +129,7 @@ auto PrefillBinarySearchKVChunkSize(const uint32_t max_batch_size_if_split, */ template -cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( +inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, uint32_t& new_batch_size, uint32_t batch_size, typename AttentionVariant::IdType* kv_indptr_h, const uint32_t num_qo_heads, const uint32_t page_size, bool enable_cuda_graph, @@ -201,7 +201,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( * \return status Indicates whether CUDA calls are successful */ template -auto DecodeSplitKVIndptr(IdType* indptr_h, uint32_t batch_size, uint32_t kv_chunk_size) { +inline auto DecodeSplitKVIndptr(IdType* indptr_h, uint32_t batch_size, uint32_t kv_chunk_size) { std::vector request_indices, kv_tile_indices, o_indptr; o_indptr.push_back(0); @@ -277,7 +277,7 @@ struct DecodePlanInfo { }; template -cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, +inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, void* page_locked_int_buffer, size_t int_workspace_size_in_bytes, DecodePlanInfo& plan_info, typename AttentionVariant::IdType* indptr_h, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, @@ -350,7 +350,7 @@ cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in_bytes, } template -auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size, +inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size, uint32_t max_batch_size_if_split, bool enable_cuda_graph) { @@ -520,7 +520,7 @@ struct PrefillPlanInfo { }; template -cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, +inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, void* page_locked_int_buffer, size_t int_workspace_size_in_bytes, PrefillPlanInfo& plan_info, IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, diff --git a/python/3rdparty b/python/3rdparty deleted file mode 120000 index 303a6484..00000000 --- a/python/3rdparty +++ /dev/null @@ -1 +0,0 @@ -../3rdparty \ No newline at end of file diff --git a/python/MANIFEST.in b/python/MANIFEST.in deleted file mode 100644 index b20747fe..00000000 --- a/python/MANIFEST.in +++ /dev/null @@ -1,12 +0,0 @@ -# sdist & wheel -include version.txt -recursive-include include * -recursive-include csrc * -recursive-include 3rdparty/cutlass * - -# wheel-only -exclude flashinfer/_build_meta.py - -# Unneeded files -prune */__pycache__ -global-exclude *.so diff --git a/python/_aot_build_utils/__init__.py b/python/_aot_build_utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/flashinfer-aot/generate_batch_paged_decode_inst.py b/python/_aot_build_utils/generate_batch_paged_decode_inst.py similarity index 98% rename from flashinfer-aot/generate_batch_paged_decode_inst.py rename to python/_aot_build_utils/generate_batch_paged_decode_inst.py index efd1945b..7808c33b 100644 --- a/flashinfer-aot/generate_batch_paged_decode_inst.py +++ b/python/_aot_build_utils/generate_batch_paged_decode_inst.py @@ -14,14 +14,15 @@ limitations under the License. """ -import sys import re -from literal_map import ( - pos_encoding_mode_literal, +import sys +from pathlib import Path + +from .literal_map import ( dtype_literal, idtype_literal, + pos_encoding_mode_literal, ) -from pathlib import Path def get_cu_file_str( diff --git a/flashinfer-aot/generate_batch_paged_prefill_inst.py b/python/_aot_build_utils/generate_batch_paged_prefill_inst.py similarity index 98% rename from flashinfer-aot/generate_batch_paged_prefill_inst.py rename to python/_aot_build_utils/generate_batch_paged_prefill_inst.py index 21328aae..97f1423a 100644 --- a/flashinfer-aot/generate_batch_paged_prefill_inst.py +++ b/python/_aot_build_utils/generate_batch_paged_prefill_inst.py @@ -14,16 +14,16 @@ limitations under the License. """ -import sys import re -import itertools -from literal_map import ( - mask_mode_literal, - pos_encoding_mode_literal, +import sys +from pathlib import Path + +from .literal_map import ( dtype_literal, idtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, ) -from pathlib import Path def get_cu_file_str( diff --git a/flashinfer-aot/generate_batch_ragged_prefill_inst.py b/python/_aot_build_utils/generate_batch_ragged_prefill_inst.py similarity index 99% rename from flashinfer-aot/generate_batch_ragged_prefill_inst.py rename to python/_aot_build_utils/generate_batch_ragged_prefill_inst.py index 59acc67b..f5631303 100644 --- a/flashinfer-aot/generate_batch_ragged_prefill_inst.py +++ b/python/_aot_build_utils/generate_batch_ragged_prefill_inst.py @@ -14,15 +14,16 @@ limitations under the License. """ -import sys import re -from literal_map import ( - mask_mode_literal, - pos_encoding_mode_literal, +import sys +from pathlib import Path + +from .literal_map import ( dtype_literal, idtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, ) -from pathlib import Path def get_cu_file_str( diff --git a/flashinfer-aot/generate_dispatch_inc.py b/python/_aot_build_utils/generate_dispatch_inc.py similarity index 99% rename from flashinfer-aot/generate_dispatch_inc.py rename to python/_aot_build_utils/generate_dispatch_inc.py index f3ad9db8..30552e6e 100644 --- a/flashinfer-aot/generate_dispatch_inc.py +++ b/python/_aot_build_utils/generate_dispatch_inc.py @@ -16,10 +16,11 @@ import argparse from pathlib import Path -from literal_map import ( - pos_encoding_mode_literal, + +from .literal_map import ( bool_literal, mask_mode_literal, + pos_encoding_mode_literal, ) diff --git a/flashinfer-aot/generate_single_decode_inst.py b/python/_aot_build_utils/generate_single_decode_inst.py similarity index 98% rename from flashinfer-aot/generate_single_decode_inst.py rename to python/_aot_build_utils/generate_single_decode_inst.py index 754e185f..ce24d7e7 100644 --- a/flashinfer-aot/generate_single_decode_inst.py +++ b/python/_aot_build_utils/generate_single_decode_inst.py @@ -14,13 +14,14 @@ limitations under the License. """ -import sys import re -from literal_map import ( - pos_encoding_mode_literal, +import sys +from pathlib import Path + +from .literal_map import ( dtype_literal, + pos_encoding_mode_literal, ) -from pathlib import Path def get_cu_file_str( diff --git a/flashinfer-aot/generate_single_prefill_inst.py b/python/_aot_build_utils/generate_single_prefill_inst.py similarity index 98% rename from flashinfer-aot/generate_single_prefill_inst.py rename to python/_aot_build_utils/generate_single_prefill_inst.py index eb54ed4e..49eefd17 100644 --- a/flashinfer-aot/generate_single_prefill_inst.py +++ b/python/_aot_build_utils/generate_single_prefill_inst.py @@ -14,14 +14,15 @@ limitations under the License. """ -import sys import re -from literal_map import ( - pos_encoding_mode_literal, +import sys +from pathlib import Path + +from .literal_map import ( dtype_literal, mask_mode_literal, + pos_encoding_mode_literal, ) -from pathlib import Path def get_cu_file_str( diff --git a/flashinfer-aot/literal_map.py b/python/_aot_build_utils/literal_map.py similarity index 100% rename from flashinfer-aot/literal_map.py rename to python/_aot_build_utils/literal_map.py diff --git a/python/aot_MANIFEST.in b/python/aot_MANIFEST.in new file mode 100644 index 00000000..5819e735 --- /dev/null +++ b/python/aot_MANIFEST.in @@ -0,0 +1,13 @@ +# MANIFEST.in for AOT + +prune */__pycache__ +prune csrc +prune csrc_aot +exclude aot_setup.py +exclude setup.py + +include flashinfer/data/version.txt +graft flashinfer/data/csrc +graft flashinfer/data/include +graft flashinfer/data/cutlass/include +graft flashinfer/data/cutlass/tools/util/include diff --git a/flashinfer-aot/setup.py b/python/aot_setup.py similarity index 80% rename from flashinfer-aot/setup.py rename to python/aot_setup.py index 80fd4ea9..6dd3a8f4 100644 --- a/flashinfer-aot/setup.py +++ b/python/aot_setup.py @@ -14,33 +14,34 @@ limitations under the License. """ -from typing import List, Tuple - -import copy -import pathlib +import argparse +import contextlib +import itertools import os +import pathlib +import platform import re -import itertools +import shutil import subprocess -import platform +import sys +from typing import Iterator, List, Tuple import setuptools -import argparse import torch import torch.utils.cpp_extension as torch_cpp_ext -from collections import namedtuple -import generate_single_decode_inst, generate_single_prefill_inst, generate_batch_paged_decode_inst, generate_batch_paged_prefill_inst, generate_batch_ragged_prefill_inst, generate_dispatch_inc +root = pathlib.Path(__file__).resolve().parents[1] +sys.path.append(str(root / "python")) -root = pathlib.Path(__name__).parent - - -# cuda arch check for fp8 at the moment. -for cuda_arch_flags in torch_cpp_ext._get_cuda_arch_flags(): - arch = int(re.search("compute_\d+", cuda_arch_flags).group()[-2:]) - if arch < 75: - raise RuntimeError("FlashInfer requires sm75+") +from _aot_build_utils import ( + generate_batch_paged_decode_inst, + generate_batch_paged_prefill_inst, + generate_batch_ragged_prefill_inst, + generate_dispatch_inc, + generate_single_decode_inst, + generate_single_prefill_inst, +) enable_bf16 = os.environ.get("FLASHINFER_ENABLE_BF16", "1") == "1" enable_fp8 = os.environ.get("FLASHINFER_ENABLE_FP8", "1") == "1" @@ -61,7 +62,7 @@ def write_if_different(path: pathlib.Path, content: str) -> None: def get_instantiation_cu() -> Tuple[List[str], List[str], List[str]]: - path = root / "csrc_aot" / "generated" + path = root / "python" / "csrc_aot" / "generated" path.mkdir(parents=True, exist_ok=True) head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",") @@ -104,13 +105,7 @@ def get_instantiation_cu() -> Tuple[List[str], List[str], List[str]]: files_prefill = [] single_decode_uris = [] # single decode files - for ( - head_dim, - pos_encoding_mode, - ) in itertools.product( - head_dims, - pos_encoding_modes, - ): + for head_dim, pos_encoding_mode in itertools.product(head_dims, pos_encoding_modes): for dtype_q, dtype_kv in list(zip(decode_dtypes, decode_dtypes)) + list( itertools.product(fp16_dtypes, fp8_dtypes) ): @@ -278,6 +273,11 @@ def get_instantiation_cu() -> Tuple[List[str], List[str], List[str]]: f"f16qk_{bool(allow_fp16_qk_reduction)}" ) + # Change to relative path + this_dir = pathlib.Path(__file__).parent.resolve() + files_prefill = [str(pathlib.Path(p).relative_to(this_dir)) for p in files_prefill] + files_decode = [str(pathlib.Path(p).relative_to(this_dir)) for p in files_decode] + return ( files_prefill, files_decode, @@ -313,14 +313,14 @@ def generate_build_meta() -> None: d["torch"] = torch.__version__ d["python"] = platform.python_version() d["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", None) - with open(root / "flashinfer" / "_build_meta.py", "w") as f: + with open(root / "python" / "flashinfer" / "_build_meta.py", "w") as f: f.write(f"__version__ = {version!r}\n") f.write(f"build_meta = {d!r}") def generate_aot_config(aot_kernel_uris: List[str]) -> None: aot_config_str = f"""prebuilt_ops_uri = set({aot_kernel_uris})""" - with open(root / "flashinfer" / "jit" / "aot_config.py", "w") as f: + with open(root / "python" / "flashinfer" / "jit" / "aot_config.py", "w") as f: f.write(aot_config_str) @@ -348,11 +348,42 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) +@contextlib.contextmanager +def link_data_files() -> Iterator[None]: + this_dir = pathlib.Path(__file__).parent + data_dir = root / "python" / "flashinfer" / "data" + if data_dir.exists(): + shutil.rmtree(data_dir) + data_dir.mkdir(parents=True) + + def ln(src: str, dst: str, is_dir: bool = False) -> None: + (data_dir / dst).symlink_to(root / src, target_is_directory=is_dir) + + ln("3rdparty/cutlass", "cutlass", True) + ln("include", "include", True) + ln("python/csrc", "csrc", True) + ln("version.txt", "version.txt") + (this_dir / "MANIFEST.in").unlink(True) + (this_dir / "MANIFEST.in").symlink_to("jit_MANIFEST.in") + + yield + + shutil.rmtree(data_dir) + (this_dir / "MANIFEST.in").unlink(True) + + if __name__ == "__main__": + # cuda arch check for fp8 at the moment. + for cuda_arch_flags in torch_cpp_ext._get_cuda_arch_flags(): + arch = int(re.search(r"compute_(\d+)", cuda_arch_flags).group(1)) + if arch < 75: + raise RuntimeError("FlashInfer requires sm75+") + remove_unwanted_pytorch_nvcc_flags() generate_build_meta() files_prefill, files_decode, aot_kernel_uris = get_instantiation_cu() generate_aot_config(aot_kernel_uris) + include_dirs = [ str(root.resolve() / "include"), str(root.resolve() / "3rdparty" / "cutlass" / "include"), # for group gemm @@ -366,83 +397,54 @@ def __init__(self, *args, **kwargs) -> None: "nvcc": [ "-O3", "-std=c++17", - "--threads", - "1", + "--threads=1", "-Xfatbin", "-compress-all", "-use_fast_math", ], } - extra_compile_args_sm90 = copy.deepcopy(extra_compile_args) - extra_compile_args_sm90["nvcc"].extend( - "-gencode arch=compute_90a,code=sm_90a".split() - ) + sources = files_decode + files_prefill + sources += [ + "csrc/bmm_fp8.cu", + "csrc/cascade.cu", + "csrc/group_gemm.cu", + "csrc/group_gemm_sm90.cu", + "csrc/norm.cu", + "csrc/page.cu", + "csrc/quantization.cu", + "csrc/rope.cu", + "csrc/sampling.cu", + "csrc_aot/activation.cu", + "csrc_aot/batch_decode.cu", + "csrc_aot/batch_prefill.cu", + "csrc_aot/flashinfer_ops.cu", + "csrc_aot/single_decode.cu", + "csrc_aot/single_prefill.cu", + ] + ext_modules = [] ext_modules.append( torch_cpp_ext.CUDAExtension( name="flashinfer._kernels", - sources=[ - "csrc/cascade.cu", - "csrc/page.cu", - "csrc/sampling.cu", - "csrc/norm.cu", - "csrc_aot/activation.cu", - "csrc/rope.cu", - "csrc/quantization.cu", - "csrc/group_gemm.cu", - "csrc/bmm_fp8.cu", - "csrc_aot/flashinfer_ops.cu" - ], + sources=sources, include_dirs=include_dirs, extra_compile_args=extra_compile_args, ) ) - ext_modules.append( - torch_cpp_ext.CUDAExtension( - name="flashinfer._kernels_sm90", - sources=[ - "csrc/group_gemm_sm90.cu", - "csrc_aot/flashinfer_sm90_ops.cu", - ], - include_dirs=include_dirs, - extra_compile_args=extra_compile_args_sm90, + with link_data_files(): + setuptools.setup( + name="flashinfer", + version=get_version(), + packages=setuptools.find_packages( + include=["flashinfer.*"], + exclude=["flashinfer.data.*"], + ), + include_package_data=True, + author="FlashInfer team", + license="Apache License 2.0", + description="FlashInfer: Kernel Library for LLM Serving", + url="https://github.com/flashinfer-ai/flashinfer", + python_requires=">=3.8", + ext_modules=ext_modules, + cmdclass={"build_ext": NinjaBuildExtension}, ) - ) - ext_modules.append( - torch_cpp_ext.CUDAExtension( - name="flashinfer._decode_kernels", - sources=[ - "csrc_aot/single_decode.cu", - "csrc_aot/flashinfer_ops_decode.cu", - "csrc_aot/batch_decode.cu", - ] - + files_decode, - include_dirs=include_dirs, - extra_compile_args=extra_compile_args, - ) - ) - ext_modules.append( - torch_cpp_ext.CUDAExtension( - name="flashinfer._prefill_kernels", - sources=[ - "csrc_aot/single_prefill.cu", - "csrc_aot/flashinfer_ops_prefill.cu", - "csrc_aot/batch_prefill.cu", - ] - + files_prefill, - include_dirs=include_dirs, - extra_compile_args=extra_compile_args, - ) - ) - setuptools.setup( - name="flashinfer", - version=get_version(), - packages=setuptools.find_packages(), - author="FlashInfer team", - license="Apache License 2.0", - description="FlashInfer: Kernel Library for LLM Serving", - url="https://github.com/flashinfer-ai/flashinfer", - python_requires=">=3.8", - ext_modules=ext_modules, - cmdclass={"build_ext": NinjaBuildExtension}, - ) diff --git a/flashinfer-aot/csrc_aot/activation.cu b/python/csrc_aot/activation.cu similarity index 100% rename from flashinfer-aot/csrc_aot/activation.cu rename to python/csrc_aot/activation.cu diff --git a/flashinfer-aot/csrc_aot/batch_decode.cu b/python/csrc_aot/batch_decode.cu similarity index 100% rename from flashinfer-aot/csrc_aot/batch_decode.cu rename to python/csrc_aot/batch_decode.cu diff --git a/flashinfer-aot/csrc_aot/batch_prefill.cu b/python/csrc_aot/batch_prefill.cu similarity index 100% rename from flashinfer-aot/csrc_aot/batch_prefill.cu rename to python/csrc_aot/batch_prefill.cu diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops.cu b/python/csrc_aot/flashinfer_ops.cu similarity index 62% rename from flashinfer-aot/csrc_aot/flashinfer_ops.cu rename to python/csrc_aot/flashinfer_ops.cu index 05b259f5..2681dfac 100644 --- a/flashinfer-aot/csrc_aot/flashinfer_ops.cu +++ b/python/csrc_aot/flashinfer_ops.cu @@ -15,11 +15,13 @@ */ #include -void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, - torch::Tensor append_indptr, torch::Tensor paged_k_cache, - torch::Tensor paged_v_cache, torch::Tensor kv_indices, - torch::Tensor kv_indptr, torch::Tensor kv_last_page_len, - unsigned int layout); +//========== activation ========== + +void silu_and_mul(torch::Tensor& out, torch::Tensor& input); +void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); +void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); + +//========== cascade ========== std::vector merge_state(torch::Tensor v_a, torch::Tensor s_a, torch::Tensor v_b, torch::Tensor s_b); @@ -29,42 +31,46 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe std::vector merge_states(torch::Tensor v, torch::Tensor s); -torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples, - bool deterministic); +//========== decode ========== -std::vector top_p_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - std::optional maybe_top_p_arr, - double top_p_val, bool deterministic); +torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, + torch::Tensor tmp, + std::optional alibi_slopes, + unsigned int layout, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta); -std::vector top_k_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - std::optional maybe_top_k_arr, - unsigned int top_k_val, bool deterministic); +std::vector BatchDecodeWithPagedKVCachePlan( + bool use_logits_soft_cap, unsigned int head_dim, torch::Tensor empty_q_data, + torch::Tensor empty_kv_data, torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, torch::Tensor page_locked_int_workspace_buffer, + torch::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads, + unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph); -std::vector min_p_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - std::optional maybe_min_p_arr, - double min_p_val, bool deterministic); +torch::Tensor BatchDecodeWithPagedKVCacheRun( + torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + std::vector plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache, + torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, std::optional alibi_slopes, + unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, std::optional maybe_lse); -std::vector top_k_top_p_sampling_from_probs( - torch::Tensor probs, torch::Tensor uniform_samples, - std::optional maybe_top_k_arr, double top_k_val, - std::optional maybe_top_p_arr, double top_p_val, bool deterministic); +//========== gemm ========== -torch::Tensor top_p_renorm_probs(torch::Tensor probs, std::optional maybe_top_p_arr, - double top_p_val); +void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D, + torch::Tensor& A_scale, torch::Tensor& B_scale); -torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional maybe_top_k_arr, - unsigned int top_k_val); +torch::Tensor CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor seg_indptr, + torch::Tensor weight_indices, torch::Tensor x, + torch::Tensor weight, unsigned int batch_size, + bool weight_column_major); -torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional maybe_top_k_arr, - unsigned int top_k_val); +torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr, + torch::Tensor weight_indices, torch::Tensor x, + torch::Tensor weight, unsigned int batch_size, + bool weight_column_major); -torch::Tensor chain_speculative_sampling( - torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples, - torch::Tensor target_probs, torch::Tensor output_accepted_token_num, - torch::Tensor output_emitted_token_num, bool deterministic); +//========== norm ========== void rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps); @@ -76,11 +82,56 @@ void gemma_rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weig void gemma_fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double eps); -void silu_and_mul(torch::Tensor& out, torch::Tensor& input); +//========== page ========== -void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); +void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, + torch::Tensor append_indptr, torch::Tensor paged_k_cache, + torch::Tensor paged_v_cache, torch::Tensor kv_indices, + torch::Tensor kv_indptr, torch::Tensor kv_last_page_len, + unsigned int layout); -void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); +//========== prefill ========== + +torch::Tensor single_prefill_with_kv_cache( + unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v, + std::optional maybe_packed_custom_mask, torch::Tensor tmp, + std::optional maybe_alibi_slopes, unsigned int layout, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + std::optional maybe_lse); + +std::vector BatchPrefillWithKVCachePlan( + unsigned int head_dim, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + torch::Tensor page_locked_int_workspace_buffer, torch::Tensor qo_indptr, + torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads, + unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph); + +torch::Tensor BatchPrefillWithRaggedKVCacheRun( + unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, + torch::Tensor k, torch::Tensor v, std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, torch::Tensor qo_indptr, + torch::Tensor kv_indptr, std::optional maybe_qk_indptr, unsigned int layout, + int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + std::optional maybe_lse); + +torch::Tensor BatchPrefillWithPagedKVCacheRun( + unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, + torch::Tensor paged_k_cache, torch::Tensor paged_v_cache, + std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, + torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, std::optional maybe_qk_indptr, + unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, std::optional maybe_lse); + +//========== quantization ========== + +torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); + +torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, + torch::Tensor output_indptr, const std::string& bitorder); + +//========== rope ========== void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta); @@ -100,25 +151,99 @@ std::vector apply_llama31_rope(torch::Tensor q, torch::Tensor k, float low_freq_factor, float high_freq_factor, float old_context_length); -torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); +//========== sampling ========== -torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, - torch::Tensor output_indptr, const std::string& bitorder); +torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples, + bool deterministic); -void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D, - torch::Tensor& A_scale, torch::Tensor& B_scale); +std::vector top_p_sampling_from_probs(torch::Tensor probs, + torch::Tensor uniform_samples, + std::optional maybe_top_p_arr, + double top_p_val, bool deterministic); -torch::Tensor CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor seg_indptr, - torch::Tensor weight_indices, torch::Tensor x, - torch::Tensor weight, unsigned int batch_size, - bool weight_column_major); +std::vector top_k_sampling_from_probs(torch::Tensor probs, + torch::Tensor uniform_samples, + std::optional maybe_top_k_arr, + unsigned int top_k_val, bool deterministic); + +std::vector min_p_sampling_from_probs(torch::Tensor probs, + torch::Tensor uniform_samples, + std::optional maybe_min_p_arr, + double min_p_val, bool deterministic); + +std::vector top_k_top_p_sampling_from_probs( + torch::Tensor probs, torch::Tensor uniform_samples, + std::optional maybe_top_k_arr, double top_k_val, + std::optional maybe_top_p_arr, double top_p_val, bool deterministic); + +torch::Tensor top_p_renorm_probs(torch::Tensor probs, std::optional maybe_top_p_arr, + double top_p_val); + +torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional maybe_top_k_arr, + unsigned int top_k_val); + +torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional maybe_top_k_arr, + unsigned int top_k_val); + +torch::Tensor chain_speculative_sampling( + torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples, + torch::Tensor target_probs, torch::Tensor output_accepted_token_num, + torch::Tensor output_emitted_token_num, bool deterministic); + +//========== pybind11 ========== PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); + // activation + m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul"); + m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul"); + m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul"); + + // cascade m.def("merge_state", &merge_state, "Merge two self-attention states"); m.def("merge_state_in_place", &merge_state_in_place, "Merge another self-attention state in-place."); m.def("merge_states", &merge_states, "Merge multiple self-attention states"); + + // decode + m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache, + "Single-request decode with KV-Cache operator"); + m.def("batch_decode_with_paged_kv_cache_plan", &BatchDecodeWithPagedKVCachePlan); + m.def("batch_decode_with_paged_kv_cache_run", &BatchDecodeWithPagedKVCacheRun); + + // gemm + m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); + m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM operator"); + m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90"); + + // norm + m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); + m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization"); + m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization"); + m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, + "Gemma Fused add root mean square normalization"); + + // page + m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); + + // prefill + m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, + "Single-request prefill attention with KV-Cache operator"); + m.def("batch_prefill_with_kv_cache_plan", &BatchPrefillWithKVCachePlan); + m.def("batch_prefill_with_ragged_kv_cache_run", &BatchPrefillWithRaggedKVCacheRun); + m.def("batch_prefill_with_paged_kv_cache_run", &BatchPrefillWithPagedKVCacheRun); + + // quantization + m.def("packbits", &packbits, "GPU packbits operator"); + m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); + + // rope + m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place"); + m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace, + "Apply Llama 3.1 style RoPE in-place"); + m.def("apply_rope", &apply_rope, "Apply RoPE"); + m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); + + // sampling m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities"); m.def("top_k_sampling_from_probs", &top_k_sampling_from_probs, "Top-k sampling from probabilities"); @@ -133,21 +258,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("top_k_mask_logits", &top_k_mask_logits, "Mask logits by top-k mask"); m.def("chain_speculative_sampling", &chain_speculative_sampling, "Speculative sampling from sequence of probabilities"); - m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); - m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization"); - m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization"); - m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, - "Gemma Fused add root mean square normalization"); - m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul"); - m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul"); - m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul"); - m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place"); - m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace, - "Apply Llama 3.1 style RoPE in-place"); - m.def("apply_rope", &apply_rope, "Apply RoPE"); - m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); - m.def("packbits", &packbits, "GPU packbits operator"); - m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); - m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM operator"); - m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); } diff --git a/flashinfer-aot/csrc_aot/pytorch_extension_utils.h b/python/csrc_aot/pytorch_extension_utils.h similarity index 100% rename from flashinfer-aot/csrc_aot/pytorch_extension_utils.h rename to python/csrc_aot/pytorch_extension_utils.h diff --git a/flashinfer-aot/csrc_aot/single_decode.cu b/python/csrc_aot/single_decode.cu similarity index 100% rename from flashinfer-aot/csrc_aot/single_decode.cu rename to python/csrc_aot/single_decode.cu diff --git a/flashinfer-aot/csrc_aot/single_prefill.cu b/python/csrc_aot/single_prefill.cu similarity index 100% rename from flashinfer-aot/csrc_aot/single_prefill.cu rename to python/csrc_aot/single_prefill.cu diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index 5215ae45..4e7cd548 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -81,9 +81,9 @@ def get_single_decode_module(*args): if args not in _single_decode_modules: uri = get_single_decode_uri(*args) if has_prebuilt_ops and uri in prebuilt_ops_uri: - from . import _decode_kernels + from . import _kernels - run_func = _decode_kernels.single_decode_with_kv_cache + run_func = _kernels.single_decode_with_kv_cache else: run_func = compile_single_decode_module(*args).run @@ -143,7 +143,7 @@ def get_batch_decode_module(*args): if args not in _batch_decode_modules: uri = get_batch_decode_uri(*args) if has_prebuilt_ops and uri in prebuilt_ops_uri: - from . import _decode_kernels + from . import _kernels # NOTE(Zihao): we should avoid hard-coded index like this, refactor it later dtype_q = args[0] @@ -151,7 +151,7 @@ def get_batch_decode_module(*args): head_dim = args[4] use_logits_cap = args[7] plan_func = ( - lambda *plan_args: _decode_kernels.batch_decode_with_paged_kv_cache_plan( + lambda *plan_args: _kernels.batch_decode_with_paged_kv_cache_plan( use_logits_cap, head_dim, torch.empty(0, dtype=dtype_q), @@ -159,7 +159,7 @@ def get_batch_decode_module(*args): *plan_args, ) ) - run_func = _decode_kernels.batch_decode_with_paged_kv_cache_run + run_func = _kernels.batch_decode_with_paged_kv_cache_run else: mod = compile_batch_decode_module(*args) plan_func = mod.plan diff --git a/python/flashinfer/gemm.py b/python/flashinfer/gemm.py index 4c88ac2e..d9c5a0bb 100644 --- a/python/flashinfer/gemm.py +++ b/python/flashinfer/gemm.py @@ -118,9 +118,9 @@ def get_gemm_sm90_module(): global _gemm_module_sm90 if _gemm_module_sm90 is None: if has_prebuilt_ops: - from . import _kernels_sm90 + from . import _kernels - module = _kernels_sm90 + module = _kernels else: module = load_cuda_ops( "gemm_sm90", diff --git a/python/flashinfer/jit/env.py b/python/flashinfer/jit/env.py index e3fbec81..e9905cb8 100644 --- a/python/flashinfer/jit/env.py +++ b/python/flashinfer/jit/env.py @@ -20,10 +20,10 @@ FLASHINFER_WORKSPACE_DIR = pathlib.Path.home() / ".flashinfer" FLASHINFER_JIT_DIR = FLASHINFER_WORKSPACE_DIR / "cached_ops" FLASHINFER_GEN_SRC_DIR = FLASHINFER_WORKSPACE_DIR / "generated" -_project_root = pathlib.Path(__file__).resolve().parent.parent.parent -FLASHINFER_INCLUDE_DIR = _project_root / "include" -FLASHINFER_CSRC_DIR = _project_root / "csrc" +_package_root = pathlib.Path(__file__).resolve().parents[1] +FLASHINFER_INCLUDE_DIR = _package_root / "data" / "include" +FLASHINFER_CSRC_DIR = _package_root / "data" / "csrc" CUTLASS_INCLUDE_DIRS = [ - _project_root / "3rdparty" / "cutlass" / "include", - _project_root / "3rdparty" / "cutlass" / "tools" / "util" / "include", + _package_root / "data" / "cutlass" / "include", + _package_root / "data" / "cutlass" / "tools" / "util" / "include", ] diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index bfb8f48e..47f5d1a5 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -49,7 +49,7 @@ ) if has_prebuilt_ops: - from . import _prefill_kernels # type: ignore[attr-defined] + from . import _kernels # type: ignore[attr-defined] def compile_single_prefill_module( @@ -87,7 +87,7 @@ def get_single_prefill_module(*args): if has_prebuilt_ops and uri in prebuilt_ops_uri: # NOTE(Zihao): we should avoid hard-coded index like this, refactor it later mask_mode = args[5] - run_func = lambda *run_args: _prefill_kernels.single_prefill_with_kv_cache( + run_func = lambda *run_args: _kernels.single_prefill_with_kv_cache( mask_mode, *run_args, ) @@ -159,21 +159,19 @@ def get_batch_prefill_module(*args): if has_prebuilt_ops and uri in prebuilt_ops_uri: # NOTE(Zihao): we should avoid hard-coded index like this, refactor it later head_dim = args[4] - plan_func = ( - lambda *plan_args: _prefill_kernels.batch_prefill_with_kv_cache_plan( - head_dim, - *plan_args, - ) + plan_func = lambda *plan_args: _kernels.batch_prefill_with_kv_cache_plan( + head_dim, + *plan_args, ) mask_mode = args[6] ragged_run_func = ( - lambda *run_args: _prefill_kernels.batch_prefill_with_ragged_kv_cache_run( + lambda *run_args: _kernels.batch_prefill_with_ragged_kv_cache_run( mask_mode, *run_args, ) ) paged_run_func = ( - lambda *run_args: _prefill_kernels.batch_prefill_with_paged_kv_cache_run( + lambda *run_args: _kernels.batch_prefill_with_paged_kv_cache_run( mask_mode, *run_args, ) diff --git a/python/include b/python/include deleted file mode 120000 index 3a1af68f..00000000 --- a/python/include +++ /dev/null @@ -1 +0,0 @@ -../include/ \ No newline at end of file diff --git a/python/jit_MANIFEST.in b/python/jit_MANIFEST.in new file mode 100644 index 00000000..ea423d7d --- /dev/null +++ b/python/jit_MANIFEST.in @@ -0,0 +1,15 @@ +# MANIFEST.in for JIT + +global-exclude *.so + +prune */__pycache__ +prune csrc +prune csrc_aot +exclude aot_setup.py +exclude flashinfer/jit/aot_config.py + +include flashinfer/data/version.txt +graft flashinfer/data/csrc +graft flashinfer/data/include +graft flashinfer/data/cutlass/include +graft flashinfer/data/cutlass/tools/util/include diff --git a/python/setup.py b/python/setup.py index 52166d51..e057271f 100644 --- a/python/setup.py +++ b/python/setup.py @@ -14,43 +14,71 @@ limitations under the License. """ -from typing import List, Tuple - -import pathlib import os +import pathlib +import shutil +from typing import Iterator + import setuptools -root = pathlib.Path(__name__).parent +root = pathlib.Path(__file__).resolve().parents[1] +this_dir = pathlib.Path(__file__).parent def get_version(): version = os.getenv("FLASHINFER_BUILD_VERSION") if version is None: - with open(root / "version.txt") as f: + with open(this_dir / "flashinfer" / "data" / "version.txt") as f: version = f.read().strip() return version def generate_build_meta() -> None: version = get_version() - with open(root / "flashinfer/_build_meta.py", "w") as f: + with open(this_dir / "flashinfer" / "_build_meta.py", "w") as f: f.write(f"__version__ = {version!r}\n") def clear_aot_config(): # remove aot_config.py - aot_config_path = root / "flashinfer" / "jit" / "aot_config.py" + aot_config_path = this_dir / "flashinfer" / "jit" / "aot_config.py" if os.path.exists(aot_config_path): os.remove(aot_config_path) +def link_data_files() -> Iterator[None]: + this_dir = pathlib.Path(__file__).parent + data_dir = root / "python" / "flashinfer" / "data" + if data_dir.exists(): + shutil.rmtree(data_dir) + data_dir.mkdir(parents=True) + + def ln(src: str, dst: str, is_dir: bool = False) -> None: + (data_dir / dst).symlink_to(root / src, target_is_directory=is_dir) + + ln("3rdparty/cutlass", "cutlass", True) + ln("include", "include", True) + ln("python/csrc", "csrc", True) + ln("version.txt", "version.txt") + (this_dir / "MANIFEST.in").unlink(True) + (this_dir / "MANIFEST.in").symlink_to("jit_MANIFEST.in") + + # Unlike aot_setup.py, don't delete the symlinks after the build + # because editable installs rely on them. + + if __name__ == "__main__": + link_data_files() generate_build_meta() clear_aot_config() setuptools.setup( name="flashinfer", version=get_version(), - packages=setuptools.find_packages(), + packages=setuptools.find_packages( + include=["flashinfer.*"], + exclude=["flashinfer.data.*"], + ), + include_package_data=True, author="FlashInfer team", license="Apache License 2.0", description="FlashInfer: Kernel Library for LLM Serving", diff --git a/python/version.txt b/python/version.txt deleted file mode 120000 index aa4e5bec..00000000 --- a/python/version.txt +++ /dev/null @@ -1 +0,0 @@ -../version.txt \ No newline at end of file diff --git a/scripts/run-ci-build-wheel.sh b/scripts/run-ci-build-wheel.sh index c50118a1..5d445982 100644 --- a/scripts/run-ci-build-wheel.sh +++ b/scripts/run-ci-build-wheel.sh @@ -42,7 +42,8 @@ echo "::endgroup::" echo "::group::Build wheel for FlashInfer" cd "$PROJECT_ROOT/python" -FLASHINFER_BUILD_VERSION="${FLASHINFER_BUILD_VERSION}+cu${CUDA_MAJOR}${CUDA_MINOR}torch${FLASHINFER_CI_TORCH_VERSION}" python -m build --no-isolation +FLASHINFER_BUILD_VERSION="${FLASHINFER_BUILD_VERSION}+cu${CUDA_MAJOR}${CUDA_MINOR}torch${FLASHINFER_CI_TORCH_VERSION}" python aot_setup.py bdist_wheel rm -f dist/*.tar.gz python -m build --no-isolation --sdist +ls -la dist/ echo "::endgroup::" From 98c5681472d998d6882314c60797de10f8ea6516 Mon Sep 17 00:00:00 2001 From: Lequn Chen Date: Tue, 29 Oct 2024 20:38:12 +0000 Subject: [PATCH 2/7] separate sm90 kernels --- python/aot_setup.py | 55 ++++++++++++++++---------- python/csrc_aot/flashinfer_ops.cu | 1 - python/csrc_aot/flashinfer_sm90_ops.cu | 26 ++++++++++++ python/flashinfer/gemm.py | 4 +- 4 files changed, 63 insertions(+), 23 deletions(-) create mode 100644 python/csrc_aot/flashinfer_sm90_ops.cu diff --git a/python/aot_setup.py b/python/aot_setup.py index 6dd3a8f4..e8b9154d 100644 --- a/python/aot_setup.py +++ b/python/aot_setup.py @@ -16,6 +16,7 @@ import argparse import contextlib +import copy import itertools import os import pathlib @@ -403,34 +404,48 @@ def ln(src: str, dst: str, is_dir: bool = False) -> None: "-use_fast_math", ], } - sources = files_decode + files_prefill - sources += [ - "csrc/bmm_fp8.cu", - "csrc/cascade.cu", - "csrc/group_gemm.cu", - "csrc/group_gemm_sm90.cu", - "csrc/norm.cu", - "csrc/page.cu", - "csrc/quantization.cu", - "csrc/rope.cu", - "csrc/sampling.cu", - "csrc_aot/activation.cu", - "csrc_aot/batch_decode.cu", - "csrc_aot/batch_prefill.cu", - "csrc_aot/flashinfer_ops.cu", - "csrc_aot/single_decode.cu", - "csrc_aot/single_prefill.cu", - ] - + extra_compile_args_sm90 = copy.deepcopy(extra_compile_args) + extra_compile_args_sm90["nvcc"].extend( + "-gencode arch=compute_90a,code=sm_90a".split() + ) ext_modules = [] ext_modules.append( torch_cpp_ext.CUDAExtension( name="flashinfer._kernels", - sources=sources, + sources=[ + "csrc/bmm_fp8.cu", + "csrc/cascade.cu", + "csrc/group_gemm.cu", + "csrc/norm.cu", + "csrc/page.cu", + "csrc/quantization.cu", + "csrc/rope.cu", + "csrc/sampling.cu", + "csrc_aot/activation.cu", + "csrc_aot/batch_decode.cu", + "csrc_aot/batch_prefill.cu", + "csrc_aot/flashinfer_ops.cu", + "csrc_aot/single_decode.cu", + "csrc_aot/single_prefill.cu", + ] + + files_decode + + files_prefill, include_dirs=include_dirs, extra_compile_args=extra_compile_args, ) ) + ext_modules.append( + torch_cpp_ext.CUDAExtension( + name="flashinfer._kernels_sm90", + sources=[ + "csrc/group_gemm_sm90.cu", + "csrc_aot/flashinfer_sm90_ops.cu", + ], + include_dirs=include_dirs, + extra_compile_args=extra_compile_args_sm90, + ) + ) + with link_data_files(): setuptools.setup( name="flashinfer", diff --git a/python/csrc_aot/flashinfer_ops.cu b/python/csrc_aot/flashinfer_ops.cu index 2681dfac..13783be7 100644 --- a/python/csrc_aot/flashinfer_ops.cu +++ b/python/csrc_aot/flashinfer_ops.cu @@ -213,7 +213,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // gemm m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM operator"); - m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90"); // norm m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); diff --git a/python/csrc_aot/flashinfer_sm90_ops.cu b/python/csrc_aot/flashinfer_sm90_ops.cu new file mode 100644 index 00000000..d4222473 --- /dev/null +++ b/python/csrc_aot/flashinfer_sm90_ops.cu @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + + +torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr, + torch::Tensor weight_indices, torch::Tensor x, + torch::Tensor weight, unsigned int batch_size, + bool weight_column_major); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90"); +} diff --git a/python/flashinfer/gemm.py b/python/flashinfer/gemm.py index d9c5a0bb..4c88ac2e 100644 --- a/python/flashinfer/gemm.py +++ b/python/flashinfer/gemm.py @@ -118,9 +118,9 @@ def get_gemm_sm90_module(): global _gemm_module_sm90 if _gemm_module_sm90 is None: if has_prebuilt_ops: - from . import _kernels + from . import _kernels_sm90 - module = _kernels + module = _kernels_sm90 else: module = load_cuda_ops( "gemm_sm90", From e496fc557c87e363a47b01af3ab02b3f8a771915 Mon Sep 17 00:00:00 2001 From: Lequn Chen Date: Tue, 29 Oct 2024 22:42:00 +0000 Subject: [PATCH 3/7] torch custom_op fix for rope --- flashinfer-aot/csrc_aot/flashinfer_ops.cu | 40 +-- python/csrc/flashinfer_rope_ops.cu | 34 +- python/csrc/rope.cu | 42 +-- python/flashinfer/rope.py | 373 +++++++++++++++++----- tests/conftest.py | 4 + 5 files changed, 346 insertions(+), 147 deletions(-) diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops.cu b/flashinfer-aot/csrc_aot/flashinfer_ops.cu index 9ab9a86c..5cf365ca 100644 --- a/flashinfer-aot/csrc_aot/flashinfer_ops.cu +++ b/flashinfer-aot/csrc_aot/flashinfer_ops.cu @@ -83,29 +83,23 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); -std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta); - -std::vector apply_llama31_rope(torch::Tensor q, torch::Tensor k, - torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor indptr, torch::Tensor offsets, - bool interleave, float rope_scale, float rope_theta, - float low_freq_factor, float high_freq_factor, - float old_context_length); - -std::vector apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, - torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor pos_ids, bool interleave, - float rope_scale, float rope_theta); - -std::vector apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, - torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor pos_ids, bool interleave, - float rope_scale, float rope_theta, - float low_freq_factor, float high_freq_factor, - float old_context_length); +void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta); + +void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, + bool interleave, float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length); + +void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta); + +void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length); torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); diff --git a/python/csrc/flashinfer_rope_ops.cu b/python/csrc/flashinfer_rope_ops.cu index ef046ead..c6259968 100644 --- a/python/csrc/flashinfer_rope_ops.cu +++ b/python/csrc/flashinfer_rope_ops.cu @@ -17,29 +17,23 @@ #include -std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta); +void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta); -std::vector apply_llama31_rope(torch::Tensor q, torch::Tensor k, - torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor indptr, torch::Tensor offsets, - bool interleave, float rope_scale, float rope_theta, - float low_freq_factor, float high_freq_factor, - float old_context_length); +void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, + bool interleave, float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length); -std::vector apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, - torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor pos_ids, bool interleave, - float rope_scale, float rope_theta); +void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta); -std::vector apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, - torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor pos_ids, bool interleave, - float rope_scale, float rope_theta, - float low_freq_factor, float high_freq_factor, - float old_context_length); +void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("apply_rope", &apply_rope, "Apply RoPE"); diff --git a/python/csrc/rope.cu b/python/csrc/rope.cu index d2ca9155..8f661da0 100644 --- a/python/csrc/rope.cu +++ b/python/csrc/rope.cu @@ -19,10 +19,9 @@ using namespace flashinfer; -std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta) { +void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(indptr); @@ -65,14 +64,11 @@ std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::T std::string(cudaGetErrorString(status))); return true; }); - - return {q_rope, k_rope}; } -std::vector apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, - torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor pos_ids, bool interleave, - float rope_scale, float rope_theta) { +void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(pos_ids); @@ -109,16 +105,12 @@ std::vector apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, std::string(cudaGetErrorString(status))); return true; }); - - return {q_rope, k_rope}; } -std::vector apply_llama31_rope(torch::Tensor q, torch::Tensor k, - torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor indptr, torch::Tensor offsets, - bool interleave, float rope_scale, float rope_theta, - float low_freq_factor, float high_freq_factor, - float old_context_length) { +void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, + bool interleave, float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(indptr); @@ -162,16 +154,12 @@ std::vector apply_llama31_rope(torch::Tensor q, torch::Tensor k, std::string(cudaGetErrorString(status))); return true; }); - - return {q_rope, k_rope}; } -std::vector apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, - torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor pos_ids, bool interleave, - float rope_scale, float rope_theta, - float low_freq_factor, float high_freq_factor, - float old_context_length) { +void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(pos_ids); @@ -209,6 +197,4 @@ std::vector apply_llama31_rope_pos_ids(torch::Tensor q, torch::Te std::string(cudaGetErrorString(status))); return true; }); - - return {q_rope, k_rope}; } diff --git a/python/flashinfer/rope.py b/python/flashinfer/rope.py index 29c2fcb7..60ca2c6e 100644 --- a/python/flashinfer/rope.py +++ b/python/flashinfer/rope.py @@ -42,7 +42,165 @@ def get_rope_module(): return _rope_module -@register_custom_op("flashinfer::apply_rope_inplace", mutates_args=("q", "k")) +@register_custom_op("flashinfer::apply_rope", mutates_args=("q_rope", "k_rope")) +def _apply_rope( + q: torch.Tensor, + k: torch.Tensor, + q_rope: torch.Tensor, + k_rope: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + interleave: bool, + rope_scale: float, + rope_theta: float, +) -> None: + get_rope_module().apply_rope( + q, k, q_rope, k_rope, indptr, offsets, interleave, rope_scale, rope_theta + ) + + +@register_fake_op("flashinfer::apply_rope") +def _fake_apply_rope( + q: torch.Tensor, + k: torch.Tensor, + q_rope: torch.Tensor, + k_rope: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + interleave: bool, + rope_scale: float, + rope_theta: float, +) -> None: + pass + + +@register_custom_op("flashinfer::apply_llama31_rope", mutates_args=("q_rope", "k_rope")) +def _apply_llama31_rope( + q: torch.Tensor, + k: torch.Tensor, + q_rope: torch.Tensor, + k_rope: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + interleave: bool, + rope_scale: float, + rope_theta: float, + low_freq_factor: float, + high_freq_factor: float, + old_context_len: float, +) -> None: + get_rope_module().apply_llama31_rope( + q, + k, + q_rope, + k_rope, + indptr, + offsets, + interleave, + rope_scale, + rope_theta, + low_freq_factor, + high_freq_factor, + old_context_len, + ) + + +@register_fake_op("flashinfer::apply_llama31_rope") +def _fake_apply_llama31_rope( + q: torch.Tensor, + k: torch.Tensor, + q_rope: torch.Tensor, + k_rope: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + interleave: bool, + rope_scale: float, + rope_theta: float, + low_freq_factor: float, + high_freq_factor: float, + old_context_len: float, +) -> None: + pass + + +@register_custom_op("flashinfer::apply_rope_pos_ids", mutates_args=("q_rope", "k_rope")) +def _apply_rope_pos_ids( + q: torch.Tensor, + k: torch.Tensor, + q_rope: torch.Tensor, + k_rope: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool, + rope_scale: float, + rope_theta: float, +) -> None: + get_rope_module().apply_rope_pos_ids( + q, k, q_rope, k_rope, pos_ids, interleave, rope_scale, rope_theta + ) + + +@register_fake_op("flashinfer::apply_rope_pos_ids") +def _fake_apply_rope_pos_ids( + q: torch.Tensor, + k: torch.Tensor, + q_rope: torch.Tensor, + k_rope: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool, + rope_scale: float, + rope_theta: float, +) -> None: + pass + + +@register_custom_op( + "flashinfer::apply_llama31_rope_pos_ids", mutates_args=("q_rope", "k_rope") +) +def _apply_llama31_rope_pos_ids( + q: torch.Tensor, + k: torch.Tensor, + q_rope: torch.Tensor, + k_rope: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool, + rope_scale: float, + rope_theta: float, + low_freq_factor: float, + high_freq_factor: float, + old_context_len: float, +) -> None: + get_rope_module().apply_llama31_rope_pos_ids( + q, + k, + q_rope, + k_rope, + pos_ids, + interleave, + rope_scale, + rope_theta, + low_freq_factor, + high_freq_factor, + old_context_len, + ) + + +@register_fake_op("flashinfer::apply_llama31_rope_pos_ids") +def _fake_apply_llama31_rope_pos_ids( + q: torch.Tensor, + k: torch.Tensor, + q_rope: torch.Tensor, + k_rope: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool, + rope_scale: float, + rope_theta: float, + low_freq_factor: float, + high_freq_factor: float, + old_context_len: float, +) -> None: + pass + + def apply_rope_inplace( q: torch.Tensor, k: torch.Tensor, @@ -118,25 +276,9 @@ def apply_rope_inplace( -------- apply_rope """ - get_rope_module().apply_rope( - q, k, q, k, indptr, offsets, interleave, rope_scale, rope_theta - ) - - -@register_fake_op("flashinfer::apply_rope_inplace") -def _fake_apply_rope_inplace( - q: torch.Tensor, - k: torch.Tensor, - indptr: torch.Tensor, - offsets: torch.Tensor, - interleave: bool = False, - rope_scale: float = 1, - rope_theta: float = 1e4, -) -> None: - pass + _apply_rope(q, k, q, k, indptr, offsets, interleave, rope_scale, rope_theta) -@register_custom_op("flashinfer::apply_rope_pos_ids_inplace", mutates_args=("q", "k")) def apply_rope_pos_ids_inplace( q: torch.Tensor, k: torch.Tensor, @@ -183,24 +325,9 @@ def apply_rope_pos_ids_inplace( -------- apply_rope_pos_ids """ - get_rope_module().apply_rope_pos_ids( - q, k, q, k, pos_ids, interleave, rope_scale, rope_theta - ) + _apply_rope_pos_ids(q, k, q, k, pos_ids, interleave, rope_scale, rope_theta) -@register_fake_op("flashinfer::apply_rope_pos_ids_inplace") -def _fake_apply_rope_pos_ids_inplace( - q: torch.Tensor, - k: torch.Tensor, - pos_ids: torch.Tensor, - interleave: bool = False, - rope_scale: float = 1, - rope_theta: float = 1e4, -) -> None: - pass - - -@register_custom_op("flashinfer::apply_llama31_rope_inplace", mutates_args=("q", "k")) def apply_llama31_rope_inplace( q: torch.Tensor, k: torch.Tensor, @@ -286,7 +413,7 @@ def apply_llama31_rope_inplace( -------- apply_llama31_rope """ - get_rope_module().apply_llama31_rope( + _apply_llama31_rope( q, k, q, @@ -302,12 +429,10 @@ def apply_llama31_rope_inplace( ) -@register_fake_op("flashinfer::apply_llama31_rope_inplace") -def _fake_apply_llama31_rope_inplace( +def apply_llama31_rope_pos_ids_inplace( q: torch.Tensor, k: torch.Tensor, - indptr: torch.Tensor, - offsets: torch.Tensor, + pos_ids: torch.Tensor, interleave: bool = True, rope_scale: float = 8, rope_theta: float = 5e5, @@ -315,10 +440,66 @@ def _fake_apply_llama31_rope_inplace( high_freq_factor: float = 4, old_context_len: int = 8192, ) -> None: - pass + r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as + RaggedTensor) inplace. + + We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th + segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the + i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always + 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. + Please see :ref:`Ragged Tensor tutorial ` for more details about the + ragged tensor. + + Parameters + ---------- + q : torch.Tensor + Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. + k : torch.Tensor + Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. + pos_ids : torch.Tensor + Position indices, shape: ``(nnz)``. + interleave : bool + Whether to use interleaved layout in the last dimension, default: ``False``. + + * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., + we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + + * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., + we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half + dimensions ``([..., head_dim//2:])``. + + rope_scale : float + The scaling factor used in the rope embedding, default: ``8``. + rope_theta : float + The theta value used in the rope embedding, default: ``5e5``. + low_freq_factor : float + The low frequency factor used in Llama 3.1 RoPE, default: ``1``. + high_freq_factor : float + The high frequency factor used in Llama 3.1 RoPE, default: ``4``. + old_context_len : int + The old context length used in Llama 3.1 RoPE, default: ``8192``. + + See Also + -------- + apply_llama31_rope_pos_ids + """ + _apply_llama31_rope_pos_ids( + q, + k, + q, + k, + pos_ids, + interleave, + rope_scale, + rope_theta, + low_freq_factor, + high_freq_factor, + float(old_context_len), + ) -@register_custom_op("flashinfer::apply_rope", mutates_args=()) def apply_rope( q: torch.Tensor, k: torch.Tensor, @@ -407,25 +588,12 @@ def apply_rope( """ q_rope = torch.empty_like(q) k_rope = torch.empty_like(k) - return get_rope_module().apply_rope( + _apply_rope( q, k, q_rope, k_rope, indptr, offsets, interleave, rope_scale, rope_theta ) + return q_rope, k_rope -@register_fake_op("flashinfer::apply_rope") -def _fake_apply_rope( - q: torch.Tensor, - k: torch.Tensor, - indptr: torch.Tensor, - offsets: torch.Tensor, - interleave: bool = False, - rope_scale: float = 1, - rope_theta: float = 1e4, -) -> Tuple[torch.Tensor, torch.Tensor]: - return torch.empty_like(q), torch.empty_like(k) - - -@register_custom_op("flashinfer::apply_rope_pos_ids", mutates_args=()) def apply_rope_pos_ids( q: torch.Tensor, k: torch.Tensor, @@ -481,24 +649,12 @@ def apply_rope_pos_ids( """ q_rope = torch.empty_like(q) k_rope = torch.empty_like(k) - return get_rope_module().apply_rope_pos_ids( + _apply_rope_pos_ids( q, k, q_rope, k_rope, pos_ids, interleave, rope_scale, rope_theta ) + return q_rope, k_rope -@register_fake_op("flashinfer::apply_rope_pos_ids") -def _fake_apply_rope_pos_ids( - q: torch.Tensor, - k: torch.Tensor, - pos_ids: torch.Tensor, - interleave: bool = False, - rope_scale: float = 1, - rope_theta: float = 1e4, -) -> Tuple[torch.Tensor, torch.Tensor]: - return torch.empty_like(q), torch.empty_like(k) - - -@register_custom_op("flashinfer::apply_llama31_rope", mutates_args=()) def apply_llama31_rope( q: torch.Tensor, k: torch.Tensor, @@ -597,7 +753,7 @@ def apply_llama31_rope( """ q_rope = torch.empty_like(q) k_rope = torch.empty_like(k) - return get_rope_module().apply_llama31_rope( + _apply_llama31_rope( q, k, q_rope, @@ -611,14 +767,13 @@ def apply_llama31_rope( high_freq_factor, float(old_context_len), ) + return q_rope, k_rope -@register_fake_op("flashinfer::apply_llama31_rope") -def _fake_apply_llama31_rope( +def apply_llama31_rope_pos_ids( q: torch.Tensor, k: torch.Tensor, - indptr: torch.Tensor, - offsets: torch.Tensor, + pos_ids: torch.Tensor, interleave: bool = True, rope_scale: float = 8, rope_theta: float = 5e5, @@ -626,4 +781,70 @@ def _fake_apply_llama31_rope( high_freq_factor: float = 4, old_context_len: int = 8192, ) -> Tuple[torch.Tensor, torch.Tensor]: - return torch.empty_like(q), torch.empty_like(k) + r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as + RaggedTensor). + + We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th + segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the + i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always + 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. + Please see :ref:`Ragged Tensor tutorial ` for more details about the + ragged tensor. + + Parameters + ---------- + q : torch.Tensor + Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. + k : torch.Tensor + Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. + pos_ids : torch.Tensor + Position indices, shape: ``(nnz)``. + interleave : bool + Whether to use interleaved layout in the last dimension, default: ``False``. + + * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., + we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + + * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., + we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half + dimensions ``([..., head_dim//2:])`` + rope_scale : float + The scaling factor used in the rope embedding, default: ``8``. + rope_theta : float + The theta value used in the rope embedding, default: ``5e5``. + low_freq_factor : float + The low frequency factor used in Llama 3.1 RoPE, default: ``1``. + high_freq_factor : float + The high frequency factor used in Llama 3.1 RoPE, default: ``4``. + old_context_len : int + The old context length used in Llama 3.1 RoPE, default: ``8192``. + + Returns + ------- + q_rope : torch.Tensor + The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``. + k_rope : torch.Tensor + The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``. + + See Also + -------- + apply_llama31_rope_pos_ids_inplace + """ + q_rope = torch.empty_like(q) + k_rope = torch.empty_like(k) + _apply_llama31_rope_pos_ids( + q, + k, + q_rope, + k_rope, + pos_ids, + interleave, + rope_scale, + rope_theta, + low_freq_factor, + high_freq_factor, + float(old_context_len), + ) + return q_rope, k_rope diff --git a/tests/conftest.py b/tests/conftest.py index 95738ddb..08697065 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,8 +32,12 @@ flashinfer.quantization.packbits, flashinfer.rope.apply_rope, flashinfer.rope.apply_rope_inplace, + flashinfer.rope.apply_rope_pos_ids, + flashinfer.rope.apply_rope_pos_ids_inplace, flashinfer.rope.apply_llama31_rope, flashinfer.rope.apply_llama31_rope_inplace, + flashinfer.rope.apply_llama31_rope_pos_ids, + flashinfer.rope.apply_llama31_rope_pos_ids_inplace, flashinfer.sampling.sampling_from_probs, flashinfer.sampling.top_p_sampling_from_probs, flashinfer.sampling.top_k_sampling_from_probs, From fdc45dcc2dba92c5f7552f44fd62e1a215642b5d Mon Sep 17 00:00:00 2001 From: Lequn Chen Date: Tue, 29 Oct 2024 23:04:25 +0000 Subject: [PATCH 4/7] fix setuptools.find_packages glob --- python/aot_setup.py | 4 ++-- python/setup.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/aot_setup.py b/python/aot_setup.py index e8b9154d..95459ec8 100644 --- a/python/aot_setup.py +++ b/python/aot_setup.py @@ -451,8 +451,8 @@ def ln(src: str, dst: str, is_dir: bool = False) -> None: name="flashinfer", version=get_version(), packages=setuptools.find_packages( - include=["flashinfer.*"], - exclude=["flashinfer.data.*"], + include=["flashinfer*"], + exclude=["flashinfer.data*"], ), include_package_data=True, author="FlashInfer team", diff --git a/python/setup.py b/python/setup.py index e057271f..4ae3e637 100644 --- a/python/setup.py +++ b/python/setup.py @@ -75,8 +75,8 @@ def ln(src: str, dst: str, is_dir: bool = False) -> None: name="flashinfer", version=get_version(), packages=setuptools.find_packages( - include=["flashinfer.*"], - exclude=["flashinfer.data.*"], + include=["flashinfer*"], + exclude=["flashinfer.data*"], ), include_package_data=True, author="FlashInfer team", From ca0285412023b5c922669232c5f8bc3911422e1a Mon Sep 17 00:00:00 2001 From: Lequn Chen Date: Tue, 29 Oct 2024 23:13:29 +0000 Subject: [PATCH 5/7] remove mypy.ini and pylintrc from JIT sdist --- python/aot_MANIFEST.in | 2 +- python/jit_MANIFEST.in | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/aot_MANIFEST.in b/python/aot_MANIFEST.in index 5819e735..e1988769 100644 --- a/python/aot_MANIFEST.in +++ b/python/aot_MANIFEST.in @@ -1,4 +1,4 @@ -# MANIFEST.in for AOT +# MANIFEST.in for AOT wheel prune */__pycache__ prune csrc diff --git a/python/jit_MANIFEST.in b/python/jit_MANIFEST.in index ea423d7d..33022575 100644 --- a/python/jit_MANIFEST.in +++ b/python/jit_MANIFEST.in @@ -1,12 +1,14 @@ -# MANIFEST.in for JIT +# MANIFEST.in for JIT sdist global-exclude *.so prune */__pycache__ prune csrc prune csrc_aot -exclude aot_setup.py exclude flashinfer/jit/aot_config.py +exclude aot_setup.py +exclude mypy.ini +exclude pylintrc include flashinfer/data/version.txt graft flashinfer/data/csrc From a57c9a309fae5f3679a8cc8db30832c58c0bc2c9 Mon Sep 17 00:00:00 2001 From: Lequn Chen Date: Wed, 30 Oct 2024 03:24:09 +0000 Subject: [PATCH 6/7] Suppress setuptools false warnings --- python/aot_setup.py | 5 +++++ python/flashinfer/triton/kernels/__init__.py | 0 python/setup.py | 6 ++++++ 3 files changed, 11 insertions(+) create mode 100644 python/flashinfer/triton/kernels/__init__.py diff --git a/python/aot_setup.py b/python/aot_setup.py index 95459ec8..cca6488e 100644 --- a/python/aot_setup.py +++ b/python/aot_setup.py @@ -25,6 +25,7 @@ import shutil import subprocess import sys +import warnings from typing import Iterator, List, Tuple import setuptools @@ -446,6 +447,10 @@ def ln(src: str, dst: str, is_dir: bool = False) -> None: ) ) + # Suppress warnings complaining that: + # Package 'flashinfer.data*' is absent from the `packages` configuration. + warnings.filterwarnings("ignore", r".*flashinfer\.data.*", UserWarning) + with link_data_files(): setuptools.setup( name="flashinfer", diff --git a/python/flashinfer/triton/kernels/__init__.py b/python/flashinfer/triton/kernels/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/setup.py b/python/setup.py index 4ae3e637..e45324c2 100644 --- a/python/setup.py +++ b/python/setup.py @@ -18,6 +18,7 @@ import pathlib import shutil from typing import Iterator +import warnings import setuptools @@ -71,6 +72,11 @@ def ln(src: str, dst: str, is_dir: bool = False) -> None: link_data_files() generate_build_meta() clear_aot_config() + + # Suppress warnings complaining that: + # Package 'flashinfer.data*' is absent from the `packages` configuration. + warnings.filterwarnings("ignore", r".*flashinfer\.data.*", UserWarning) + setuptools.setup( name="flashinfer", version=get_version(), From c43864baccfc1a4266292ffcf11c7aecf27dbe08 Mon Sep 17 00:00:00 2001 From: Lequn Chen Date: Wed, 30 Oct 2024 07:15:20 +0000 Subject: [PATCH 7/7] skip rm symlinks when develop --- python/aot_setup.py | 5 +++-- python/setup.py | 39 +++++++++++++++++++++++---------------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/python/aot_setup.py b/python/aot_setup.py index cca6488e..ee9cfa2a 100644 --- a/python/aot_setup.py +++ b/python/aot_setup.py @@ -370,8 +370,9 @@ def ln(src: str, dst: str, is_dir: bool = False) -> None: yield - shutil.rmtree(data_dir) - (this_dir / "MANIFEST.in").unlink(True) + if sys.argv[1] != "develop": + shutil.rmtree(data_dir) + (this_dir / "MANIFEST.in").unlink(True) if __name__ == "__main__": diff --git a/python/setup.py b/python/setup.py index e45324c2..fff0d5e2 100644 --- a/python/setup.py +++ b/python/setup.py @@ -14,9 +14,11 @@ limitations under the License. """ +import contextlib import os import pathlib import shutil +import sys from typing import Iterator import warnings @@ -47,6 +49,7 @@ def clear_aot_config(): os.remove(aot_config_path) +@contextlib.contextmanager def link_data_files() -> Iterator[None]: this_dir = pathlib.Path(__file__).parent data_dir = root / "python" / "flashinfer" / "data" @@ -64,8 +67,11 @@ def ln(src: str, dst: str, is_dir: bool = False) -> None: (this_dir / "MANIFEST.in").unlink(True) (this_dir / "MANIFEST.in").symlink_to("jit_MANIFEST.in") - # Unlike aot_setup.py, don't delete the symlinks after the build - # because editable installs rely on them. + yield + + if sys.argv[1] != "develop": + shutil.rmtree(data_dir) + (this_dir / "MANIFEST.in").unlink(True) if __name__ == "__main__": @@ -77,17 +83,18 @@ def ln(src: str, dst: str, is_dir: bool = False) -> None: # Package 'flashinfer.data*' is absent from the `packages` configuration. warnings.filterwarnings("ignore", r".*flashinfer\.data.*", UserWarning) - setuptools.setup( - name="flashinfer", - version=get_version(), - packages=setuptools.find_packages( - include=["flashinfer*"], - exclude=["flashinfer.data*"], - ), - include_package_data=True, - author="FlashInfer team", - license="Apache License 2.0", - description="FlashInfer: Kernel Library for LLM Serving", - url="https://github.com/flashinfer-ai/flashinfer", - python_requires=">=3.8", - ) + with link_data_files(): + setuptools.setup( + name="flashinfer", + version=get_version(), + packages=setuptools.find_packages( + include=["flashinfer*"], + exclude=["flashinfer.data*"], + ), + include_package_data=True, + author="FlashInfer team", + license="Apache License 2.0", + description="FlashInfer: Kernel Library for LLM Serving", + url="https://github.com/flashinfer-ai/flashinfer", + python_requires=">=3.8", + )