Skip to content

Commit 45eba51

Browse files
committed
fix format issue
Signed-off-by: Yu Gong <yu3.gong@gmail.com>
1 parent 6dfda81 commit 45eba51

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

vllm/lora/punica_wrapper/punica_gpu.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -382,19 +382,19 @@ def add_lora_fused_moe(
382382
num_tokens_post_padded,
383383
max_lora_rank,
384384
top_k_num,
385-
shrink_config.get("BLOCK_SIZE_M", 64),
386-
shrink_config.get("BLOCK_SIZE_N", 64),
387-
shrink_config.get("BLOCK_SIZE_K", 32),
388-
shrink_config.get("GROUP_SIZE_M", 8),
389-
shrink_config.get("NUM_WARPS", 4),
390-
shrink_config.get("NUM_STAGES", 3),
391-
shrink_config.get("SPLIT_K", 1),
392-
expand_config.get("BLOCK_SIZE_M", 64),
393-
expand_config.get("BLOCK_SIZE_N", 64),
394-
expand_config.get("BLOCK_SIZE_K", 64),
395-
expand_config.get("GROUP_SIZE_M", 64),
396-
expand_config.get("NUM_WARPS", 4),
397-
expand_config.get("NUM_STAGES", 3),
398-
expand_config.get("SPLIT_K", 1),
385+
shrink_config.get("BLOCK_SIZE_M") or shrink_config.get("block_m") or 64,
386+
shrink_config.get("BLOCK_SIZE_N") or shrink_config.get("block_n") or 64,
387+
shrink_config.get("BLOCK_SIZE_K") or shrink_config.get("block_k") or 32,
388+
shrink_config.get("GROUP_SIZE_M") or shrink_config.get("group_m") or 8,
389+
shrink_config.get("NUM_WARPS") or shrink_config.get("num_warps") or 4,
390+
shrink_config.get("NUM_STAGES") or shrink_config.get("num_stages") or 3,
391+
shrink_config.get("SPLIT_K") or shrink_config.get("split_k") or 1,
392+
expand_config.get("BLOCK_SIZE_M") or expand_config.get("block_m") or 64,
393+
expand_config.get("BLOCK_SIZE_N") or expand_config.get("block_n") or 64,
394+
expand_config.get("BLOCK_SIZE_K") or expand_config.get("block_k") or 64,
395+
expand_config.get("GROUP_SIZE_M") or expand_config.get("group_m") or 64,
396+
expand_config.get("NUM_WARPS") or expand_config.get("num_warps") or 4,
397+
expand_config.get("NUM_STAGES") or expand_config.get("num_stages") or 3,
398+
expand_config.get("SPLIT_K") or expand_config.get("split_k") or 1,
399399
mul_routed_weight,
400400
)

0 commit comments

Comments
 (0)