diff --git a/CMakeLists.txt b/CMakeLists.txt index 54c11644..abae8201 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -110,8 +110,8 @@ file(MAKE_DIRECTORY ${PROJECT_SOURCE_DIR}/src/generated) set(dispatch_inc_file ${PROJECT_SOURCE_DIR}/src/dispatch.inc) add_custom_command( OUTPUT ${dispatch_inc_file} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES} - DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_dispatch_inc.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES} + DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_dispatch_inc.py COMMENT "Generating additional source file ${generated_dispatch_inc}" VERBATIM ) @@ -124,8 +124,8 @@ foreach(head_dim IN LISTS HEAD_DIMS) set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_head_${head_dim}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_decode_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_single_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_single_decode_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) @@ -137,8 +137,8 @@ foreach(head_dim IN LISTS HEAD_DIMS) set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_head_${head_dim}_posenc_${pos_encoding_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_decode_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_single_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_single_decode_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) @@ -156,8 +156,8 @@ foreach(head_dim IN LISTS HEAD_DIMS) set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_head_${head_dim}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_decode_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_paged_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_paged_decode_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) @@ -169,8 +169,8 @@ foreach(head_dim IN LISTS HEAD_DIMS) set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_head_${head_dim}_posenc_${pos_encoding_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_decode_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_paged_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_paged_decode_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) @@ -193,8 +193,8 @@ foreach(head_dim IN LISTS HEAD_DIMS) set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_prefill_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_prefill_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_single_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_single_prefill_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) @@ -205,8 +205,8 @@ foreach(head_dim IN LISTS HEAD_DIMS) set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_prefill_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_prefill_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_single_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_single_prefill_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) @@ -227,8 +227,8 @@ foreach(head_dim IN LISTS HEAD_DIMS) set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_prefill_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_prefill_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_paged_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_paged_prefill_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) @@ -239,8 +239,8 @@ foreach(head_dim IN LISTS HEAD_DIMS) set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_prefill_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_prefill_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_paged_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_paged_prefill_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) @@ -262,8 +262,8 @@ foreach(head_dim IN LISTS HEAD_DIMS) set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_ragged_prefill_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_ragged_prefill_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_ragged_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_ragged_prefill_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) @@ -274,8 +274,8 @@ foreach(head_dim IN LISTS HEAD_DIMS) set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_ragged_prefill_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_ragged_prefill_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_ragged_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/_aot_build_utils/generate_batch_ragged_prefill_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) diff --git a/python/_aot_build_utils/generate_batch_paged_decode_inst.py b/python/_aot_build_utils/generate_batch_paged_decode_inst.py index 7808c33b..5ed07e50 100644 --- a/python/_aot_build_utils/generate_batch_paged_decode_inst.py +++ b/python/_aot_build_utils/generate_batch_paged_decode_inst.py @@ -18,11 +18,19 @@ import sys from pathlib import Path -from .literal_map import ( - dtype_literal, - idtype_literal, - pos_encoding_mode_literal, -) +if __package__: + from .literal_map import ( + dtype_literal, + idtype_literal, + pos_encoding_mode_literal, + ) +else: + sys.path.append(str(Path(__file__).resolve().parents[1] / "_aot_build_utils")) + from literal_map import ( + dtype_literal, + idtype_literal, + pos_encoding_mode_literal, + ) def get_cu_file_str( diff --git a/python/_aot_build_utils/generate_batch_paged_prefill_inst.py b/python/_aot_build_utils/generate_batch_paged_prefill_inst.py index 97f1423a..21fd8722 100644 --- a/python/_aot_build_utils/generate_batch_paged_prefill_inst.py +++ b/python/_aot_build_utils/generate_batch_paged_prefill_inst.py @@ -18,12 +18,21 @@ import sys from pathlib import Path -from .literal_map import ( - dtype_literal, - idtype_literal, - mask_mode_literal, - pos_encoding_mode_literal, -) +if __package__: + from .literal_map import ( + dtype_literal, + idtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, + ) +else: + sys.path.append(str(Path(__file__).resolve().parents[1] / "_aot_build_utils")) + from literal_map import ( + dtype_literal, + idtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, + ) def get_cu_file_str( diff --git a/python/_aot_build_utils/generate_batch_ragged_prefill_inst.py b/python/_aot_build_utils/generate_batch_ragged_prefill_inst.py index f5631303..3ee9ed54 100644 --- a/python/_aot_build_utils/generate_batch_ragged_prefill_inst.py +++ b/python/_aot_build_utils/generate_batch_ragged_prefill_inst.py @@ -18,12 +18,21 @@ import sys from pathlib import Path -from .literal_map import ( - dtype_literal, - idtype_literal, - mask_mode_literal, - pos_encoding_mode_literal, -) +if __package__: + from .literal_map import ( + dtype_literal, + idtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, + ) +else: + sys.path.append(str(Path(__file__).resolve().parents[1] / "_aot_build_utils")) + from literal_map import ( + dtype_literal, + idtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, + ) def get_cu_file_str( diff --git a/python/_aot_build_utils/generate_dispatch_inc.py b/python/_aot_build_utils/generate_dispatch_inc.py index 30552e6e..57d87e97 100644 --- a/python/_aot_build_utils/generate_dispatch_inc.py +++ b/python/_aot_build_utils/generate_dispatch_inc.py @@ -14,14 +14,23 @@ limitations under the License. """ +import sys import argparse from pathlib import Path -from .literal_map import ( - bool_literal, - mask_mode_literal, - pos_encoding_mode_literal, -) +if __package__: + from .literal_map import ( + bool_literal, + mask_mode_literal, + pos_encoding_mode_literal, + ) +else: + sys.path.append(str(Path(__file__).resolve().parents[1] / "_aot_build_utils")) + from literal_map import ( + bool_literal, + mask_mode_literal, + pos_encoding_mode_literal, + ) def get_dispatch_inc_str(args: argparse.Namespace) -> str: diff --git a/python/_aot_build_utils/generate_single_decode_inst.py b/python/_aot_build_utils/generate_single_decode_inst.py index ce24d7e7..db01e640 100644 --- a/python/_aot_build_utils/generate_single_decode_inst.py +++ b/python/_aot_build_utils/generate_single_decode_inst.py @@ -18,11 +18,17 @@ import sys from pathlib import Path -from .literal_map import ( - dtype_literal, - pos_encoding_mode_literal, -) - +if __package__: + from .literal_map import ( + dtype_literal, + pos_encoding_mode_literal, + ) +else: + sys.path.append(str(Path(__file__).resolve().parents[1] / "_aot_build_utils")) + from literal_map import ( + dtype_literal, + pos_encoding_mode_literal, + ) def get_cu_file_str( head_dim, diff --git a/python/_aot_build_utils/generate_single_prefill_inst.py b/python/_aot_build_utils/generate_single_prefill_inst.py index 49eefd17..b840b34e 100644 --- a/python/_aot_build_utils/generate_single_prefill_inst.py +++ b/python/_aot_build_utils/generate_single_prefill_inst.py @@ -18,11 +18,19 @@ import sys from pathlib import Path -from .literal_map import ( - dtype_literal, - mask_mode_literal, - pos_encoding_mode_literal, -) +if __package__: + from .literal_map import ( + dtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, + ) +else: + sys.path.append(str(Path(__file__).resolve().parents[1] / "_aot_build_utils")) + from literal_map import ( + dtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, + ) def get_cu_file_str(