Skip to content

Commit a85634c

Browse files
committed
make sure modular matches!
1 parent caf6e77 commit a85634c

File tree

1 file changed

+48
-25
lines changed

1 file changed

+48
-25
lines changed

src/transformers/models/openai_moe/modular_openai_moe.py

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def forward(self, hidden_states):
5252
class OpenAIMoeExperts(nn.Module):
5353
def __init__(self, config):
5454
super().__init__()
55-
self.num_experts = config.num_local_experts
5655
self.intermediate_size = config.intermediate_size
56+
self.num_experts = config.num_local_experts
5757
self.hidden_size = config.hidden_size
5858
self.expert_dim = self.intermediate_size
5959
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
@@ -70,16 +70,19 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
7070
For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
7171
7272
Args:
73-
hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
73+
hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
7474
selected_experts (torch.Tensor): (batch_size * token_num, top_k)
7575
routing_weights (torch.Tensor): (batch_size * token_num, top_k)
7676
Returns:
7777
torch.Tensor
7878
"""
79+
batch_size = hidden_states.shape[0]
80+
hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
81+
num_experts = routing_weights.shape[0]
7982
if self.training:
8083
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
8184
with torch.no_grad():
82-
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts).permute(
85+
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts).permute(
8386
2, 1, 0
8487
)
8588
expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
@@ -100,42 +103,62 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
100103
) # (num_tokens, hidden_dim)
101104
weighted_output = out * routing_weights[top_x, idx, None] # (num_tokens, hidden_dim)
102105
next_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype)[0])
106+
next_states = next_states.view(batch_size, -1, self.hidden_size)
103107
else:
104-
hidden_states = hidden_states.repeat(self.num_experts, 1)
105-
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
108+
hidden_states = hidden_states.repeat(num_experts, 1)
109+
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
106110
gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
107111
gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
108112
glu = gate * torch.sigmoid(gate * self.alpha)
109-
next_states = torch.bmm(((up + 1) * glu), self.down_proj) + self.down_proj_bias[..., None, :]
110-
next_states = next_states.view(-1, self.hidden_size)
111-
return next_states
113+
next_states = torch.bmm(((up + 1) * glu), self.down_proj)
114+
next_states = next_states + self.down_proj_bias[..., None, :]
115+
next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) # (num_experts, batch_size, seq_len, hidden_size)
116+
return next_states, None
112117

118+
class TopKRouter(nn.Module):
119+
def __init__(self, config):
120+
super().__init__()
121+
self.top_k = config.num_experts_per_tok
122+
self.num_experts = config.num_local_experts
123+
self.hidden_dim = config.hidden_size
124+
self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
125+
self.bias = nn.Parameter(torch.empty(self.num_experts))
126+
127+
def forward(self, hidden_states):
128+
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
129+
router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
130+
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
131+
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1)
132+
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) # (num_experts, seq_len)
133+
return router_scores, router_indices
134+
135+
class TokenDispatcher(nn.Module):
136+
# this module is important to add EP hook
137+
def __init__(self, config):
138+
super().__init__()
139+
self.config = config
140+
self.hidden_size = config.hidden_size
141+
142+
def forward(self, routed_out, routing_weights):
143+
# routed_out is (num_experts, batch_size, seq_len, hidden_size)
144+
routed_out = routed_out * routing_weights[:, None, :, None] # we're throwing away computed routed_out for rest of experts
145+
routed_out = routed_out.sum(dim=0) # (batch_size, seq_len, hidden_size)
146+
return routed_out
113147

114148
@use_kernel_forward_from_hub("MegaBlocksMoeMLP")
115149
class OpenAIMoeMLP(nn.Module):
116150
def __init__(self, config):
117151
super().__init__()
118-
self.top_k = config.num_experts_per_tok
119-
self.hidden_dim = config.hidden_size
120-
self.num_local_experts = config.num_local_experts
152+
self.router = TopKRouter(config)
121153
self.experts = OpenAIMoeExperts(config)
122-
self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=True)
154+
self.token_dispatcher = TokenDispatcher(config)
123155

124156
def forward(self, hidden_states):
125157
# we don't slice weight as its not compile compatible
126-
batch_size = hidden_states.shape[0]
127-
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
128-
router_logits = self.router(hidden_states)
129-
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
130-
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1)
131-
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1)
132-
routed_out = self.experts(hidden_states, router_indices, router_top_value)
133-
if self.training:
134-
output_states = routed_out.view(batch_size, -1, self.hidden_dim)
135-
else:
136-
routed_out = routed_out.view(self.num_local_experts, -1, self.hidden_dim) * router_scores[..., None]
137-
output_states = routed_out.view(self.num_local_experts, batch_size, -1, self.hidden_dim).sum(dim=0)
138-
return output_states, router_scores
158+
router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len)
159+
routed_out, _ = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) #TODO: router_indices isn't used inside this func
160+
hidden_states = self.token_dispatcher(routed_out, router_scores)
161+
return hidden_states, router_scores
139162

140163

141164
class OpenAIMoeRotaryEmbedding(LlamaRotaryEmbedding):

0 commit comments

Comments
 (0)