From 66452984eb0818d8c9787986c00d43d16804859f Mon Sep 17 00:00:00 2001 From: Xiaoming <7437363+stevezheng23@users.noreply.github.com> Date: Tue, 29 Oct 2019 11:11:01 -0700 Subject: [PATCH] update kd qa in roberta modeling (#13) --- transformers/modeling_roberta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformers/modeling_roberta.py b/transformers/modeling_roberta.py index 876aacadc26ee6..f8520208fe4ac2 100644 --- a/transformers/modeling_roberta.py +++ b/transformers/modeling_roberta.py @@ -824,7 +824,7 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ kd_end_probs = nn.LogSoftmax(kd_end_logits) kd_start_loss = kd_loss_fct(kd_start_probs, start_targets) kd_end_loss = kd_loss_fct(kd_end_probs, end_targets) - kd_span_loss = (kd_start_loss + kd_end_loss) / 2 + kd_span_loss = (self.kd_temperature ** 2) * (kd_start_loss + kd_end_loss) / 2 total_loss = kd_span_loss if total_loss == None else total_loss + kd_span_loss if total_loss is not None: