diff --git a/transformers/modeling_roberta.py b/transformers/modeling_roberta.py index 876aacadc26ee6..f8520208fe4ac2 100644 --- a/transformers/modeling_roberta.py +++ b/transformers/modeling_roberta.py @@ -824,7 +824,7 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ kd_end_probs = nn.LogSoftmax(kd_end_logits) kd_start_loss = kd_loss_fct(kd_start_probs, start_targets) kd_end_loss = kd_loss_fct(kd_end_probs, end_targets) - kd_span_loss = (kd_start_loss + kd_end_loss) / 2 + kd_span_loss = (self.kd_temperature ** 2) * (kd_start_loss + kd_end_loss) / 2 total_loss = kd_span_loss if total_loss == None else total_loss + kd_span_loss if total_loss is not None: