Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Refactor JIT and AOT build script #567

Merged
merged 9 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ src/generated/
python/csrc/generated/
python/flashinfer/_build_meta.py
python/flashinfer/jit/aot_config.py
flashinfer-aot/csrc_aot/generated/
python/csrc_aot/generated/

# Package files
python/flashinfer/data/
python/flashinfer/version.txt
python/MANIFEST.in

# Generated documentation files
docs/generated
Expand Down
15 changes: 12 additions & 3 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ You can follow the steps below to install FlashInfer from source code:

pip install ninja

4. Compile FlashInfer:
4. Install FlashInfer:

.. tabs::

Expand All @@ -153,8 +153,17 @@ You can follow the steps below to install FlashInfer from source code:

.. code-block:: bash

cd flashinfer/flashinfer-aot
pip install -e . -v
cd flashinfer/python
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" python3 aot_setup.py bdist_wheel
pip install dist/flashinfer-*.whl

.. tab:: Create sdist for JIT mode

.. code-block:: bash

cd flashinfer/python
python -m build --sdist
ls -la dist/

C++ API
-------
Expand Down
1 change: 0 additions & 1 deletion flashinfer-aot/3rdparty

This file was deleted.

12 changes: 0 additions & 12 deletions flashinfer-aot/MANIFEST.in

This file was deleted.

1 change: 0 additions & 1 deletion flashinfer-aot/csrc

This file was deleted.

45 changes: 0 additions & 45 deletions flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu

This file was deleted.

56 changes: 0 additions & 56 deletions flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu

This file was deleted.

1 change: 0 additions & 1 deletion flashinfer-aot/flashinfer

This file was deleted.

1 change: 0 additions & 1 deletion flashinfer-aot/include

This file was deleted.

1 change: 0 additions & 1 deletion flashinfer-aot/version.txt

This file was deleted.

14 changes: 7 additions & 7 deletions include/flashinfer/attention/scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__
* the new batch size after the partition.
*/
template <typename IdType>
auto PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(
inline auto PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(
const uint32_t max_grid_size, const uint32_t num_kv_heads, const std::vector<IdType>& num_pages,
const uint32_t min_num_pages_per_batch = 1) {
uint32_t low = min_num_pages_per_batch, high = 0;
Expand All @@ -77,7 +77,7 @@ auto PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(
return std::make_tuple(low, new_batch_size);
}

auto PrefillBinarySearchKVChunkSize(const uint32_t max_batch_size_if_split,
inline auto PrefillBinarySearchKVChunkSize(const uint32_t max_batch_size_if_split,
const std::vector<int64_t>& packed_qo_len_arr,
const std::vector<int64_t>& kv_len_arr,
const uint32_t qo_chunk_size,
Expand Down Expand Up @@ -129,7 +129,7 @@ auto PrefillBinarySearchKVChunkSize(const uint32_t max_batch_size_if_split,
*/
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE,
typename AttentionVariant>
cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch,
uint32_t& new_batch_size, uint32_t batch_size, typename AttentionVariant::IdType* kv_indptr_h,
const uint32_t num_qo_heads, const uint32_t page_size, bool enable_cuda_graph,
Expand Down Expand Up @@ -201,7 +201,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
* \return status Indicates whether CUDA calls are successful
*/
template <typename IdType>
auto DecodeSplitKVIndptr(IdType* indptr_h, uint32_t batch_size, uint32_t kv_chunk_size) {
inline auto DecodeSplitKVIndptr(IdType* indptr_h, uint32_t batch_size, uint32_t kv_chunk_size) {
std::vector<IdType> request_indices, kv_tile_indices, o_indptr;
o_indptr.push_back(0);

Expand Down Expand Up @@ -277,7 +277,7 @@ struct DecodePlanInfo {
};

template <uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, typename AttentionVariant>
cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer,
inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer,
void* page_locked_int_buffer, size_t int_workspace_size_in_bytes,
DecodePlanInfo& plan_info, typename AttentionVariant::IdType* indptr_h,
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
Expand Down Expand Up @@ -350,7 +350,7 @@ cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in_bytes,
}

template <typename IdType>
auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
uint32_t page_size, uint32_t max_batch_size_if_split,
bool enable_cuda_graph) {
Expand Down Expand Up @@ -520,7 +520,7 @@ struct PrefillPlanInfo {
};

template <typename IdType>
cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer,
inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer,
void* page_locked_int_buffer, size_t int_workspace_size_in_bytes,
PrefillPlanInfo& plan_info, IdType* qo_indptr_h, IdType* kv_indptr_h,
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
Expand Down
1 change: 0 additions & 1 deletion python/3rdparty

This file was deleted.

12 changes: 0 additions & 12 deletions python/MANIFEST.in

This file was deleted.

Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
limitations under the License.
"""

import sys
import re
from literal_map import (
pos_encoding_mode_literal,
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
idtype_literal,
pos_encoding_mode_literal,
)
from pathlib import Path


def get_cu_file_str(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
limitations under the License.
"""

import sys
import re
import itertools
from literal_map import (
mask_mode_literal,
pos_encoding_mode_literal,
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)
from pathlib import Path


def get_cu_file_str(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
limitations under the License.
"""

import sys
import re
from literal_map import (
mask_mode_literal,
pos_encoding_mode_literal,
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)
from pathlib import Path


def get_cu_file_str(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

import argparse
from pathlib import Path
from literal_map import (
pos_encoding_mode_literal,

from .literal_map import (
bool_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
limitations under the License.
"""

import sys
import re
from literal_map import (
pos_encoding_mode_literal,
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
pos_encoding_mode_literal,
)
from pathlib import Path


def get_cu_file_str(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
limitations under the License.
"""

import sys
import re
from literal_map import (
pos_encoding_mode_literal,
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)
from pathlib import Path


def get_cu_file_str(
Expand Down
File renamed without changes.
13 changes: 13 additions & 0 deletions python/aot_MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# MANIFEST.in for AOT wheel

prune */__pycache__
prune csrc
prune csrc_aot
exclude aot_setup.py
exclude setup.py

include flashinfer/data/version.txt
graft flashinfer/data/csrc
graft flashinfer/data/include
graft flashinfer/data/cutlass/include
graft flashinfer/data/cutlass/tools/util/include
Loading