Skip to content

Commit

Permalink
Merge pull request #2 from erenup/run_multiple_choice_add_doc
Browse files Browse the repository at this point in the history
Run multiple choice add doc
  • Loading branch information
erenup authored Sep 16, 2019
2 parents 84b9d1c + 5882c44 commit 5a81e79
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 39 deletions.
37 changes: 36 additions & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ similar API between the different models.
| [Language Model fine-tuning](#language-model-fine-tuning) | Fine-tuning the library models for language modeling on a text dataset. Causal language modeling for GPT/GPT-2, masked language modeling for BERT/RoBERTa. |
| [Language Generation](#language-generation) | Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet. |
| [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. |
| [SQuAD](#squad) | Using BERT for question answering, examples with distributed training. |
| [SQuAD](#squad) | Using BERT for question answering, examples with distributed training. |
| [Multiple Choice](#multiple choice) | Examples running BERT/XLNet/RoBERTa on the SWAG/RACE/ARC tasks.

## Language model fine-tuning

Expand Down Expand Up @@ -282,6 +283,40 @@ The results are the following:
loss = 0.04755385363816904
```

##Multiple Choice

Based on the script [`run_multiple_choice.py`]().

#### Fine-tuning on SWAG
Download [swag](https://github.com/rowanz/swagaf/tree/master/data) data

```
#training on 4 tesla V100(16GB) GPUS
export SWAG_DIR=/path/to/swag_data_dir
python ./examples/single_model_scripts/run_multiple_choice.py \
--model_type roberta \
--task_name swag \
--model_name_or_path roberta-base \
--do_train \
--do_eval \
--do_lower_case \
--data_dir $SWAG_DIR \
--learning_rate 5e-5 \
--num_train_epochs 3 \
--max_seq_length 80 \
--output_dir models_bert/swag_base \
--per_gpu_eval_batch_size=16 \
--per_gpu_train_batch_size=16 \
--gradient_accumulation_steps 2 \
--overwrite_output
```
Training with the defined hyper-parameters yields the following results:
```
***** Eval results *****
eval_acc = 0.8338998300509847
eval_loss = 0.44457291918821606
```

## SQuAD

Based on the script [`run_squad.py`](https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_squad.py).
Expand Down
13 changes: 6 additions & 7 deletions examples/single_model_scripts/run_multiple_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning the library models for multiple choice (Bert, XLM, XLNet)."""
""" Finetuning the library models for multiple choice (Bert, Roberta, XLNet)."""

from __future__ import absolute_import, division, print_function

Expand Down Expand Up @@ -44,7 +44,7 @@

logger = logging.getLogger(__name__)

ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig)), ())
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, RobertaConfig)), ())

MODEL_CLASSES = {
'bert': (BertConfig, BertForMultipleChoice, BertTokenizer),
Expand Down Expand Up @@ -208,7 +208,6 @@ def train(args, train_dataset, model, tokenizer):


def evaluate(args, model, tokenizer, prefix="", test=False):
# Loop to handle MNLI double evaluation (matched, mis-matched)
eval_task_names = (args.task_name,)
eval_outputs_dirs = (args.output_dir,)

Expand Down Expand Up @@ -259,7 +258,7 @@ def evaluate(args, model, tokenizer, prefix="", test=False):
result = {"eval_acc": acc, "eval_loss": eval_loss}
results.update(result)

output_eval_file = os.path.join(eval_output_dir, "is_test_" + str(test) + "_eval_results.txt")
output_eval_file = os.path.join(eval_output_dir, "is_test_" + str(test).lower() + "_eval_results.txt")

with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(str(prefix) + " is test:" + str(test)))
Expand Down Expand Up @@ -522,9 +521,9 @@ def main():
if not args.do_train:
args.output_dir = args.model_name_or_path
checkpoints = [args.output_dir]
if args.eval_all_checkpoints: #can not use this to do test!! just for different paras
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
# if args.eval_all_checkpoints: # can not use this to do test!!
# checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
# logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
Expand Down
61 changes: 36 additions & 25 deletions examples/single_model_scripts/utils_multiple_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BERT classification fine-tuning: utilities to work with GLUE tasks """
""" BERT multiple choice fine-tuning: utilities to work with multiple choice tasks of reading comprehension """

from __future__ import absolute_import, division, print_function

Expand All @@ -38,11 +38,10 @@ def __init__(self, example_id, question, contexts, endings, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
example_id: Unique id for the example.
contexts: list of str. The untokenized text of the first sequence (context of corresponding question).
question: string. The untokenized text of the second sequence (qustion).
endings: list of str. multiple choice's options. Its length must be equal to contexts' length.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
Expand Down Expand Up @@ -73,7 +72,7 @@ def __init__(self,


class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
"""Base class for data converters for multiple choice data sets."""

def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
Expand All @@ -84,7 +83,7 @@ def get_dev_examples(self, data_dir):
raise NotImplementedError()

def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
"""Gets a collection of `InputExample`s for the test set."""
raise NotImplementedError()

def get_labels(self):
Expand All @@ -93,7 +92,7 @@ def get_labels(self):


class RaceProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
"""Processor for the RACE data set."""

def get_train_examples(self, data_dir):
"""See base class."""
Expand Down Expand Up @@ -152,13 +151,13 @@ def _create_examples(self, lines, set_type):
InputExample(
example_id=race_id,
question=question,
contexts=[article, article, article, article],
contexts=[article, article, article, article], # this is not efficient but convenient
endings=[options[0], options[1], options[2], options[3]],
label=truth))
return examples

class SwagProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
"""Processor for the SWAG data set."""

def get_train_examples(self, data_dir):
"""See base class."""
Expand All @@ -172,9 +171,12 @@ def get_dev_examples(self, data_dir):

def get_test_examples(self, data_dir):
"""See base class."""
logger.info("LOOKING AT {} test".format(data_dir))
logger.info("LOOKING AT {} dev".format(data_dir))
raise ValueError(
"For swag testing, the input file does not contain a label column. It can not be tested in current code"
"setting!"
)
return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test")

def get_labels(self):
"""See base class."""
return ["0", "1", "2", "3"]
Expand Down Expand Up @@ -213,7 +215,7 @@ def _create_examples(self, lines, type):


class ArcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
"""Processor for the ARC data set (request from allennlp)."""

def get_train_examples(self, data_dir):
"""See base class."""
Expand Down Expand Up @@ -242,6 +244,7 @@ def _read_json(self, input_file):
def _create_examples(self, lines, type):
"""Creates examples for the training and dev sets."""

#There are two types of labels. They should be normalized
def normalize(truth):
if truth in "ABCD":
return ord(truth) - ord("A")
Expand All @@ -256,6 +259,7 @@ def normalize(truth):
four_choice = 0
five_choice = 0
other_choices = 0
# we deleted example which has more than or less than four choices
for line in tqdm.tqdm(lines, desc="read arc data"):
data_raw = json.loads(line.strip("\n"))
if len(data_raw["question"]["choices"]) == 3:
Expand All @@ -274,7 +278,6 @@ def normalize(truth):
question = question_choices["stem"]
id = data_raw["id"]
options = question_choices["choices"]

if len(options) == 4:
examples.append(
InputExample(
Expand Down Expand Up @@ -328,13 +331,16 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
tokens_a = tokenizer.tokenize(context)
tokens_b = None
if example.question.find("_") != -1:
#this is for cloze question
tokens_b = tokenizer.tokenize(example.question.replace("_", ending))
else:
tokens_b = tokenizer.tokenize(example.question)
tokens_b += [sep_token]
if sep_token_extra:
tokens_b += [sep_token]
tokens_b += tokenizer.tokenize(ending)
tokens_b = tokenizer.tokenize(example.question + " " + ending)
# you can add seq token between quesiotn and ending. This does not make too much difference.
# tokens_b = tokenizer.tokenize(example.question)
# tokens_b += [sep_token]
# if sep_token_extra:
# tokens_b += [sep_token]
# tokens_b += tokenizer.tokenize(ending)

special_tokens_count = 4 if sep_token_extra else 3
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count)
Expand Down Expand Up @@ -427,15 +433,20 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.

# However, since we'd better not to remove tokens of options and questions, you can choose to use a bigger
# length or only pop from context
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
# if len(tokens_a) > len(tokens_b):
# tokens_a.pop()
# else:
# tokens_b.pop()
tokens_a.pop()
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
logger.info('Attention! you are removing from token_b (swag task is ok). '
'If you are training ARC and RACE (you are poping question + options), '
'you need to try to use a bigger max seq length!')
tokens_b.pop()


processors = {
Expand Down
73 changes: 70 additions & 3 deletions pytorch_transformers/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
Examples::
tokenizer = RoertaTokenizer.from_pretrained('roberta-base')
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForSequenceClassification.from_pretrained('roberta-base')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
Expand Down Expand Up @@ -338,8 +338,75 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_

return outputs # (loss), logits, (hidden_states), (attentions)


@add_start_docstrings("""Roberta Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
class RobertaForMultipleChoice(BertPreTrainedModel):
r"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
To match pre-training, RoBerta input sequence should be formatted with [CLS] and [SEP] tokens as follows:
(a) For sequence pairs:
``tokens: [CLS] is this jack ##son ##ville ? [SEP] [SEP] no it is not . [SEP]``
``token_type_ids: 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
(b) For single sequences:
``tokens: [CLS] the dog is hairy . [SEP]``
``token_type_ids: 0 0 0 0 0 0 0``
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Segment token indices to indicate first and second portions of the inputs.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Mask to avoid performing attention on padding token indices.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForMultipleChoice.from_pretrained('roberta-base')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, classification_scores = outputs[:2]
"""
config_class = RobertaConfig
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "roberta"
Expand All @@ -351,7 +418,7 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)

self.apply(self.init_weights)
self.init_weights()

def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None):
Expand Down
Loading

0 comments on commit 5a81e79

Please sign in to comment.