@@ -78,8 +78,7 @@ def __init__(self, config):
7878 self .gate_up_proj = nn .Parameter (torch .empty (self .num_experts , self .hidden_size , 2 * self .expert_dim ))
7979 self .gate_up_proj_bias = nn .Parameter (torch .empty (self .num_experts , 2 * self .expert_dim ))
8080 self .down_proj = nn .Parameter (torch .empty ((self .num_experts , self .expert_dim , self .hidden_size )))
81- self .down_proj_bias = nn .Parameter (torch .empty (self .num_experts , self .expert_dim ))
82- self .act_fn = torch .nn .Sigmoid ()
81+ self .down_proj_bias = nn .Parameter (torch .empty (self .num_experts , self .hidden_size ))
8382 self .alpha = 1.702
8483
8584 def forward (self , hidden_states : torch .Tensor , router_indices = None , routing_weights = None ) -> torch .Tensor :
@@ -110,7 +109,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices = None, routing_we
110109 current_state = hidden_states [top_x ] # (num_tokens, hidden_dim)
111110 gate_up = current_state @ self .gate_up_proj [expert_idx ] + self .gate_up_proj_bias [expert_idx ] # (num_tokens, 2 * interm_dim)
112111 gate , up = gate_up .chunk (2 , dim = - 1 ) # (num_tokens, interm_dim)
113- glu = gate * self . act_fn (gate * self .alpha ) # (num_tokens, interm_dim)
112+ glu = gate * torch . sigmoid (gate * self .alpha ) # (num_tokens, interm_dim)
114113 gated_output = (up + 1 ) * glu # (num_tokens, interm_dim)
115114 out = gated_output @ self .down_proj [expert_idx ] + self .down_proj_bias [expert_idx ] # (num_tokens, hidden_dim)
116115 weighted_output = out * routing_weights [top_x , idx ].unsqueeze (- 1 ) # (num_tokens, hidden_dim)
0 commit comments