Skip to content

Commit

Permalink
update to newest mcore code
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Cui <chcui@nvidia.com>
  • Loading branch information
cuichenx committed Jun 11, 2024
1 parent c07de5f commit 7ac5a5d
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def forward(self, permuted_local_hidden_states, tokens_per_expert):

cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
# Insert zero at the begining for offset index's convenience
zero_tensor = torch.zeros(1, dtype=torch.long)
zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
for expert_num, expert in enumerate(self.local_experts):
start = cumsum_num_tokens[expert_num]
Expand Down

0 comments on commit 7ac5a5d

Please sign in to comment.