Skip to content

Commit

Permalink
bugfix: fix prefill kernel uris for aot compilation (#624)
Browse files Browse the repository at this point in the history
mask is no longer part of uris, this PR fixes the issue, otherwise our
aot wheels will still trigger JIT compilation for prefill kernels.
  • Loading branch information
yzh119 authored Nov 21, 2024
1 parent 3ed9a8b commit ba1d8c3
Showing 1 changed file with 27 additions and 23 deletions.
50 changes: 27 additions & 23 deletions python/aot_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,17 +201,19 @@ def get_instantiation_cu() -> Tuple[List[str], List[str], List[str]]:
)
for use_sliding_window in [True, False]:
for use_logits_soft_cap in [True, False]:
single_prefill_uris.append(
f"single_prefill_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_q}_"
f"head_dim_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"mask_{mask_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}_"
f"f16qk_{bool(allow_fp16_qk_reduction)}"
)
if (
mask_mode == 0
): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris
single_prefill_uris.append(
f"single_prefill_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_q}_"
f"head_dim_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}_"
f"f16qk_{bool(allow_fp16_qk_reduction)}"
)
write_if_different(path / fname, content)

# batch prefill files
Expand Down Expand Up @@ -262,18 +264,20 @@ def get_instantiation_cu() -> Tuple[List[str], List[str], List[str]]:

for sliding_window in [True, False]:
for logits_soft_cap in [True, False]:
batch_prefill_uris.append(
f"batch_prefill_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_q}_"
f"dtype_idx_{idtype}_"
f"head_dim_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"mask_{mask_mode}_"
f"use_swa_{sliding_window}_"
f"use_logits_cap_{logits_soft_cap}_"
f"f16qk_{bool(allow_fp16_qk_reduction)}"
)
if (
mask_mode == 0
): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris
batch_prefill_uris.append(
f"batch_prefill_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_q}_"
f"dtype_idx_{idtype}_"
f"head_dim_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{sliding_window}_"
f"use_logits_cap_{logits_soft_cap}_"
f"f16qk_{bool(allow_fp16_qk_reduction)}"
)

# Change to relative path
this_dir = pathlib.Path(__file__).parent.resolve()
Expand Down

0 comments on commit ba1d8c3

Please sign in to comment.