We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent afe1767 commit e417997Copy full SHA for e417997
vllm_ascend/quantization/w8a8_dynamic.py
@@ -285,8 +285,9 @@ def fused_experts(hidden_states: torch.Tensor,
285
valid_token_mask = torch.arange(
286
0, sorted_token_indices.shape[0],
287
device=device).unsqueeze(1) < num_valid_tokens
288
- down_out_list.mul_(valid_token_mask)
289
- final_hidden_states.index_add_(0, sorted_token_indices, down_out_list)
+ valid_output = torch.where(valid_token_mask, down_out_list,
+ torch.zeros_like(down_out_list)).to(dtype)
290
+ final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
291
else:
292
# TODO: Reorder device memory 2 times here, replace the current
293
# implementation here when suitable operators become available.
0 commit comments