Skip to content

Commit e9f7ec2

Browse files
authored
Merge pull request huggingface#2 from huggingface/fix-zero3-and-down_proj_bias-shape
Replace activation module with function and fix `down_proj_bias` shape
2 parents e43411a + 079840a commit e9f7ec2

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

src/transformers/models/openai_moe/modeling_openai_moe.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ def __init__(self, config):
7878
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
7979
self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim))
8080
self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
81-
self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim))
82-
self.act_fn = torch.nn.Sigmoid()
81+
self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size))
8382
self.alpha = 1.702
8483

8584
def forward(self, hidden_states: torch.Tensor, router_indices = None, routing_weights=None) -> torch.Tensor:
@@ -110,7 +109,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices = None, routing_we
110109
current_state = hidden_states[top_x] # (num_tokens, hidden_dim)
111110
gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] # (num_tokens, 2 * interm_dim)
112111
gate, up = gate_up.chunk(2, dim=-1) # (num_tokens, interm_dim)
113-
glu = gate * self.act_fn(gate * self.alpha) # (num_tokens, interm_dim)
112+
glu = gate * torch.sigmoid(gate * self.alpha) # (num_tokens, interm_dim)
114113
gated_output = (up + 1) * glu # (num_tokens, interm_dim)
115114
out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] # (num_tokens, hidden_dim)
116115
weighted_output = out * routing_weights[top_x, idx].unsqueeze(-1) # (num_tokens, hidden_dim)

0 commit comments

Comments
 (0)