Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BigBird] enable BigBirdForQuestionAnswering to return pooler output #11439

Merged
merged 1 commit into from
Apr 26, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 40 additions & 5 deletions src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
CausalLMOutputWithCrossAttentions,
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
Expand Down Expand Up @@ -1857,6 +1856,41 @@ class BigBirdForPreTrainingOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
class BigBirdForQuestionAnsweringModelOutput(ModelOutput):
"""
Base class for outputs of question answering models.

Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`):
Span-start scores (before SoftMax).
end_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`):
Span-end scores (before SoftMax).
pooler_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 1)`):
pooler output from BigBigModel
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.

Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

loss: Optional[torch.FloatTensor] = None
start_logits: torch.FloatTensor = None
end_logits: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None


@add_start_docstrings(
"The bare BigBird Model transformer outputting raw hidden-states without any specific head on top.",
BIG_BIRD_START_DOCSTRING,
Expand Down Expand Up @@ -2852,14 +2886,14 @@ def forward(self, encoder_output):
BIG_BIRD_START_DOCSTRING,
)
class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
def __init__(self, config):
def __init__(self, config, add_pooling_layer=False):
super().__init__(config)

config.num_labels = 2
self.num_labels = config.num_labels
self.sep_token_id = config.sep_token_id

self.bert = BigBirdModel(config, add_pooling_layer=False)
self.bert = BigBirdModel(config, add_pooling_layer=add_pooling_layer)
self.qa_classifier = BigBirdForQuestionAnsweringHead(config)

self.init_weights()
Expand All @@ -2868,7 +2902,7 @@ def __init__(self, config):
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/bigbird-base-trivia-itc",
output_type=QuestionAnsweringModelOutput,
output_type=BigBirdForQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
Expand Down Expand Up @@ -2958,10 +2992,11 @@ def forward(
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output

return QuestionAnsweringModelOutput(
return BigBirdForQuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
pooler_output=outputs.pooler_output,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Expand Down