Skip to content

Commit

Permalink
fix model llama for split function in line 852 (#1941)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuizhuzheming authored Feb 14, 2025
1 parent 9ea7a99 commit ab20e2e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion mindnlp/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ def forward(

hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
lm_head_slices = ops.split(self.lm_head.weight,self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = ops.cat(logits, dim=-1)
else:
Expand Down

0 comments on commit ab20e2e

Please sign in to comment.