diff --git a/python/aot_setup.py b/python/aot_setup.py index 3f24eec1..e9c73159 100644 --- a/python/aot_setup.py +++ b/python/aot_setup.py @@ -201,17 +201,19 @@ def get_instantiation_cu() -> Tuple[List[str], List[str], List[str]]: ) for use_sliding_window in [True, False]: for use_logits_soft_cap in [True, False]: - single_prefill_uris.append( - f"single_prefill_with_kv_cache_dtype_q_{dtype_q}_" - f"dtype_kv_{dtype_kv}_" - f"dtype_o_{dtype_q}_" - f"head_dim_{head_dim}_" - f"posenc_{pos_encoding_mode}_" - f"mask_{mask_mode}_" - f"use_swa_{use_sliding_window}_" - f"use_logits_cap_{use_logits_soft_cap}_" - f"f16qk_{bool(allow_fp16_qk_reduction)}" - ) + if ( + mask_mode == 0 + ): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris + single_prefill_uris.append( + f"single_prefill_with_kv_cache_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_q}_" + f"head_dim_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"use_swa_{use_sliding_window}_" + f"use_logits_cap_{use_logits_soft_cap}_" + f"f16qk_{bool(allow_fp16_qk_reduction)}" + ) write_if_different(path / fname, content) # batch prefill files @@ -262,18 +264,20 @@ def get_instantiation_cu() -> Tuple[List[str], List[str], List[str]]: for sliding_window in [True, False]: for logits_soft_cap in [True, False]: - batch_prefill_uris.append( - f"batch_prefill_with_kv_cache_dtype_q_{dtype_q}_" - f"dtype_kv_{dtype_kv}_" - f"dtype_o_{dtype_q}_" - f"dtype_idx_{idtype}_" - f"head_dim_{head_dim}_" - f"posenc_{pos_encoding_mode}_" - f"mask_{mask_mode}_" - f"use_swa_{sliding_window}_" - f"use_logits_cap_{logits_soft_cap}_" - f"f16qk_{bool(allow_fp16_qk_reduction)}" - ) + if ( + mask_mode == 0 + ): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris + batch_prefill_uris.append( + f"batch_prefill_with_kv_cache_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_q}_" + f"dtype_idx_{idtype}_" + f"head_dim_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"use_swa_{sliding_window}_" + f"use_logits_cap_{logits_soft_cap}_" + f"f16qk_{bool(allow_fp16_qk_reduction)}" + ) # Change to relative path this_dir = pathlib.Path(__file__).parent.resolve()