@@ -52,8 +52,8 @@ def forward(self, hidden_states):
5252class OpenAIMoeExperts (nn .Module ):
5353 def __init__ (self , config ):
5454 super ().__init__ ()
55- self .num_experts = config .num_local_experts
5655 self .intermediate_size = config .intermediate_size
56+ self .num_experts = config .num_local_experts
5757 self .hidden_size = config .hidden_size
5858 self .expert_dim = self .intermediate_size
5959 self .gate_up_proj = nn .Parameter (torch .empty (self .num_experts , self .hidden_size , 2 * self .expert_dim ))
@@ -70,16 +70,19 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
7070 For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
7171
7272 Args:
73- hidden_states (torch.Tensor): (batch_size * token_num , hidden_size)
73+ hidden_states (torch.Tensor): (batch_size, seq_len , hidden_size)
7474 selected_experts (torch.Tensor): (batch_size * token_num, top_k)
7575 routing_weights (torch.Tensor): (batch_size * token_num, top_k)
7676 Returns:
7777 torch.Tensor
7878 """
79+ batch_size = hidden_states .shape [0 ]
80+ hidden_states = hidden_states .reshape (- 1 , self .hidden_size ) # (num_tokens, hidden_size)
81+ num_experts = routing_weights .shape [0 ]
7982 if self .training :
8083 next_states = torch .zeros_like (hidden_states , dtype = hidden_states .dtype , device = hidden_states .device )
8184 with torch .no_grad ():
82- expert_mask = torch .nn .functional .one_hot (router_indices , num_classes = self . num_experts ).permute (
85+ expert_mask = torch .nn .functional .one_hot (router_indices , num_classes = num_experts ).permute (
8386 2 , 1 , 0
8487 )
8588 expert_hitted = torch .greater (expert_mask .sum (dim = (- 1 , - 2 )), 0 ).nonzero ()
@@ -100,42 +103,62 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
100103 ) # (num_tokens, hidden_dim)
101104 weighted_output = out * routing_weights [top_x , idx , None ] # (num_tokens, hidden_dim)
102105 next_states .index_add_ (0 , top_x , weighted_output .to (hidden_states .dtype )[0 ])
106+ next_states = next_states .view (batch_size , - 1 , self .hidden_size )
103107 else :
104- hidden_states = hidden_states .repeat (self . num_experts , 1 )
105- hidden_states = hidden_states .view (self . num_experts , - 1 , self .hidden_size )
108+ hidden_states = hidden_states .repeat (num_experts , 1 )
109+ hidden_states = hidden_states .view (num_experts , - 1 , self .hidden_size )
106110 gate_up = torch .bmm (hidden_states , self .gate_up_proj ) + self .gate_up_proj_bias [..., None , :]
107111 gate , up = gate_up .chunk (2 , dim = - 1 ) # not supported for DTensors
108112 glu = gate * torch .sigmoid (gate * self .alpha )
109- next_states = torch .bmm (((up + 1 ) * glu ), self .down_proj ) + self .down_proj_bias [..., None , :]
110- next_states = next_states .view (- 1 , self .hidden_size )
111- return next_states
113+ next_states = torch .bmm (((up + 1 ) * glu ), self .down_proj )
114+ next_states = next_states + self .down_proj_bias [..., None , :]
115+ next_states = next_states .view (num_experts , batch_size , - 1 , self .hidden_size ) # (num_experts, batch_size, seq_len, hidden_size)
116+ return next_states , None
112117
118+ class TopKRouter (nn .Module ):
119+ def __init__ (self , config ):
120+ super ().__init__ ()
121+ self .top_k = config .num_experts_per_tok
122+ self .num_experts = config .num_local_experts
123+ self .hidden_dim = config .hidden_size
124+ self .weight = nn .Parameter (torch .empty (self .num_experts , self .hidden_dim ))
125+ self .bias = nn .Parameter (torch .empty (self .num_experts ))
126+
127+ def forward (self , hidden_states ):
128+ hidden_states = hidden_states .reshape (- 1 , self .hidden_dim )
129+ router_logits = F .linear (hidden_states , self .weight , self .bias ) # (seq_len, num_experts)
130+ router_top_value , router_indices = torch .topk (router_logits , self .top_k , dim = - 1 ) # (seq_len, top_k)
131+ router_top_value = torch .nn .functional .softmax (router_top_value , dim = 1 )
132+ router_scores = torch .zeros_like (router_logits ).scatter_ (1 , router_indices , router_top_value ).transpose (0 , 1 ) # (num_experts, seq_len)
133+ return router_scores , router_indices
134+
135+ class TokenDispatcher (nn .Module ):
136+ # this module is important to add EP hook
137+ def __init__ (self , config ):
138+ super ().__init__ ()
139+ self .config = config
140+ self .hidden_size = config .hidden_size
141+
142+ def forward (self , routed_out , routing_weights ):
143+ # routed_out is (num_experts, batch_size, seq_len, hidden_size)
144+ routed_out = routed_out * routing_weights [:, None , :, None ] # we're throwing away computed routed_out for rest of experts
145+ routed_out = routed_out .sum (dim = 0 ) # (batch_size, seq_len, hidden_size)
146+ return routed_out
113147
114148@use_kernel_forward_from_hub ("MegaBlocksMoeMLP" )
115149class OpenAIMoeMLP (nn .Module ):
116150 def __init__ (self , config ):
117151 super ().__init__ ()
118- self .top_k = config .num_experts_per_tok
119- self .hidden_dim = config .hidden_size
120- self .num_local_experts = config .num_local_experts
152+ self .router = TopKRouter (config )
121153 self .experts = OpenAIMoeExperts (config )
122- self .router = nn . Linear (config . hidden_size , config . num_local_experts , bias = True )
154+ self .token_dispatcher = TokenDispatcher (config )
123155
124156 def forward (self , hidden_states ):
125157 # we don't slice weight as its not compile compatible
126- batch_size = hidden_states .shape [0 ]
127- hidden_states = hidden_states .reshape (- 1 , self .hidden_dim )
128- router_logits = self .router (hidden_states )
129- router_top_value , router_indices = torch .topk (router_logits , self .top_k , dim = - 1 )
130- router_top_value = torch .nn .functional .softmax (router_top_value , dim = 1 )
131- router_scores = torch .zeros_like (router_logits ).scatter_ (1 , router_indices , router_top_value ).transpose (0 , 1 )
132- routed_out = self .experts (hidden_states , router_indices , router_top_value )
133- if self .training :
134- output_states = routed_out .view (batch_size , - 1 , self .hidden_dim )
135- else :
136- routed_out = routed_out .view (self .num_local_experts , - 1 , self .hidden_dim ) * router_scores [..., None ]
137- output_states = routed_out .view (self .num_local_experts , batch_size , - 1 , self .hidden_dim ).sum (dim = 0 )
138- return output_states , router_scores
158+ router_scores , router_indices = self .router (hidden_states ) # (num_experts, seq_len)
159+ routed_out , _ = self .experts (hidden_states , router_indices = router_indices , routing_weights = router_scores ) #TODO: router_indices isn't used inside this func
160+ hidden_states = self .token_dispatcher (routed_out , router_scores )
161+ return hidden_states , router_scores
139162
140163
141164class OpenAIMoeRotaryEmbedding (LlamaRotaryEmbedding ):
0 commit comments