@@ -82,18 +82,22 @@ class DPMetadata:
8282 max_tokens_across_dp_cpu : torch .Tensor
8383 num_tokens_across_dp_cpu : torch .Tensor
8484
85+ hidden_states_across_dp : torch .Tensor
86+ router_logits_across_dp : torch .Tensor
87+ local_hidden_states : torch .Tensor
88+
8589 # NOTE: local_sizes should only be set by the chunked_sizes context manager
8690 local_sizes : Optional [list [int ]] = None
8791
8892 @staticmethod
8993 def make (
90- parallel_config : ParallelConfig ,
94+ vllm_config : VllmConfig ,
9195 num_tokens : int ,
9296 num_tokens_across_dp_cpu : torch .Tensor ,
9397 ) -> "DPMetadata" :
9498 assert num_tokens_across_dp_cpu is not None
95- assert parallel_config .data_parallel_size > 1
96- dp_rank = parallel_config .data_parallel_rank
99+ assert vllm_config . parallel_config .data_parallel_size > 1
100+ dp_rank = vllm_config . parallel_config .data_parallel_rank
97101 batchsize = num_tokens
98102
99103 # If num_tokens_across_dp is None, it will be computed by all_reduce
@@ -102,7 +106,48 @@ def make(
102106 f"{ num_tokens_across_dp_cpu [dp_rank ]} { batchsize } "
103107 )
104108 max_tokens_across_dp_cpu = torch .max (num_tokens_across_dp_cpu )
105- return DPMetadata (max_tokens_across_dp_cpu , num_tokens_across_dp_cpu )
109+
110+ hidden_size = vllm_config .model_config .get_hidden_size ()
111+ dp_size = vllm_config .parallel_config .data_parallel_size
112+ tp_size = vllm_config .parallel_config .tensor_parallel_size
113+
114+ num_tokens_across_dp = num_tokens * dp_size
115+
116+ dtype = vllm_config .model_config .dtype
117+ from vllm .platforms import current_platform
118+ device = current_platform .device_type
119+
120+ if device == "hpu" :
121+ num_expert_names = [
122+ "moe_num_experts" , # Dbrx
123+ "num_experts" , # Jamba
124+ "n_routed_experts" , # DeepSeek
125+ "num_local_experts" , # Mixtral
126+ ]
127+ num_experts = 0
128+ for name in num_expert_names :
129+ num_experts = getattr (vllm_config .model_config .hf_text_config , name , 0 )
130+ if num_experts > 0 :
131+ break
132+ assert num_experts > 0 , \
133+ "No expert found in the model config. Please check the model config."
134+
135+ hidden_states_across_dp = torch .empty (
136+ (num_tokens_across_dp , hidden_size ),
137+ dtype = dtype ,
138+ device = device ,
139+ )
140+ router_logits_across_dp = torch .empty (
141+ (num_tokens_across_dp , num_experts ),
142+ dtype = dtype ,
143+ device = device ,
144+ )
145+ local_num_tokens = (num_tokens // tp_size ) if vllm_config .parallel_config .use_sequence_parallel_moe else num_tokens
146+ local_hidden_states = torch .empty (
147+ (local_num_tokens , hidden_size ), dtype = dtype , device = device
148+ )
149+
150+ return DPMetadata (max_tokens_across_dp_cpu , num_tokens_across_dp_cpu , hidden_states_across_dp , router_logits_across_dp , local_hidden_states )
106151
107152 @contextmanager
108153 def chunked_sizes (
@@ -269,7 +314,7 @@ def set_forward_context(
269314 ):
270315 assert num_tokens_across_dp is not None
271316 dp_metadata = DPMetadata .make (
272- vllm_config . parallel_config , num_tokens or 0 , num_tokens_across_dp
317+ vllm_config , num_tokens or 0 , num_tokens_across_dp
273318 )
274319
275320 forward_context = create_forward_context (
0 commit comments