From 4812a5a7678fffef1751ba45ad5ae0b5e8854b70 Mon Sep 17 00:00:00 2001 From: erenup Date: Mon, 16 Sep 2019 11:50:18 +0800 Subject: [PATCH 1/4] add doc string --- .../run_multiple_choice.py | 13 ++-- .../utils_multiple_choice.py | 59 ++++++++------- pytorch_transformers/modeling_roberta.py | 71 ++++++++++++++++++- pytorch_transformers/modeling_xlnet.py | 51 ++++++++++++- 4 files changed, 158 insertions(+), 36 deletions(-) diff --git a/examples/single_model_scripts/run_multiple_choice.py b/examples/single_model_scripts/run_multiple_choice.py index 9784dfe94d5e15..c82d9b4c1fed5c 100644 --- a/examples/single_model_scripts/run_multiple_choice.py +++ b/examples/single_model_scripts/run_multiple_choice.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Finetuning the library models for multiple choice (Bert, XLM, XLNet).""" +""" Finetuning the library models for multiple choice (Bert, Roberta, XLNet).""" from __future__ import absolute_import, division, print_function @@ -44,7 +44,7 @@ logger = logging.getLogger(__name__) -ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig)), ()) +ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, RobertaConfig)), ()) MODEL_CLASSES = { 'bert': (BertConfig, BertForMultipleChoice, BertTokenizer), @@ -208,7 +208,6 @@ def train(args, train_dataset, model, tokenizer): def evaluate(args, model, tokenizer, prefix="", test=False): - # Loop to handle MNLI double evaluation (matched, mis-matched) eval_task_names = (args.task_name,) eval_outputs_dirs = (args.output_dir,) @@ -259,7 +258,7 @@ def evaluate(args, model, tokenizer, prefix="", test=False): result = {"eval_acc": acc, "eval_loss": eval_loss} results.update(result) - output_eval_file = os.path.join(eval_output_dir, "is_test_" + str(test) + "_eval_results.txt") + output_eval_file = os.path.join(eval_output_dir, "is_test_" + str(test).lower() + "_eval_results.txt") with open(output_eval_file, "w") as writer: logger.info("***** Eval results {} *****".format(str(prefix) + " is test:" + str(test))) @@ -522,9 +521,9 @@ def main(): if not args.do_train: args.output_dir = args.model_name_or_path checkpoints = [args.output_dir] - if args.eval_all_checkpoints: #can not use this to do test!! just for different paras - checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) - logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging + # if args.eval_all_checkpoints: # can not use this to do test!! + # checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) + # logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging logger.info("Evaluate the following checkpoints: %s", checkpoints) for checkpoint in checkpoints: global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" diff --git a/examples/single_model_scripts/utils_multiple_choice.py b/examples/single_model_scripts/utils_multiple_choice.py index 3159db94ea52e5..d8ce76f5501b71 100644 --- a/examples/single_model_scripts/utils_multiple_choice.py +++ b/examples/single_model_scripts/utils_multiple_choice.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" BERT classification fine-tuning: utilities to work with GLUE tasks """ +""" BERT multiple choice fine-tuning: utilities to work with multiple choice tasks of reading comprehension """ from __future__ import absolute_import, division, print_function @@ -38,11 +38,10 @@ def __init__(self, example_id, question, contexts, endings, label=None): """Constructs a InputExample. Args: - guid: Unique id for the example. - text_a: string. The untokenized text of the first sequence. For single - sequence tasks, only this sequence must be specified. - text_b: (Optional) string. The untokenized text of the second sequence. - Only must be specified for sequence pair tasks. + example_id: Unique id for the example. + contexts: list of str. The untokenized text of the first sequence (context of corresponding question). + question: string. The untokenized text of the second sequence (qustion). + endings: list of str. multiple choice's options. Its length must be equal to contexts' length. label: (Optional) string. The label of the example. This should be specified for train and dev examples, but not for test examples. """ @@ -73,7 +72,7 @@ def __init__(self, class DataProcessor(object): - """Base class for data converters for sequence classification data sets.""" + """Base class for data converters for multiple choice data sets.""" def get_train_examples(self, data_dir): """Gets a collection of `InputExample`s for the train set.""" @@ -84,7 +83,7 @@ def get_dev_examples(self, data_dir): raise NotImplementedError() def get_test_examples(self, data_dir): - """Gets a collection of `InputExample`s for the dev set.""" + """Gets a collection of `InputExample`s for the test set.""" raise NotImplementedError() def get_labels(self): @@ -93,7 +92,7 @@ def get_labels(self): class RaceProcessor(DataProcessor): - """Processor for the MRPC data set (GLUE version).""" + """Processor for the RACE data set.""" def get_train_examples(self, data_dir): """See base class.""" @@ -152,13 +151,13 @@ def _create_examples(self, lines, set_type): InputExample( example_id=race_id, question=question, - contexts=[article, article, article, article], + contexts=[article, article, article, article], # this is not efficient but convenient endings=[options[0], options[1], options[2], options[3]], label=truth)) return examples class SwagProcessor(DataProcessor): - """Processor for the MRPC data set (GLUE version).""" + """Processor for the SWAG data set.""" def get_train_examples(self, data_dir): """See base class.""" @@ -172,9 +171,12 @@ def get_dev_examples(self, data_dir): def get_test_examples(self, data_dir): """See base class.""" - logger.info("LOOKING AT {} test".format(data_dir)) + logger.info("LOOKING AT {} dev".format(data_dir)) + raise ValueError( + "For swag testing, the input file does not contain a label column. It can not be tested in current code" + "setting!" + ) return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test") - def get_labels(self): """See base class.""" return ["0", "1", "2", "3"] @@ -213,7 +215,7 @@ def _create_examples(self, lines, type): class ArcProcessor(DataProcessor): - """Processor for the MRPC data set (GLUE version).""" + """Processor for the ARC data set (request from allennlp).""" def get_train_examples(self, data_dir): """See base class.""" @@ -242,6 +244,7 @@ def _read_json(self, input_file): def _create_examples(self, lines, type): """Creates examples for the training and dev sets.""" + #There are two types of labels. They should be normalized def normalize(truth): if truth in "ABCD": return ord(truth) - ord("A") @@ -256,6 +259,7 @@ def normalize(truth): four_choice = 0 five_choice = 0 other_choices = 0 + # we deleted example which has more than or less than four choices for line in tqdm.tqdm(lines, desc="read arc data"): data_raw = json.loads(line.strip("\n")) if len(data_raw["question"]["choices"]) == 3: @@ -274,7 +278,6 @@ def normalize(truth): question = question_choices["stem"] id = data_raw["id"] options = question_choices["choices"] - if len(options) == 4: examples.append( InputExample( @@ -328,13 +331,16 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokens_a = tokenizer.tokenize(context) tokens_b = None if example.question.find("_") != -1: + #this is for cloze question tokens_b = tokenizer.tokenize(example.question.replace("_", ending)) else: - tokens_b = tokenizer.tokenize(example.question) - tokens_b += [sep_token] - if sep_token_extra: - tokens_b += [sep_token] - tokens_b += tokenizer.tokenize(ending) + tokens_b = tokenizer.tokenize(example.question + " " + ending) + # you can add seq token between quesiotn and ending. This does not make too much difference. + # tokens_b = tokenizer.tokenize(example.question) + # tokens_b += [sep_token] + # if sep_token_extra: + # tokens_b += [sep_token] + # tokens_b += tokenizer.tokenize(ending) special_tokens_count = 4 if sep_token_extra else 3 _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count) @@ -427,15 +433,18 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length): # one token at a time. This makes more sense than truncating an equal percent # of tokens from each, since if one sequence is very short then each token # that's truncated likely contains more information than a longer sequence. + + # However, since we'd better not to remove tokens of options and questions, you can choose to use a bigger + # length or only pop from context while True: total_length = len(tokens_a) + len(tokens_b) if total_length <= max_length: break - # if len(tokens_a) > len(tokens_b): - # tokens_a.pop() - # else: - # tokens_b.pop() - tokens_a.pop() + if len(tokens_a) > len(tokens_b): + tokens_a.pop() + else: + logger.info('Attention! you are removing from question + options. Try to use a bigger max seq length!') + tokens_b.pop() processors = { diff --git a/pytorch_transformers/modeling_roberta.py b/pytorch_transformers/modeling_roberta.py index c8cb055d553431..cc5650a5ce0ae3 100644 --- a/pytorch_transformers/modeling_roberta.py +++ b/pytorch_transformers/modeling_roberta.py @@ -294,7 +294,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel): Examples:: - tokenizer = RoertaTokenizer.from_pretrained('roberta-base') + tokenizer = RobertaTokenizer.from_pretrained('roberta-base') model = RobertaForSequenceClassification.from_pretrained('roberta-base') input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 @@ -333,8 +333,75 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=No return outputs # (loss), logits, (hidden_states), (attentions) - +@add_start_docstrings("""Roberta Model with a multiple choice classification head on top (a linear layer on top of + the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, + ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING) class RobertaForMultipleChoice(BertPreTrainedModel): + r""" + Inputs: + **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``: + Indices of input sequence tokens in the vocabulary. + The second dimension of the input (`num_choices`) indicates the number of choices to score. + To match pre-training, RoBerta input sequence should be formatted with [CLS] and [SEP] tokens as follows: + + (a) For sequence pairs: + + ``tokens: [CLS] is this jack ##son ##ville ? [SEP] [SEP] no it is not . [SEP]`` + + ``token_type_ids: 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1`` + + (b) For single sequences: + + ``tokens: [CLS] the dog is hairy . [SEP]`` + + ``token_type_ids: 0 0 0 0 0 0 0`` + + Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`. + See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and + :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. + **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``: + Segment token indices to indicate first and second portions of the inputs. + The second dimension of the input (`num_choices`) indicates the number of choices to score. + Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` + **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``: + Mask to avoid performing attention on padding token indices. + The second dimension of the input (`num_choices`) indicates the number of choices to score. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. + **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for computing the multiple choice classification loss. + Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension + of the input tensors. (see `input_ids` above) + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Classification loss. + **classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension + of the input tensors. (see `input_ids` above). + Classification scores (before SoftMax). + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + tokenizer = RobertaTokenizer.from_pretrained('roberta-base') + model = RobertaForMultipleChoice.from_pretrained('roberta-base') + choices = ["Hello, my dog is cute", "Hello, my cat is amazing"] + input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices + labels = torch.tensor(1).unsqueeze(0) # Batch size 1 + outputs = model(input_ids, labels=labels) + loss, classification_scores = outputs[:2] + + """ config_class = RobertaConfig pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP base_model_prefix = "roberta" diff --git a/pytorch_transformers/modeling_xlnet.py b/pytorch_transformers/modeling_xlnet.py index b8045b92235446..5c94b186f46bf0 100644 --- a/pytorch_transformers/modeling_xlnet.py +++ b/pytorch_transformers/modeling_xlnet.py @@ -1152,9 +1152,56 @@ def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mas return outputs # return (loss), logits, mems, (hidden states), (attentions) - +@add_start_docstrings("""XLNet Model with a multiple choice classification head on top (a linear layer on top of + the pooled output and a softmax) e.g. for RACE/SWAG tasks. """, + XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING) class XLNetForMultipleChoice(XLNetPreTrainedModel): r""" + Inputs: + **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``: + Indices of input sequence tokens in the vocabulary. + The second dimension of the input (`num_choices`) indicates the number of choices to scores. + **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``: + Segment token indices to indicate first and second portions of the inputs. + The second dimension of the input (`num_choices`) indicates the number of choices to score. + Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` + **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``: + Mask to avoid performing attention on padding token indices. + The second dimension of the input (`num_choices`) indicates the number of choices to score. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. + **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for computing the multiple choice classification loss. + Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension + of the input tensors. (see `input_ids` above) + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Classification loss. + **classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension + of the input tensors. (see `input_ids` above). + Classification scores (before SoftMax). + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased') + model = XLNetForMultipleChoice.from_pretrained('xlnet-base-cased') + choices = ["Hello, my dog is cute", "Hello, my cat is amazing"] + input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices + labels = torch.tensor(1).unsqueeze(0) # Batch size 1 + outputs = model(input_ids, labels=labels) + loss, classification_scores = outputs[:2] """ def __init__(self, config): @@ -1251,7 +1298,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): Examples:: - tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048') + tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased') model = XLMForQuestionAnswering.from_pretrained('xlnet-large-cased') input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 start_positions = torch.tensor([1]) From 603b470a3d855c5187564701c475cfef5826c224 Mon Sep 17 00:00:00 2001 From: erenup Date: Mon, 16 Sep 2019 18:53:37 +0800 Subject: [PATCH 2/4] add warnning info --- examples/single_model_scripts/utils_multiple_choice.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/single_model_scripts/utils_multiple_choice.py b/examples/single_model_scripts/utils_multiple_choice.py index d8ce76f5501b71..7abcc5e1e9ea20 100644 --- a/examples/single_model_scripts/utils_multiple_choice.py +++ b/examples/single_model_scripts/utils_multiple_choice.py @@ -443,7 +443,9 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length): if len(tokens_a) > len(tokens_b): tokens_a.pop() else: - logger.info('Attention! you are removing from question + options. Try to use a bigger max seq length!') + logger.info('Attention! you are removing from token_b (swag task is ok). ' + 'If you are training ARC and RACE (you are poping question + options), ' + 'you need to try to use a bigger max seq length!') tokens_b.pop() From a9debaca3dfd3d12c3acf896e3e2272d0a087257 Mon Sep 17 00:00:00 2001 From: erenup Date: Mon, 16 Sep 2019 19:55:24 +0800 Subject: [PATCH 3/4] fixed init_weight --- pytorch_transformers/modeling_roberta.py | 2 +- pytorch_transformers/modeling_xlnet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_transformers/modeling_roberta.py b/pytorch_transformers/modeling_roberta.py index 98bfa4202a18b1..2b64893c2e53ec 100644 --- a/pytorch_transformers/modeling_roberta.py +++ b/pytorch_transformers/modeling_roberta.py @@ -418,7 +418,7 @@ def __init__(self, config): self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, 1) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, position_ids=None, head_mask=None): diff --git a/pytorch_transformers/modeling_xlnet.py b/pytorch_transformers/modeling_xlnet.py index 91390b9d6bc424..fa65d83b0ef38d 100644 --- a/pytorch_transformers/modeling_xlnet.py +++ b/pytorch_transformers/modeling_xlnet.py @@ -1065,7 +1065,7 @@ def __init__(self, config): self.sequence_summary = SequenceSummary(config) self.logits_proj = nn.Linear(config.d_model, 1) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, From 5882c442e52921d6e8755efccd7e11a2ae405bbe Mon Sep 17 00:00:00 2001 From: erenup Date: Mon, 16 Sep 2019 22:38:08 +0800 Subject: [PATCH 4/4] add example usage --- examples/README.md | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/examples/README.md b/examples/README.md index c47dc41433ae3a..3253e5481c4243 100644 --- a/examples/README.md +++ b/examples/README.md @@ -8,7 +8,8 @@ similar API between the different models. | [Language Model fine-tuning](#language-model-fine-tuning) | Fine-tuning the library models for language modeling on a text dataset. Causal language modeling for GPT/GPT-2, masked language modeling for BERT/RoBERTa. | | [Language Generation](#language-generation) | Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet. | | [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. | -| [SQuAD](#squad) | Using BERT for question answering, examples with distributed training. | +| [SQuAD](#squad) | Using BERT for question answering, examples with distributed training. | +| [Multiple Choice](#multiple choice) | Examples running BERT/XLNet/RoBERTa on the SWAG/RACE/ARC tasks. ## Language model fine-tuning @@ -282,6 +283,40 @@ The results are the following: loss = 0.04755385363816904 ``` +##Multiple Choice + +Based on the script [`run_multiple_choice.py`](). + +#### Fine-tuning on SWAG +Download [swag](https://github.com/rowanz/swagaf/tree/master/data) data + +``` +#training on 4 tesla V100(16GB) GPUS +export SWAG_DIR=/path/to/swag_data_dir +python ./examples/single_model_scripts/run_multiple_choice.py \ +--model_type roberta \ +--task_name swag \ +--model_name_or_path roberta-base \ +--do_train \ +--do_eval \ +--do_lower_case \ +--data_dir $SWAG_DIR \ +--learning_rate 5e-5 \ +--num_train_epochs 3 \ +--max_seq_length 80 \ +--output_dir models_bert/swag_base \ +--per_gpu_eval_batch_size=16 \ +--per_gpu_train_batch_size=16 \ +--gradient_accumulation_steps 2 \ +--overwrite_output +``` +Training with the defined hyper-parameters yields the following results: +``` +***** Eval results ***** +eval_acc = 0.8338998300509847 +eval_loss = 0.44457291918821606 +``` + ## SQuAD Based on the script [`run_squad.py`](https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_squad.py).