Skip to content

Commit

Permalink
Using gaudi mixtral MOE impl
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Huang <daniel1.huang@intel.com>
  • Loading branch information
pi314ever committed Jan 3, 2025
1 parent 8dbc254 commit 9c390e7
Showing 1 changed file with 31 additions and 34 deletions.
65 changes: 31 additions & 34 deletions optimum/habana/transformers/models/snowflake/modeling_arctic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,59 +1057,56 @@ def __init__(self, config: ArcticConfig, layer_id: int, **kwargs):

# Similar in behavior to transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward but more efficient.
def _moe_foreward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Copied from ../mixtral/modeling_mixtral.py gaudi_mixtral_block_sparse_moe_forward
"""
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

if is_deepspeed_available() and (not self.training):
from deepspeed import comm as dist

if dist.is_initialized():
output_tensors = [router_logits.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, router_logits)
router_logits = torch.cat(output_tensors, dim=1)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.top_k > 1:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
(batch_size, sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)

# Matching between experts, tokens, and their top-k rank. For every i,
# expert_idx[i] is the rank topk_idx[i] expert for token_idx[i].
expert_idx, token_idx, topk_idx = torch.where(
selected_experts
== torch.arange(
self.num_experts,
device=selected_experts.device,
).view((self.num_experts, 1, 1))
padded_weights = torch.zeros(
(batch_size * sequence_length, self.num_experts), dtype=hidden_states.dtype, device=hidden_states.device
)

# Split into one chunk per expert.
bincount = torch.bincount(expert_idx, minlength=self.num_experts).tolist()
token_idx = token_idx.split(bincount)
topk_idx = topk_idx.split(bincount)
padded_weights.scatter_(-1, selected_experts, routing_weights)
padded_weights = padded_weights.reshape(-1, sequence_length, self.num_experts)
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)

# Loop over all available experts in the model and perform the computation on each expert
for expert_layer, top_x, idx in zip(self.experts, token_idx, topk_idx):
if top_x.shape[0] == 0:
continue

# in torch it is faster to index using lists than torch tensors
top_x_list = top_x.tolist()
idx_list = idx.tolist()

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
# torch.distributed.barrier()
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
padded_weight = padded_weights[expert_idx]
current_state_static = hidden_states.reshape(-1, hidden_dim)
current_hidden_states_static = (
expert_layer(current_state_static).reshape(-1, sequence_length, hidden_dim) * padded_weight
)
final_hidden_states += current_hidden_states_static
# support long sequences exceeding 8192
if not self.training and sequence_length > 8192:
htcore.mark_step()

return final_hidden_states, load_balancing_loss_func(
(router_logits,), self.num_experts, self.top_k
) # ZY: let's directly output the loss to align what we have in ds
)

def forward(self, hidden_states: torch.Tensor):
if self.is_moe_layer:
Expand Down

0 comments on commit 9c390e7

Please sign in to comment.