Skip to content

Commit 6b909fc

Browse files
committed
Move imports out of line to fix compilation
Signed-off-by: Bram Wasti <bwasti@meta.com>
1 parent 9ee0b21 commit 6b909fc

File tree

3 files changed

+8
-15
lines changed

3 files changed

+8
-15
lines changed

vllm/distributed/device_communicators/all_reduce_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
import vllm.envs as envs
2020
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
2121
from vllm.logger import init_logger
22+
from vllm.model_executor.layers.batch_invariant import (
23+
vllm_kernel_override_batch_invariant,
24+
)
2225
from vllm.utils import cuda_device_count_stateless, update_environment_variables
2326

2427
logger = init_logger(__name__)
@@ -70,9 +73,6 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
7073
from vllm.distributed.device_communicators.pynccl_allocator import (
7174
is_symmetric_memory_enabled,
7275
)
73-
from vllm.model_executor.layers.batch_invariant import (
74-
vllm_kernel_override_batch_invariant,
75-
)
7676

7777
if vllm_kernel_override_batch_invariant():
7878
return False

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,11 +1134,8 @@ def fused_topk_bias(
11341134
scores_for_choice = scores.view(
11351135
-1, n_routed_experts
11361136
) + e_score_correction_bias.unsqueeze(0)
1137-
# For batch invariance, use sorted=True to ensure deterministic expert selection
1138-
from vllm.model_executor.layers.batch_invariant import (
1139-
vllm_kernel_override_batch_invariant,
1140-
)
11411137

1138+
# For batch invariance, use sorted=True to ensure deterministic expert selection
11421139
use_sorted = vllm_kernel_override_batch_invariant()
11431140
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
11441141
topk_weights = scores.gather(1, topk_indices)
@@ -1201,11 +1198,8 @@ def grouped_topk(
12011198
group_scores = (
12021199
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
12031200
) # [n, n_group]
1204-
# For batch invariance, use sorted=True to ensure deterministic expert selection
1205-
from vllm.model_executor.layers.batch_invariant import (
1206-
vllm_kernel_override_batch_invariant,
1207-
)
12081201

1202+
# For batch invariance, use sorted=True to ensure deterministic expert selection
12091203
use_sorted = vllm_kernel_override_batch_invariant()
12101204
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
12111205
1

vllm/v1/attention/backends/mla/triton_mla.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
1414
from vllm.attention.ops.triton_flash_attention import triton_attention
1515
from vllm.logger import init_logger
16+
from vllm.model_executor.layers.batch_invariant import (
17+
vllm_kernel_override_batch_invariant,
18+
)
1619
from vllm.platforms import current_platform
1720
from vllm.triton_utils import HAS_TRITON
1821
from vllm.v1.attention.backends.mla.common import (
@@ -159,10 +162,6 @@ def _forward_decode(
159162
)
160163
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
161164

162-
from vllm.model_executor.layers.batch_invariant import (
163-
vllm_kernel_override_batch_invariant,
164-
)
165-
166165
# For batch invariance, use only 1 split to ensure deterministic reduction
167166
num_kv_splits = 1 if vllm_kernel_override_batch_invariant() else 4
168167

0 commit comments

Comments
 (0)