Skip to content

Commit af0f1a6

Browse files
authored
Static fused moe op (vllm-project#41)
* Fix mixtral hidden states layout to fit into habana model runner * Add static moe op to mixtral * Add mark_step to static_fused_moe * Update __init__.py * Fix code indentation * Make code compatible with non HPU devices * Move static_fused_moe to vllm.hpu.ops * Update mixtral.py * Move op import from forward to top of the file * Remove circular import
1 parent 3c827b3 commit af0f1a6

File tree

2 files changed

+66
-15
lines changed

2 files changed

+66
-15
lines changed

vllm/hpu/ops.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,39 @@ def apply_rope(
113113

114114
def awq_gemm(*args):
115115
raise NotImplementedError
116+
117+
118+
def silu_and_mul_wrapper(x: torch.Tensor) -> torch.Tensor:
119+
d = x.shape[-1] // 2
120+
output_shape = (x.shape[:-1] + (d, ))
121+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
122+
silu_and_mul(out, x)
123+
return out
124+
125+
126+
@hpu_utils.with_mark_steps
127+
def static_fused_moe(hidden_states, w1, w2, score, topk):
128+
B, D = hidden_states.shape
129+
num_experts = w1.shape[0]
130+
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
131+
routing_weights, selected_experts = torch.topk(routing_weights, topk, dim=-1)
132+
routing_weights = routing_weights.to(hidden_states.dtype)
133+
final_hidden_states = torch.zeros(
134+
(1, B, D), dtype=hidden_states.dtype, device=hidden_states.device
135+
)
136+
padded_weights = torch.zeros(
137+
(B, num_experts), dtype=hidden_states.dtype, device=hidden_states.device
138+
)
139+
padded_weights.scatter_(-1, selected_experts, routing_weights)
140+
padded_weights = padded_weights.reshape(-1, B, w1.shape[0])
141+
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)
142+
143+
for expert_idx in range(num_experts):
144+
padded_weight = padded_weights[expert_idx]
145+
current_state_static = hidden_states.reshape(-1, D)
146+
w_output = silu_and_mul_wrapper(torch.matmul(current_state_static, w1[expert_idx].transpose(0, 1)))
147+
w_output = torch.matmul(w_output, w2[expert_idx].transpose(0, 1))
148+
current_hidden_states_static = w_output * padded_weight
149+
final_hidden_states += current_hidden_states_static
150+
151+
return final_hidden_states.view(-1, D)

vllm/model_executor/models/mixtral.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@
5050
from vllm.model_executor.sampling_metadata import SamplingMetadata
5151
from vllm.model_executor.utils import set_weight_attrs
5252
from vllm.sequence import SamplerOutput
53-
from vllm.utils import print_warning_once
53+
from vllm.utils import print_warning_once, is_hpu
54+
55+
if is_hpu():
56+
from vllm.hpu.ops import static_fused_moe
5457

5558

5659
class MixtralMoE(nn.Module):
@@ -220,28 +223,40 @@ def process_weights_after_loading(self):
220223
requires_grad=False)
221224

222225
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
223-
num_tokens, hidden_size = hidden_states.shape
226+
if is_hpu():
227+
batch_size, sequence_length, hidden_size = hidden_states.shape
228+
else:
229+
num_tokens, hidden_size = hidden_states.shape
224230
hidden_states = hidden_states.view(-1, self.hidden_size)
225231
# router_logits: (num_tokens, n_experts)
226232
router_logits, _ = self.gate(hidden_states)
227-
final_hidden_states = fused_moe(hidden_states,
228-
self.w13_weight,
229-
self.w2_weight,
230-
router_logits,
231-
self.top_k,
232-
renormalize=True,
233-
inplace=True,
234-
use_fp8=self.use_fp8,
235-
w1_scale=self.w13_scale,
236-
w2_scale=self.w2_scale,
237-
a1_scale=self.a13_scale,
238-
a2_scale=self.a2_scale)
233+
234+
if is_hpu():
235+
final_hidden_states = static_fused_moe(hidden_states,
236+
self.w13_weight,
237+
self.w2_weight,
238+
router_logits,
239+
self.top_k)
240+
else:
241+
final_hidden_states = fused_moe(hidden_states,
242+
self.w13_weight,
243+
self.w2_weight,
244+
router_logits,
245+
self.top_k,
246+
renormalize=True,
247+
inplace=True,
248+
use_fp8=self.use_fp8,
249+
w1_scale=self.w13_scale,
250+
w2_scale=self.w2_scale,
251+
a1_scale=self.a13_scale,
252+
a2_scale=self.a2_scale)
239253

240254
if self.tp_size > 1:
241255
final_hidden_states = tensor_model_parallel_all_reduce(
242256
final_hidden_states)
243257

244-
return final_hidden_states.view(num_tokens, hidden_size)
258+
return (final_hidden_states.view(batch_size, sequence_length, hidden_size) if is_hpu()
259+
else final_hidden_states.view(num_tokens, hidden_size))
245260

246261

247262
class MixtralAttention(nn.Module):

0 commit comments

Comments
 (0)