Skip to content

Commit abf7ec4

Browse files
authored
allocate DP allgather tensor in forward context (vllm-project#1)
--------- Signed-off-by: Wuxun Zhang <wuxun.zhang@intel.com>
1 parent ee74be3 commit abf7ec4

File tree

1 file changed

+50
-5
lines changed

1 file changed

+50
-5
lines changed

vllm/forward_context.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,22 @@ class DPMetadata:
8282
max_tokens_across_dp_cpu: torch.Tensor
8383
num_tokens_across_dp_cpu: torch.Tensor
8484

85+
hidden_states_across_dp: torch.Tensor
86+
router_logits_across_dp: torch.Tensor
87+
local_hidden_states: torch.Tensor
88+
8589
# NOTE: local_sizes should only be set by the chunked_sizes context manager
8690
local_sizes: Optional[list[int]] = None
8791

8892
@staticmethod
8993
def make(
90-
parallel_config: ParallelConfig,
94+
vllm_config: VllmConfig,
9195
num_tokens: int,
9296
num_tokens_across_dp_cpu: torch.Tensor,
9397
) -> "DPMetadata":
9498
assert num_tokens_across_dp_cpu is not None
95-
assert parallel_config.data_parallel_size > 1
96-
dp_rank = parallel_config.data_parallel_rank
99+
assert vllm_config.parallel_config.data_parallel_size > 1
100+
dp_rank = vllm_config.parallel_config.data_parallel_rank
97101
batchsize = num_tokens
98102

99103
# If num_tokens_across_dp is None, it will be computed by all_reduce
@@ -102,7 +106,48 @@ def make(
102106
f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
103107
)
104108
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
105-
return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
109+
110+
hidden_size = vllm_config.model_config.get_hidden_size()
111+
dp_size = vllm_config.parallel_config.data_parallel_size
112+
tp_size = vllm_config.parallel_config.tensor_parallel_size
113+
114+
num_tokens_across_dp = num_tokens * dp_size
115+
116+
dtype = vllm_config.model_config.dtype
117+
from vllm.platforms import current_platform
118+
device = current_platform.device_type
119+
120+
if device == "hpu":
121+
num_expert_names = [
122+
"moe_num_experts", # Dbrx
123+
"num_experts", # Jamba
124+
"n_routed_experts", # DeepSeek
125+
"num_local_experts", # Mixtral
126+
]
127+
num_experts = 0
128+
for name in num_expert_names:
129+
num_experts = getattr(vllm_config.model_config.hf_text_config, name, 0)
130+
if num_experts > 0:
131+
break
132+
assert num_experts > 0, \
133+
"No expert found in the model config. Please check the model config."
134+
135+
hidden_states_across_dp = torch.empty(
136+
(num_tokens_across_dp, hidden_size),
137+
dtype=dtype,
138+
device=device,
139+
)
140+
router_logits_across_dp = torch.empty(
141+
(num_tokens_across_dp, num_experts),
142+
dtype=dtype,
143+
device=device,
144+
)
145+
local_num_tokens = (num_tokens // tp_size) if vllm_config.parallel_config.use_sequence_parallel_moe else num_tokens
146+
local_hidden_states = torch.empty(
147+
(local_num_tokens, hidden_size), dtype=dtype, device=device
148+
)
149+
150+
return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu, hidden_states_across_dp, router_logits_across_dp, local_hidden_states)
106151

107152
@contextmanager
108153
def chunked_sizes(
@@ -269,7 +314,7 @@ def set_forward_context(
269314
):
270315
assert num_tokens_across_dp is not None
271316
dp_metadata = DPMetadata.make(
272-
vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp
317+
vllm_config, num_tokens or 0, num_tokens_across_dp
273318
)
274319

275320
forward_context = create_forward_context(

0 commit comments

Comments
 (0)