Skip to content

Commit

Permalink
fix broken cpp integration caused by flashinfer-ai#567
Browse files Browse the repository at this point in the history
  • Loading branch information
tsu-bin committed Oct 30, 2024
1 parent 7df90dd commit d5a585c
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 54 deletions.
44 changes: 22 additions & 22 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ file(MAKE_DIRECTORY ${PROJECT_SOURCE_DIR}/src/generated)
set(dispatch_inc_file ${PROJECT_SOURCE_DIR}/src/dispatch.inc)
add_custom_command(
OUTPUT ${dispatch_inc_file}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES}
DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_dispatch_inc.py
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES}
DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_dispatch_inc.py
COMMENT "Generating additional source file ${generated_dispatch_inc}"
VERBATIM
)
Expand All @@ -124,8 +124,8 @@ foreach(head_dim IN LISTS HEAD_DIMS)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_head_${head_dim}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_decode_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_decode_inst.py
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_single_decode_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_single_decode_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
Expand All @@ -137,8 +137,8 @@ foreach(head_dim IN LISTS HEAD_DIMS)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_head_${head_dim}_posenc_${pos_encoding_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_decode_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_decode_inst.py
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_single_decode_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_single_decode_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
Expand All @@ -156,8 +156,8 @@ foreach(head_dim IN LISTS HEAD_DIMS)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_head_${head_dim}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_decode_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_decode_inst.py
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_paged_decode_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_paged_decode_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
Expand All @@ -169,8 +169,8 @@ foreach(head_dim IN LISTS HEAD_DIMS)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_head_${head_dim}_posenc_${pos_encoding_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_decode_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_decode_inst.py
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_paged_decode_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_paged_decode_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
Expand All @@ -193,8 +193,8 @@ foreach(head_dim IN LISTS HEAD_DIMS)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_prefill_inst.py
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_single_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_single_prefill_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
Expand All @@ -205,8 +205,8 @@ foreach(head_dim IN LISTS HEAD_DIMS)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_prefill_inst.py
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_single_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_single_prefill_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
Expand All @@ -227,8 +227,8 @@ foreach(head_dim IN LISTS HEAD_DIMS)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_prefill_inst.py
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_paged_prefill_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
Expand All @@ -239,8 +239,8 @@ foreach(head_dim IN LISTS HEAD_DIMS)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_prefill_inst.py
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_paged_prefill_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
Expand All @@ -262,8 +262,8 @@ foreach(head_dim IN LISTS HEAD_DIMS)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_ragged_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_ragged_prefill_inst.py
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_ragged_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_ragged_prefill_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
Expand All @@ -274,8 +274,8 @@ foreach(head_dim IN LISTS HEAD_DIMS)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_ragged_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_ragged_prefill_inst.py
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_ragged_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_ragged_prefill_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
Expand Down
18 changes: 13 additions & 5 deletions python/_aot_build_utils/generate_batch_paged_decode_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,19 @@
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
idtype_literal,
pos_encoding_mode_literal,
)
if __package__:
from .literal_map import (
dtype_literal,
idtype_literal,
pos_encoding_mode_literal,
)
else:
sys.path.append(str(Path(__file__).resolve().parents[1] / "_aot_build_utils"))
from literal_map import (
dtype_literal,
idtype_literal,
pos_encoding_mode_literal,
)


def get_cu_file_str(
Expand Down
21 changes: 15 additions & 6 deletions python/_aot_build_utils/generate_batch_paged_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,21 @@
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)
if __package__:
from .literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)
else:
sys.path.append(str(Path(__file__).resolve().parents[1] / "_aot_build_utils"))
from literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)


def get_cu_file_str(
Expand Down
21 changes: 15 additions & 6 deletions python/_aot_build_utils/generate_batch_ragged_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,21 @@
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)
if __package__:
from .literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)
else:
sys.path.append(str(Path(__file__).resolve().parents[1] / "_aot_build_utils"))
from literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)


def get_cu_file_str(
Expand Down
19 changes: 14 additions & 5 deletions python/_aot_build_utils/generate_dispatch_inc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,23 @@
limitations under the License.
"""

import sys
import argparse
from pathlib import Path

from .literal_map import (
bool_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)
if __package__:
from .literal_map import (
bool_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)
else:
sys.path.append(str(Path(__file__).resolve().parents[1] / "_aot_build_utils"))
from literal_map import (
bool_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)


def get_dispatch_inc_str(args: argparse.Namespace) -> str:
Expand Down
16 changes: 11 additions & 5 deletions python/_aot_build_utils/generate_single_decode_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
pos_encoding_mode_literal,
)

if __package__:
from .literal_map import (
dtype_literal,
pos_encoding_mode_literal,
)
else:
sys.path.append(str(Path(__file__).resolve().parents[1] / "_aot_build_utils"))
from literal_map import (
dtype_literal,
pos_encoding_mode_literal,
)

def get_cu_file_str(
head_dim,
Expand Down
18 changes: 13 additions & 5 deletions python/_aot_build_utils/generate_single_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,19 @@
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)
if __package__:
from .literal_map import (
dtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)
else:
sys.path.append(str(Path(__file__).resolve().parents[1] / "_aot_build_utils"))
from literal_map import (
dtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)


def get_cu_file_str(
Expand Down

0 comments on commit d5a585c

Please sign in to comment.