Skip to content

Commit

Permalink
Dev/zheng/coqa (huggingface#4)
Browse files Browse the repository at this point in the history
* upgrade roberta question answering based on xlnet question answering

* upgrade roberta question answering based on xlnet question answering (cont.)

* upgrade roberta question answering based on xlnet question answering (cont.)

* upgrade roberta question answering based on xlnet question answering (cont.)

* upgrade roberta question answering based on xlnet question answering (cont.)

* upgrade roberta question answering based on xlnet question answering (cont.)

* upgrade roberta question answering based on xlnet question answering (cont.)

* upgrade roberta question answering based on xlnet question answering (cont.)

* revert question answering changes in roberta/xlnet modeling

* revert question answering changes in roberta/xlnet modeling (cont.)

* revert to roberta qa simple

* revert to roberta qa simple (cont.)

* revert to roberta qa simple (cont.)

* revert to roberta qa simple (cont.)

* revert 'revert to roberta qa simple'

* update paragraph/query order for xlnet

* update paragraph/query order for xlnet (cont.)

* update answer cls layer & modulize squad output layer for roberta/xlnet

* update answer cls layer & modulize squad output layer for roberta/xlnet (cont.)

* use start/end index lookup for detokenization

* Revert "use start/end index lookup for detokenization"

This reverts commit b8807c478dee2aed01e531f188ea191e0ef4f37a.

* remove complex output layer for roberta-squad

* add back complex qa output layer for roberta-squad
  • Loading branch information
stevezheng23 authored Oct 25, 2019
1 parent 7c3f639 commit c3e928e
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 150 deletions.
37 changes: 20 additions & 17 deletions examples/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,16 @@ def train(args, train_dataset, model, tokenizer):
for step, batch in enumerate(epoch_iterator):
model.train()
batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
'start_positions': batch[3],
'end_positions': batch[4]}
if args.model_type in ['xlnet', 'xlm']:
inputs.update({'cls_index': batch[5],
'p_mask': batch[6]})
inputs = {
'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
'start_positions': batch[3],
'end_positions': batch[4],
'is_impossible': batch[5] if args.model_type in ['xlnet'] else None,
'cls_index': batch[6] if args.model_type in ['xlnet', 'xlm'] else None,
'p_mask': batch[7] if args.model_type in ['xlnet', 'xlm'] else None
}
outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in transformers (see doc)

Expand Down Expand Up @@ -220,14 +222,14 @@ def evaluate(args, model, tokenizer, prefix=""):
model.eval()
batch = tuple(t.to(args.device) for t in batch)
with torch.no_grad():
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2] # XLM don't use segment_ids
}
inputs = {
'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
'cls_index': batch[4] if args.model_type in ['xlnet', 'xlm'] else None,
'p_mask': batch[5] if args.model_type in ['xlnet', 'xlm'] else None
}
example_indices = batch[3]
if args.model_type in ['xlnet', 'xlm']:
inputs.update({'cls_index': batch[4],
'p_mask': batch[5]})
outputs = model(**inputs)

for i, example_index in enumerate(example_indices):
Expand Down Expand Up @@ -261,7 +263,7 @@ def evaluate(args, model, tokenizer, prefix=""):
args.max_answer_length, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.predict_file,
model.config.start_n_top, model.config.end_n_top,
args.version_2_with_negative, tokenizer, args.verbose_logging)
args.version_2_with_negative, tokenizer, args.do_lower_case, args.verbose_logging)
elif args.model_type in ['roberta']:
write_predictions(examples, features, all_results, args.n_best_size,
args.max_answer_length, args.do_lower_case, output_prediction_file,
Expand Down Expand Up @@ -334,8 +336,9 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
else:
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
all_is_impossible = torch.tensor([1.0 if f.is_impossible else 0.0 for f in features], dtype=torch.float)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
all_start_positions, all_end_positions,
all_start_positions, all_end_positions, all_is_impossible,
all_cls_index, all_p_mask)

if output_examples:
Expand Down
120 changes: 74 additions & 46 deletions examples/utils_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,55 +273,84 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# Original TF implem also keep the classification token (set to 0) (not sure why...)
p_mask = []

# CLS token at the beginning
if not cls_token_at_end:
tokens.append(cls_token)
segment_ids.append(cls_token_segment_id)
p_mask.append(0)
cls_index = 0

# Query
for token in query_tokens:
tokens.append(token)
segment_ids.append(sequence_a_segment_id)
p_mask.append(1)

