Skip to content

Commit

Permalink
Dev/zheng/quac (huggingface#14)
Browse files Browse the repository at this point in the history
* update kd qa in roberta modeling

* fix issues for kd-quac runner
  • Loading branch information
stevezheng23 authored Oct 29, 2019
1 parent 6645298 commit 6cdbcd7
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 13 deletions.
4 changes: 2 additions & 2 deletions examples/run_quac_kd.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
else:
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
all_start_targets = torch.tensor([f.start_target for f in features], dtype=torch.float)
all_end_targets = torch.tensor([f.end_target for f in features], dtype=torch.float)
all_start_targets = torch.tensor([f.start_targets for f in features], dtype=torch.float)
all_end_targets = torch.tensor([f.end_targets for f in features], dtype=torch.float)
all_is_impossible = torch.tensor([1.0 if f.is_impossible else 0.0 for f in features], dtype=torch.float)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
all_start_positions, all_end_positions, all_start_targets, all_end_targets,
Expand Down
10 changes: 4 additions & 6 deletions examples/utils_quac_kd.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def __init__(self,
start_position=None,
end_position=None,
is_impossible=None,
start_target=None,
end_target=None):
start_targets=None,
end_targets=None):
self.unique_id = unique_id
self.example_index = example_index
self.doc_span_index = doc_span_index
Expand All @@ -108,8 +108,8 @@ def __init__(self,
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
self.start_target = start_target
self.end_target = end_target
self.start_targets = start_targets
self.end_targets = end_targets


def read_quac_examples(input_file, is_training, version_2_with_negative):
Expand Down Expand Up @@ -428,8 +428,6 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
assert len(start_targets) == max_seq_length
assert len(end_targets) == max_seq_length

span_is_impossible = example.is_impossible
start_position = None
Expand Down
10 changes: 5 additions & 5 deletions transformers/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import CrossEntropyLoss, MSELoss, KLDivLoss

from .modeling_utils import SQuADHead
from .modeling_bert import BertEmbeddings, BertLayerNorm, BertModel, BertPreTrainedModel, gelu
Expand Down Expand Up @@ -819,13 +819,13 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_
total_loss = span_loss if total_loss == None else total_loss + span_loss

if start_targets is not None and end_targets is not None:
kd_loss_fct = KLDivLoss()
kd_start_probs = nn.LogSoftmax(kd_start_logits)
kd_end_probs = nn.LogSoftmax(kd_end_logits)
kd_loss_fct = KLDivLoss(reduction='batchmean')
kd_start_probs = F.log_softmax(kd_start_logits, dim=-1)
kd_end_probs = F.log_softmax(kd_end_logits, dim=-1)
kd_start_loss = kd_loss_fct(kd_start_probs, start_targets)
kd_end_loss = kd_loss_fct(kd_end_probs, end_targets)
kd_span_loss = (self.kd_temperature ** 2) * (kd_start_loss + kd_end_loss) / 2
total_loss = kd_span_loss if total_loss == None else total_loss + kd_span_loss
total_loss = kd_span_loss if total_loss is None else total_loss + kd_span_loss

if total_loss is not None:
outputs = (total_loss,) + outputs
Expand Down

0 comments on commit 6cdbcd7

Please sign in to comment.