Skip to content

#fix qwen2 abnormal loss caused by SoftmaxCrossEntropyWithLogits on 910A/B #2034

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: 0.4
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mindnlp/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(self, model, peft_config: PeftConfig, adapter_name="default"):
# if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"):
# self.base_model.config.pretraining_tp = 1

def save_pretrained(self, save_directory, safe_serialization=False, **kwargs):
def save_pretrained(self, save_directory, safe_serialization=True, **kwargs):
r"""
This function saves the adapter model and the adapter configuration files to a directory, so that it can be
reloaded using the [`LoraModel.from_pretrained`] class method, and also used by the [`LoraModel.push_to_hub`]
Expand Down
47 changes: 32 additions & 15 deletions mindnlp/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,13 +826,22 @@ def forward(
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
# Flatten the tokens
loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = nn.functional.one_hot(shift_labels.view(-1), self.config.vocab_size)
# Enable model parallelism
loss, _ = loss_fct(shift_logits, shift_labels.to(shift_logits.dtype))
loss = loss.mean()
if ON_ORANGE_PI:
# Flatten the tokens
loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = nn.functional.one_hot(shift_labels.view(-1), self.config.vocab_size)
# Enable model parallelism
loss, _ = loss_fct(shift_logits, shift_labels.to(shift_logits.dtype))
loss = loss.mean()
else:
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
loss = loss_fct(shift_logits, shift_labels)


if not return_dict:
output = (logits,) + outputs[1:]
Expand Down Expand Up @@ -1004,10 +1013,14 @@ def forward(
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
labels = nn.functional.one_hot(labels.view(-1), self.num_labels)
loss, _ = loss_fct(pooled_logits.view(-1, self.num_labels), labels.to(pooled_logits.dtype))
loss = loss.mean()
if ON_ORANGE_PI:
loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
labels = nn.functional.one_hot(labels.view(-1), self.num_labels)
loss, _ = loss_fct(pooled_logits.view(-1, self.num_labels), labels.to(pooled_logits.dtype))
loss = loss.mean()
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
Expand Down Expand Up @@ -1086,10 +1099,14 @@ def forward(

loss = None
if labels is not None:
loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
labels = nn.functional.one_hot(labels.view(-1), self.num_labels)
loss, _= loss_fct(logits.view(-1, self.num_labels), labels.to(logits.dtype))
loss = loss.mean()
if ON_ORANGE_PI:
loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
labels = nn.functional.one_hot(labels.view(-1), self.num_labels)
loss, _= loss_fct(logits.view(-1, self.num_labels), labels.to(logits.dtype))
loss = loss.mean()
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

if not return_dict:
output = (logits,) + outputs[2:]
Expand Down