# SEP token
tokens.append(sep_token)
segment_ids.append(sequence_a_segment_id)
p_mask.append(1)

if sep_token_extra:
if cls_token_at_end:
# Paragraph
for i in range(doc_span.length):
split_token_index = doc_span.start + i
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
is_max_context = _check_is_max_context(doc_spans, doc_span_index, split_token_index)
token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(sequence_a_segment_id)
p_mask.append(0)
paragraph_len = doc_span.length

# SEP token
tokens.append(sep_token)
segment_ids.append(sequence_a_segment_id)
p_mask.append(1)

# Paragraph
for i in range(doc_span.length):
split_token_index = doc_span.start + i
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]

is_max_context = _check_is_max_context(doc_spans, doc_span_index,
split_token_index)
token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index])

if sep_token_extra:
tokens.append(sep_token)
segment_ids.append(sequence_a_segment_id)
p_mask.append(1)

# Query
for token in query_tokens:
tokens.append(token)
segment_ids.append(sequence_b_segment_id)
p_mask.append(1)

# SEP token
tokens.append(sep_token)
segment_ids.append(sequence_b_segment_id)
p_mask.append(0)
paragraph_len = doc_span.length

# SEP token
tokens.append(sep_token)
segment_ids.append(sequence_b_segment_id)
p_mask.append(1)

# CLS token at the end
if cls_token_at_end:
p_mask.append(1)

# CLS token at the end
tokens.append(cls_token)
segment_ids.append(cls_token_segment_id)
p_mask.append(0)
cls_index = len(tokens) - 1 # Index of classification token

else:
# CLS token at the beginning
tokens.append(cls_token)
segment_ids.append(cls_token_segment_id)
p_mask.append(0)
cls_index = 0

# Query
for token in query_tokens:
tokens.append(token)
segment_ids.append(sequence_a_segment_id)
p_mask.append(1)

# SEP token
tokens.append(sep_token)
segment_ids.append(sequence_a_segment_id)
p_mask.append(1)

if sep_token_extra:
tokens.append(sep_token)
segment_ids.append(sequence_a_segment_id)
p_mask.append(1)

# Paragraph
for i in range(doc_span.length):
split_token_index = doc_span.start + i
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
is_max_context = _check_is_max_context(doc_spans, doc_span_index, split_token_index)
token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(sequence_b_segment_id)
p_mask.append(0)
paragraph_len = doc_span.length

# SEP token
tokens.append(sep_token)
segment_ids.append(sequence_b_segment_id)
p_mask.append(1)

input_ids = tokenizer.convert_tokens_to_ids(tokens)

# The mask has 1 for real tokens and 0 for padding tokens. Only real
Expand Down Expand Up @@ -356,7 +385,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
end_position = 0
span_is_impossible = True
else:
doc_offset = len(query_tokens) + special_tokens_count - 1
doc_offset = 0 if cls_token_at_end else len(query_tokens) + special_tokens_count - 1
start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset

Expand Down Expand Up @@ -595,11 +624,11 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
else:
tok_text = " ".join(tok_tokens)

# De-tokenize WordPieces that have been split off.
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")

# Clean whitespace
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
Expand Down Expand Up @@ -700,7 +729,7 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
output_nbest_file,
output_null_log_odds_file, orig_data_file,
start_n_top, end_n_top, version_2_with_negative,
tokenizer, verbose_logging):
tokenizer, do_lower_case, verbose_logging):
""" XLNet write prediction logic (more complex than Bert's).
Write final predictions to the json file and log-odds of null if needed.
Expand Down Expand Up @@ -812,8 +841,7 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)

final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case,
verbose_logging)
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)

if final_text in seen_predictions:
continue
Expand Down
2 changes: 1 addition & 1 deletion transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
XLMForQuestionAnswering, XLMForQuestionAnsweringSimple,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
RobertaForMultipleChoice, RobertaForQuestionAnswering,
RobertaForMultipleChoice, RobertaForQuestionAnswering, RobertaForQuestionAnsweringComplex,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
Expand Down
7 changes: 7 additions & 0 deletions transformers/configuration_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,10 @@

class RobertaConfig(BertConfig):
pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP

def __init__(self,
**kwargs):
super(RobertaConfig, self).__init__(**kwargs)

self.start_n_top = 5
self.end_n_top = 5
Loading

0 comments on commit c3e928e

Please sign in to comment.