From 3125b094fead8f39b8d8b90cd059647221b03835 Mon Sep 17 00:00:00 2001 From: calpt <36051308+calpt@users.noreply.github.com> Date: Sat, 17 Jul 2021 15:12:36 +0200 Subject: [PATCH 01/11] Dependency parsing head --- src/transformers/adapters/heads/__init__.py | 3 + .../adapters/{heads.py => heads/base.py} | 10 +- .../adapters/heads/dependency_parsing.py | 154 ++++++++++++++++++ src/transformers/adapters/models/bert.py | 6 + tests/test_adapter_heads.py | 37 ++++- 5 files changed, 199 insertions(+), 11 deletions(-) create mode 100644 src/transformers/adapters/heads/__init__.py rename src/transformers/adapters/{heads.py => heads/base.py} (99%) create mode 100644 src/transformers/adapters/heads/dependency_parsing.py diff --git a/src/transformers/adapters/heads/__init__.py b/src/transformers/adapters/heads/__init__.py new file mode 100644 index 0000000000..e12be6a821 --- /dev/null +++ b/src/transformers/adapters/heads/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa +from .base import * +from .dependency_parsing import * diff --git a/src/transformers/adapters/heads.py b/src/transformers/adapters/heads/base.py similarity index 99% rename from src/transformers/adapters/heads.py rename to src/transformers/adapters/heads/base.py index 497d185b8b..ab40f9669d 100644 --- a/src/transformers/adapters/heads.py +++ b/src/transformers/adapters/heads/base.py @@ -5,8 +5,8 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from ..file_utils import ModelOutput -from ..modeling_outputs import ( +from ...file_utils import ModelOutput +from ...modeling_outputs import ( MultipleChoiceModelOutput, QuestionAnsweringModelOutput, Seq2SeqModelOutput, @@ -15,9 +15,9 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from .composition import AdapterCompositionBlock, BatchSplit, Parallel, Stack -from .model_mixin import ModelWithHeadsAdaptersMixin -from .modeling import Activation_Function_Class +from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, Stack +from ..model_mixin import ModelWithHeadsAdaptersMixin +from ..modeling import Activation_Function_Class logger = logging.getLogger(__name__) diff --git a/src/transformers/adapters/heads/dependency_parsing.py b/src/transformers/adapters/heads/dependency_parsing.py new file mode 100644 index 0000000000..adcf859eab --- /dev/null +++ b/src/transformers/adapters/heads/dependency_parsing.py @@ -0,0 +1,154 @@ +""" +Code taken and modified from: https://github.com/Adapter-Hub/hgiyt. +Credits: "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual Language Models" (Rust et al., 2021) +https://arxiv.org/abs/2012.15613 +""" +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from .base import PredictionHead + + +# Credit: +# Class taken from https://github.com/yzhangcs/biaffine-parser +class Biaffine(nn.Module): + def __init__(self, n_in, n_out=1, bias_x=True, bias_y=True): + super(Biaffine, self).__init__() + + self.n_in = n_in + self.n_out = n_out + self.bias_x = bias_x + self.bias_y = bias_y + self.weight = nn.Parameter(torch.Tensor(n_out, n_in + bias_x, n_in + bias_y)) + self.init_weights() + + def extra_repr(self): + s = f"n_in={self.n_in}, n_out={self.n_out}" + if self.bias_x: + s += f", bias_x={self.bias_x}" + if self.bias_y: + s += f", bias_y={self.bias_y}" + + return s + + def init_weights(self): + nn.init.zeros_(self.weight) + + def forward(self, x, y): + if self.bias_x: + x = torch.cat((x, torch.ones_like(x[..., :1])), -1) + if self.bias_y: + y = torch.cat((y, torch.ones_like(y[..., :1])), -1) + + # [batch_size, n_out, seq_len, seq_len] + s = torch.einsum("bxi,oij,byj->boxy", x, self.weight, y) + return s + + +class BiaffineParsingHead(PredictionHead): + """ + Credit: G. Glavaš & I. Vulić + Based on paper "Is Supervised Syntactic Parsing Beneficial for Language Understanding? An Empirical Investigation" + (https://arxiv.org/pdf/2008.06788.pdf) + """ + + def __init__(self, model, head_name, num_labels=2, id2label=None): + super().__init__(head_name) + self.config = { + "head_type": "dependency_parsing", + "num_labels": num_labels, + "label2id": {label: id_ for id_, label in id2label.items()} if id2label is not None else None, + } + self.model_config = model.config + self.build(model) + + def build(self, model): + self.biaffine_arcs = Biaffine(n_in=model.config.hidden_size, bias_x=True, bias_y=False) + self.biaffine_rels = Biaffine( + n_in=model.config.hidden_size, n_out=self.config["num_labels"], bias_x=True, bias_y=True + ) + + self.dropout = nn.Dropout(model.config.hidden_dropout_prob) + + self.loss_fn = CrossEntropyLoss() + + self.train(model.training) # make sure training mode is consistent + + def forward( + self, outputs, cls_output=None, attention_mask=None, return_dict=False, word_starts=None, labels_arcs=None, labels_rels=None, **kwargs + ): + outs = self.dropout(outputs[0]) + word_outputs_deps = self._merge_subword_tokens(outs, word_starts) + + # adding the CLS representation as the representation for the "root" parse token + # cls_output = self.pooler_activation(self.pooler_dense(outs[:, 0])) + cls_output = outs[:, 0] + word_outputs_heads = torch.cat([cls_output.unsqueeze(1), word_outputs_deps], dim=1) + + arc_preds = self.biaffine_arcs(word_outputs_deps, word_outputs_heads) + arc_preds = arc_preds.squeeze() + outputs = (arc_preds,) + + rel_preds = self.biaffine_rels(word_outputs_deps, word_outputs_heads) + rel_preds = rel_preds.permute(0, 2, 3, 1) + outputs = (rel_preds,) + outputs + + loss = self._get_loss(arc_preds, rel_preds, labels_arcs, labels_rels, self.loss_fn) + + # TODO-AH return_dict + outputs = (loss,) + outputs + + if len(arc_preds.shape) == 2: + return loss, rel_preds, arc_preds.unsqueeze(0) + return outputs + + def _merge_subword_tokens(self, subword_outputs, word_starts): + instances = [] + max_seq_length = subword_outputs.shape[1] + + # handling instance by instance + for i in range(len(subword_outputs)): + subword_vecs = subword_outputs[i] + word_vecs = [] + starts = word_starts[i] + mask = starts.ne(self.model_config.pad_token_id) + starts = starts[mask] + for j in range(len(starts) - 1): + if starts[j + 1] <= 0: + break + + start = starts[j] + end = starts[j + 1] + vecs_range = subword_vecs[start:end] + word_vecs.append(torch.mean(vecs_range, 0).unsqueeze(0)) + + instances.append(word_vecs) + + t_insts = [] + zero_tens = torch.zeros(self.model_config.hidden_size).unsqueeze(0) + zero_tens = zero_tens.to("cuda" if torch.cuda.is_available() else "cpu") + + for inst in instances: + if len(inst) < max_seq_length: + for i in range(max_seq_length - len(inst)): + inst.append(zero_tens) + t_insts.append(torch.cat(inst, dim=0).unsqueeze(0)) + + w_tens = torch.cat(t_insts, dim=0) + return w_tens + + def _get_loss(self, arc_preds, rel_preds, labels_arc, labels_rel, loss_fn): + if len(arc_preds.shape) == 2: + arc_preds = arc_preds.unsqueeze(0) + + mask = labels_arc.ne(self.model_config.pad_token_id) + arc_scores, arcs = arc_preds[mask], labels_arc[mask] + loss = loss_fn(arc_scores, arcs) + + rel_scores, rels = rel_preds[mask], labels_rel[mask] + rel_scores = rel_scores[torch.arange(len(arcs)), arcs] + rel_loss = loss_fn(rel_scores, rels) + loss += rel_loss + + return loss diff --git a/src/transformers/adapters/models/bert.py b/src/transformers/adapters/models/bert.py index 8cc7a0107d..f95811cb76 100644 --- a/src/transformers/adapters/models/bert.py +++ b/src/transformers/adapters/models/bert.py @@ -5,6 +5,7 @@ from ..composition import AdapterCompositionBlock, parse_composition from ..heads import ( + BiaffineParsingHead, ClassificationHead, ModelWithFlexibleHeadsAdaptersMixin, MultiLabelClassificationHead, @@ -179,6 +180,7 @@ class BertModelHeadsMixin(ModelWithFlexibleHeadsAdaptersMixin): "tagging": TaggingHead, "multiple_choice": MultipleChoiceHead, "question_answering": QuestionAnsweringHead, + "dependency_parsing": BiaffineParsingHead, } def add_classification_head( @@ -256,3 +258,7 @@ def add_qa_head( ): head = QuestionAnsweringHead(self, head_name, num_labels, layers, activation_function, id2label) self.add_prediction_head(head, overwrite_ok) + + def add_dependency_parsing_head(self, head_name, num_labels=2, overwrite_ok=False, id2label=None): + head = BiaffineParsingHead(self, head_name, num_labels, id2label) + self.add_prediction_head(head, overwrite_ok) diff --git a/tests/test_adapter_heads.py b/tests/test_adapter_heads.py index b76fa55968..0ff5e99422 100644 --- a/tests/test_adapter_heads.py +++ b/tests/test_adapter_heads.py @@ -47,7 +47,7 @@ def run_prediction_head_test( def test_classification_head(self): if not hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_classification_head"): - return + self.skipTest("No classification head") model1, model2 = create_twin_models(AutoModelWithHeads, self.config) @@ -58,7 +58,7 @@ def test_classification_head(self): def test_multiple_choice_head(self): if not hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_multiple_choice_head"): - return + self.skipTest("No multiple choice head") model1, model2 = create_twin_models(AutoModelWithHeads, self.config) @@ -71,18 +71,20 @@ def test_multiple_choice_head(self): def test_tagging_head(self): if not hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_tagging_head"): - return + self.skipTest("No tagging head") model1, model2 = create_twin_models(AutoModelWithHeads, self.config) model1.add_tagging_head("dummy") label_dict = {} label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device) - self.run_prediction_head_test(model1, model2, "dummy", output_shape=(1, 128, 2), label_dict=label_dict) + self.run_prediction_head_test( + model1, model2, "dummy", output_shape=(1, self.seq_length, 2), label_dict=label_dict + ) def test_qa_head(self): if not hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_qa_head"): - return + self.skipTest("No QA head") model1, model2 = create_twin_models(AutoModelWithHeads, self.config) @@ -90,7 +92,30 @@ def test_qa_head(self): label_dict = {} label_dict["start_positions"] = torch.zeros(self.batch_size, dtype=torch.long, device=torch_device) label_dict["end_positions"] = torch.zeros(self.batch_size, dtype=torch.long, device=torch_device) - self.run_prediction_head_test(model1, model2, "dummy", output_shape=(1, 128), label_dict=label_dict) + self.run_prediction_head_test( + model1, model2, "dummy", output_shape=(1, self.seq_length), label_dict=label_dict + ) + + def test_dependency_parsing_head(self): + if not hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_dependency_parsing_head"): + self.skipTest("No dependency parsing head") + + model1, model2 = create_twin_models(AutoModelWithHeads, self.config) + + model1.add_dependency_parsing_head("dummy") + label_dict = {} + label_dict["labels_arcs"] = torch.zeros( + (self.batch_size, self.seq_length), dtype=torch.long, device=torch_device + ) + label_dict["labels_rels"] = torch.zeros( + (self.batch_size, self.seq_length), dtype=torch.long, device=torch_device + ) + label_dict["word_starts"] = torch.zeros( + (self.batch_size, self.seq_length), dtype=torch.long, device=torch_device + ) + self.run_prediction_head_test( + model1, model2, "dummy", output_shape=(1, self.seq_length, self.seq_length + 1, 2), label_dict=label_dict + ) def test_delete_head(self): model = AutoModelWithHeads.from_config(self.config()) From 5c2e7e6ad6a643b510a64635ccc3ff38cd6af27a Mon Sep 17 00:00:00 2001 From: calpt <36051308+calpt@users.noreply.github.com> Date: Sat, 17 Jul 2021 15:35:22 +0200 Subject: [PATCH 02/11] Copied example script from hgiyt --- examples/dependency-parsing/run_udp.py | 328 ++++++++++++++++++++++ examples/dependency-parsing/ud_dataset.py | 234 +++++++++++++++ examples/dependency-parsing/utils_udp.py | 310 ++++++++++++++++++++ 3 files changed, 872 insertions(+) create mode 100644 examples/dependency-parsing/run_udp.py create mode 100644 examples/dependency-parsing/ud_dataset.py create mode 100644 examples/dependency-parsing/utils_udp.py diff --git a/examples/dependency-parsing/run_udp.py b/examples/dependency-parsing/run_udp.py new file mode 100644 index 0000000000..7618f54576 --- /dev/null +++ b/examples/dependency-parsing/run_udp.py @@ -0,0 +1,328 @@ +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import Dict, Optional + +from modeling_biaffine import BertForBiaffineParsing +from transformers import ( + AdapterArguments, + AdapterConfig, + AdapterType, + AutoConfig, + AutoTokenizer, + HfArgumentParser, + set_seed, + setup_task_adapter_training, +) +from ud_dataset import Split, UDDataset +from utils_udp import UD_HEAD_LABELS, DependencyParsingTrainer, UDTrainingArguments + + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}, + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}, + ) + use_fast: bool = field(default=False, metadata={"help": "Set this flag to use fast tokenization."}) + # If you want to tweak more attributes on your tokenizer, you should do it in a distinct script, + # or just modify its tokenizer_config.json. + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}, + ) + replace_embeddings: bool = field(default=False, metadata={"help": "Whether or not to replace embeddings."}) + leave_out_twelvth: bool = field( + default=False, metadata={"help": "Whether or not to leave out adapters in twelvth layer"} + ) + do_lower_case: bool = field(default=False, metadata={"help": "Set this flag when using uncased model/tokenizer"}) + is_japanese: bool = field(default=False, metadata={"help": "Set this to true when using Japanese model/tokenizer"}) + mecab_dir: Optional[str] = field( + default=None, metadata={"help": "Path to mecab installation. Required when using Japanese model/tokenizer"} + ) + mecab_dic_dir: Optional[str] = field( + default=None, metadata={"help": "Path to mecab dictionary. Required when using Japanese model/tokenizer"} + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + data_dir: str = field(metadata={"help": "Path to train, dev, and test data files."}) + max_seq_length: int = field( + default=128, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets."}, + ) + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, UDTrainingArguments, AdapterArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args, adapter_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + else: + (model_args, data_args, training_args, adapter_args,) = parser.parse_args_into_dataclasses() + + if ( + os.path.exists(training_args.output_dir) + and os.listdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, + ) + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + training_args.local_rank, + training_args.device, + training_args.n_gpu, + bool(training_args.local_rank != -1), + training_args.fp16, + ) + logger.info("Training/evaluation parameters %s", training_args) + + # Set seed + set_seed(training_args.seed) + + # Prepare for UD dependency parsing task + labels = UD_HEAD_LABELS + label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)} + num_labels = len(labels) + + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + num_labels=num_labels, + id2label=label_map, + label2id={label: i for i, label in enumerate(labels)}, + cache_dir=model_args.cache_dir, + pad_token_id=-1, + ) + + if model_args.is_japanese: + assert model_args.mecab_dir is not None + assert model_args.mecab_dic_dir is not None + + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast, + do_lower_case=model_args.do_lower_case, + mecab_kwargs={"mecab_option": f"-r {model_args.mecab_dir} -d {model_args.mecab_dic_dir}"} + if model_args.is_japanese + else None, + ) + + model = BertForBiaffineParsing.from_pretrained( + model_args.model_name_or_path, config=config, cache_dir=model_args.cache_dir, + ) + + # Setup adapters + task_name = "udp" + language = adapter_args.language + if model_args.replace_embeddings: + model.resize_token_embeddings(len(tokenizer)) + + if model_args.leave_out_twelvth: + logger.info("Leaving out 12") + leave_out = [11] + else: + leave_out = [] + + setup_task_adapter_training( + model, task_name, adapter_args, leave_out=leave_out, with_embeddings=model_args.replace_embeddings + ) + if model_args.leave_out_twelvth: + if language in model.base_model.encoder.layer._modules["11"].output.layer_text_lang_adapters: + del model.base_model.encoder.layer._modules["11"].output.layer_text_lang_adapters[language] + logger.info("Deleted language adapter " + language + " in layer 12") + if language in model.base_model.encoder.layer._modules["11"].attention.output.attention_text_lang_adapters: + del model.base_model.encoder.layer._modules["11"].attention.output.attention_text_lang_adapters[language] + logger.info("Deleted language adapter " + language + " in layer 12") + + if adapter_args.train_adapter: + if language: + adapter_names = [[language], [task_name]] + else: + adapter_names = [[task_name]] + else: + adapter_names = None + + train_dataset = ( + UDDataset( + data_dir=data_args.data_dir, + tokenizer=tokenizer, + labels=labels, + max_seq_length=data_args.max_seq_length, + overwrite_cache=data_args.overwrite_cache, + mode=Split.train, + ) + if training_args.do_train + else None + ) + + eval_dataset = ( + UDDataset( + data_dir=data_args.data_dir, + tokenizer=tokenizer, + labels=labels, + max_seq_length=data_args.max_seq_length, + overwrite_cache=data_args.overwrite_cache, + mode=Split.dev, + ) + if training_args.do_eval + else None + ) + + # Initialize our Trainer + trainer = DependencyParsingTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + do_save_full_model=not adapter_args.train_adapter, + do_save_adapters=adapter_args.train_adapter, + adapter_names=adapter_names, + ) + + # Training + if training_args.do_train: + trainer.train( + model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None + ) + trainer.save_model() + # For convenience, we also re-save the tokenizer to the same directory, + # so that you can share your model easily on huggingface.co/models =) + if trainer.is_world_master(): + tokenizer.save_pretrained(training_args.output_dir) + + # Evaluation + results = {} + if training_args.do_eval: + logger.info("*** Evaluate ***") + + result = trainer.evaluate() + + output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt") + if trainer.is_world_master(): + with open(output_eval_file, "w") as writer: + logger.info("***** Eval results *****") + for key, value in result.items(): + logger.info(" %s = %s", key, value) + writer.write("%s = %s\n" % (key, value)) + + results.update(result) + + # Predict + if training_args.do_predict: + test_dataset = UDDataset( + data_dir=data_args.data_dir, + tokenizer=tokenizer, + labels=labels, + max_seq_length=data_args.max_seq_length, + overwrite_cache=data_args.overwrite_cache, + mode=Split.test, + ) + + logging.info("*** Test ***") + + if training_args.store_best_model: + logger.info("Loading best model for predictions.") + + if adapter_args.train_adapter: + if language: + lang_adapter_config = AdapterConfig.load( + config="pfeiffer", non_linearity="gelu", reduction_factor=2, leave_out=leave_out + ) + model.load_adapter( + os.path.join(training_args.output_dir, "best_model", language) + if training_args.do_train + else adapter_args.load_lang_adapter, + AdapterType.text_lang, + config=lang_adapter_config, + load_as=language, + ) + task_adapter_config = AdapterConfig.load( + config="pfeiffer", non_linearity="gelu", reduction_factor=16, leave_out=leave_out + ) + model.load_adapter( + os.path.join(training_args.output_dir, "best_model", task_name) + if training_args.do_train + else adapter_args.load_task_adapter, + AdapterType.text_task, + config=task_adapter_config, + load_as=task_name, + ) + if model_args.leave_out_twelvth: + if language in model.base_model.encoder.layer._modules["11"].output.layer_text_lang_adapters: + del model.base_model.encoder.layer._modules["11"].output.layer_text_lang_adapters[language] + logger.info("Deleted language adapter " + language + " in layer 12") + if ( + language + in model.base_model.encoder.layer._modules["11"].attention.output.attention_text_lang_adapters + ): + del model.base_model.encoder.layer._modules[ + "11" + ].attention.output.attention_text_lang_adapters[language] + logger.info("Deleted language adapter " + language + " in layer 12") + + if language: + adapter_names = [[language], [task_name]] + else: + adapter_names = [[task_name]] + else: + trainer.model = BertForBiaffineParsing.from_pretrained( + os.path.join(training_args.output_dir, "best_model"), + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + ).to(training_args.device) + + predictions, _, metrics = trainer.predict(test_dataset) + + output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt") + if trainer.is_world_master(): + with open(output_test_results_file, "w") as writer: + for key, value in metrics.items(): + logger.info(" %s = %s", key, value) + writer.write("%s = %s\n" % (key, value)) + + return results + + +if __name__ == "__main__": + main() diff --git a/examples/dependency-parsing/ud_dataset.py b/examples/dependency-parsing/ud_dataset.py new file mode 100644 index 0000000000..cd3589adba --- /dev/null +++ b/examples/dependency-parsing/ud_dataset.py @@ -0,0 +1,234 @@ +import glob +import logging +import os +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional, Union + +import numpy as np + +from filelock import FileLock +from transformers import PreTrainedTokenizer, is_torch_available + + +logger = logging.getLogger(__name__) + + +@dataclass +class InputExample: + """ + A single training/test example for universal dependency parsing. + + Args: + words: list. The words of the sequence. + head_labels: (Optional) list. The labels for each word's dependency head. This should be + specified for train and dev examples, but not for test examples. + rel_labels: (Optional) list. The labels for the relations between each word and its respective head. This should be + specified for train and dev examples, but not for test examples. + """ + + words: List[str] + head_labels: Optional[List[int]] + rel_labels: Optional[List[str]] + + +@dataclass +class InputFeatures: + """ + A single set of features of data. + Property names are the same names as the corresponding inputs to a BertForBiaffineParsing model. + """ + + input_ids: List[int] + attention_mask: List[int] + token_type_ids: List[int] + word_starts: List[int] + labels_arcs: List[int] + labels_rels: List[int] + + +class Split(Enum): + train = "train" + dev = "dev" + test = "test" + + +if is_torch_available(): + import torch + from torch.utils.data.dataset import Dataset + + class UDDataset(Dataset): + """ + Pytorch Dataset for universal dependency parsing. + """ + + features: List[InputFeatures] + + def __init__( + self, + data_dir: str, + tokenizer: PreTrainedTokenizer, + labels: List[str], + max_seq_length: Optional[int] = None, + overwrite_cache=False, + mode: Split = Split.train, + ): + # Load data features from cache or dataset file + cached_features_file = os.path.join( + data_dir, "cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length)), + ) + + # Make sure only the first process in distributed training processes the dataset, + # and the others will use the cache. + lock_path = cached_features_file + ".lock" + with FileLock(lock_path): + + if os.path.exists(cached_features_file) and not overwrite_cache: + logger.info(f"Loading features from cached file {cached_features_file}") + self.features = torch.load(cached_features_file) + else: + logger.info(f"Creating features from dataset file at {data_dir}") + examples = read_examples_from_file(data_dir, mode) + self.features = convert_examples_to_features( + examples=examples, label_list=labels, max_seq_length=max_seq_length, tokenizer=tokenizer + ) + logger.info(f"Saving features into cached file {cached_features_file}") + torch.save(self.features, cached_features_file) + + def __len__(self): + return len(self.features) + + def __getitem__(self, i) -> InputFeatures: + return self.features[i] + + +def get_file(data_dir: str, mode: Union[Split, str]) -> Optional[str]: + if isinstance(mode, Split): + mode = mode.value + else: + raise ValueError(f"Unsupported mode: {mode}") + + fp = os.path.join(data_dir, f"*-ud-{mode}.conllu") + _fp = glob.glob(fp) + if len(_fp) == 1: + return _fp[0] + elif len(_fp) == 0: + return None + else: + raise ValueError(f"Unsupported mode: {mode}") + + +def read_examples_from_file(data_dir, mode: Union[Split, str]) -> List[InputExample]: + + file_path = get_file(data_dir, mode) + examples = [] + + with open(file_path, "r", encoding="utf-8") as f: + words: List[str] = [] + head_labels: List[int] = [] + rel_labels: List[str] = [] + for line in f.readlines(): + tok = line.strip().split("\t") + if len(tok) < 2 or line[0] == "#": + if words: + examples.append(InputExample(words=words, head_labels=head_labels, rel_labels=rel_labels)) + words = [] + head_labels = [] + rel_labels = [] + if tok[0].isdigit(): + word, head, label = tok[1], tok[6], tok[7] + words.append(word) + head_labels.append(int(head)) + rel_labels.append(label.split(":")[0]) + if words: + examples.append(InputExample(words=words, head_labels=head_labels, rel_labels=rel_labels)) + return examples + + +def convert_examples_to_features( + examples: List[InputExample], + label_list: List[str], + max_seq_length: int, + tokenizer: PreTrainedTokenizer, + pad_token=-1, +) -> List[InputFeatures]: + """ Loads a data file into a list of `InputFeatures` + """ + + label_map = {label: i for i, label in enumerate(label_list)} + + features = [] + for (ex_index, example) in enumerate(examples): + if ex_index % 10_000 == 0: + logger.info("Writing example %d of %d", ex_index, len(examples)) + + tokens = [tokenizer.tokenize(w) for w in example.words] + word_lengths = [len(w) for w in tokens] + tokens_merged = [] + list(map(tokens_merged.extend, tokens)) + + if 0 in word_lengths: + logger.info("Invalid sequence with word length 0 filtered: %s", example.words) + continue + # Filter out sequences that are too long + if len(tokens_merged) >= (max_seq_length - 2): + logger.info("Sequence of len %d filtered: %s", len(tokens_merged), tokens_merged) + continue + + encoding = tokenizer.encode_plus( + tokens_merged, + add_special_tokens=True, + pad_to_max_length=True, + max_length=max_seq_length, + is_split_into_words=True, + return_token_type_ids=True, + return_attention_mask=True, + ) + + input_ids = encoding["input_ids"] + token_type_ids = encoding["token_type_ids"] + attention_mask = encoding["attention_mask"] + + pad_item = [pad_token] + + # pad or truncate arc labels + labels_arcs = example.head_labels + labels_arcs = labels_arcs + (max_seq_length - len(labels_arcs)) * pad_item + + # convert rel labels from map, pad or truncate if necessary + labels_rels = [label_map[i] for i in example.rel_labels] + labels_rels = labels_rels + (max_seq_length - len(labels_rels)) * pad_item + + # determine start indices of words, pad or truncate if necessary + word_starts = np.cumsum([1] + word_lengths).tolist() + word_starts = word_starts + (max_seq_length + 1 - len(word_starts)) * pad_item + + # sanity check lengths + assert len(input_ids) == max_seq_length + assert len(attention_mask) == max_seq_length + assert len(token_type_ids) == max_seq_length + assert len(labels_arcs) == max_seq_length + assert len(labels_rels) == max_seq_length + assert len(word_starts) == max_seq_length + 1 + + if ex_index < 5: + logger.info("*** Example ***") + logger.info("tokens: %s", " ".join([str(x) for x in tokens_merged])) + logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) + logger.info("attention_mask: %s", " ".join([str(x) for x in attention_mask])) + logger.info("token_type_ids: %s", " ".join([str(x) for x in token_type_ids])) + logger.info("labels_arcs: %s", " ".join([str(x) for x in labels_arcs])) + logger.info("labels_rels: %s", " ".join([str(x) for x in labels_rels])) + logger.info("word_starts: %s", " ".join([str(x) for x in word_starts])) + + features.append( + InputFeatures( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + word_starts=word_starts, + labels_arcs=labels_arcs, + labels_rels=labels_rels, + ) + ) + return features diff --git a/examples/dependency-parsing/utils_udp.py b/examples/dependency-parsing/utils_udp.py new file mode 100644 index 0000000000..2606529194 --- /dev/null +++ b/examples/dependency-parsing/utils_udp.py @@ -0,0 +1,310 @@ +import logging +import os +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torch.utils.data import DataLoader +from torch.utils.data.dataset import Dataset +from torch.utils.tensorboard import SummaryWriter +from tqdm.auto import tqdm + +from transformers import DataCollator, EvalPrediction, PreTrainedModel, Trainer, TrainingArguments +from transformers.trainer_utils import PredictionOutput +from transformers.training_args import is_tpu_available + + +if is_tpu_available(): + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as met + import torch_xla.distributed.parallel_loader as pl + +logger = logging.getLogger(__name__) + +UD_HEAD_LABELS = [ + "_", + "acl", + "advcl", + "advmod", + "amod", + "appos", + "aux", + "case", + "cc", + "ccomp", + "clf", + "compound", + "conj", + "cop", + "csubj", + "dep", + "det", + "discourse", + "dislocated", + "expl", + "fixed", + "flat", + "goeswith", + "iobj", + "list", + "mark", + "nmod", + "nsubj", + "nummod", + "obj", + "obl", + "orphan", + "parataxis", + "punct", + "reparandum", + "root", + "vocative", + "xcomp", +] + + +@dataclass +class UDTrainingArguments(TrainingArguments): + """ + Extends TrainingArguments for Universal Dependencies (UD) dependency parsing. + TrainingArguments is the subset of the arguments we use in our example scripts + **which relate to the training loop itself**. + + Using `HfArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + decode_mode: str = field(default="greedy", metadata={"help": "Whether to use mst decoding or greedy decoding"}) + store_best_model: bool = field(default=False, metadata={"help": "Whether to store best model during training."}) + metric_score: Optional[str] = field( + default=None, metadata={"help": "Metric used to determine best model during training."} + ) + + +class Metric(object): + def add(self, gold, prediction): + raise NotImplementedError + + def get_metric(self) -> Dict[str, float]: + raise NotImplementedError + + def reset(self): + raise NotImplementedError + + @staticmethod + def unpack(*tensors: torch.Tensor): + return (x.detach().cpu() if isinstance(x, torch.Tensor) else x for x in tensors) + + +class ParsingMetric(Metric): + """ + based on allennlp.training.metrics.AttachmentScores + Computes labeled and unlabeled attachment scores for a dependency parse. Note that the input + to this metric is the sampled predictions, not the distribution itself. + """ + + def __init__(self): + self._labeled_correct = 0.0 + self._unlabeled_correct = 0.0 + self._total_words = 0.0 + + def add( + self, + gold_indices: torch.Tensor, + gold_labels: torch.Tensor, + predicted_indices: torch.Tensor, + predicted_labels: torch.Tensor, + ): + """ + Parameters + ---------- + predicted_indices : ``torch.Tensor``, required. + A tensor of head index predictions of shape (batch_size, timesteps). + predicted_labels : ``torch.Tensor``, required. + A tensor of arc label predictions of shape (batch_size, timesteps). + gold_indices : ``torch.Tensor``, required. + A tensor of the same shape as ``predicted_indices``. + gold_labels : ``torch.Tensor``, required. + A tensor of the same shape as ``predicted_labels``. + """ + unwrapped = self.unpack(predicted_indices, predicted_labels, gold_indices, gold_labels) + predicted_indices, predicted_labels, gold_indices, gold_labels = unwrapped + + predicted_indices = predicted_indices.long() + predicted_labels = predicted_labels.long() + gold_indices = gold_indices.long() + gold_labels = gold_labels.long() + + correct_indices = predicted_indices.eq(gold_indices).long() + correct_labels = predicted_labels.eq(gold_labels).long() + correct_labels_and_indices = correct_indices * correct_labels + + self._unlabeled_correct += correct_indices.sum().item() + self._labeled_correct += correct_labels_and_indices.sum().item() + self._total_words += correct_indices.numel() + + def get_metric(self): + unlabeled_attachment_score = 0.0 + labeled_attachment_score = 0.0 + if self._total_words > 0.0: + unlabeled_attachment_score = self._unlabeled_correct / self._total_words + labeled_attachment_score = self._labeled_correct / self._total_words + return { + "uas": unlabeled_attachment_score * 100, + "las": labeled_attachment_score * 100, + } + + def reset(self): + self._labeled_correct = 0.0 + self._unlabeled_correct = 0.0 + self._total_words = 0.0 + + +class DependencyParsingTrainer(Trainer): + args: UDTrainingArguments + + def __init__( + self, + model: PreTrainedModel, + args: UDTrainingArguments, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Dataset] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + prediction_loss_only=False, + do_save_full_model: bool = True, + do_save_adapters: bool = False, + do_save_adapter_fusion: bool = False, + adapter_names: Optional[List[List[str]]] = None, + tb_writer: Optional["SummaryWriter"] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None, + ): + super().__init__( + model, + args, + data_collator, + train_dataset, + eval_dataset, + compute_metrics, + prediction_loss_only, + do_save_full_model, + do_save_adapters, + do_save_adapter_fusion, + adapter_names, + tb_writer, + optimizers, + ) + # for finding the best model. + # assumes higher is better + self.best_score = 0.0 + # torch.autograd.set_detect_anomaly(True) + + def evaluate( + self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None, + ) -> Dict[str, float]: + """ + Run evaluation and return metrics. + + The calling script will be responsible for providing a method to compute metrics, as they are + task-dependent. + + Args: + eval_dataset: (Optional) Pass a dataset if you wish to override + the one on the instance. + Returns: + A dict containing: + - the eval loss + - the potential metrics computed from the predictions + """ + eval_dataloader = self.get_eval_dataloader(eval_dataset) + + output = self._prediction_loop(eval_dataloader, description="Evaluation") + + if self.args.store_best_model: + self.store_best_model(output) + + self._log(output.metrics) + + if self.args.tpu_metrics_debug: + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + + return output.metrics + + def store_best_model(self, output): + + if self.args.metric_score not in output.metrics: + raise Exception( + "Metric %s not in output.\nThe following output was generated: %s", + str(self.args.metric_score), + str(output), + ) + + if output.metrics[self.args.metric_score] > self.best_score: + self.best_score = output.metrics[self.args.metric_score] + # Save model checkpoint + self.save_model(os.path.join(self.args.output_dir, "best_model")) + with open(os.path.join(self.args.output_dir, "best_model", "output.txt"), "w") as f: + f.write(str(output.metrics)) + + def _prediction_loop( + self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None + ) -> PredictionOutput: + """ + Prediction/evaluation loop, shared by `evaluate()` and `predict()`. + + Works both with or without labels. + """ + + prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only + + model = self.model + # multi-gpu eval + if self.args.n_gpu > 1: + model = torch.nn.DataParallel(model) + else: + model = self.model + # Note: in torch.distributed mode, there's no point in wrapping the model + # inside a DistributedDataParallel as we'll be under `no_grad` anyways. + + batch_size = dataloader.batch_size + logger.info("***** Running %s *****", description) + logger.info(" Num examples = %d", self.num_examples(dataloader)) + logger.info(" Batch size = %d", batch_size) + logger.info(" Decode mode = %s", self.args.decode_mode) + eval_losses: List[float] = [] + model.eval() + + metric = ParsingMetric() + + if is_tpu_available(): + dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device) + + for inputs in tqdm(dataloader, desc=description): + + for k, v in inputs.items(): + inputs[k] = v.to(self.args.device) + + with torch.no_grad(): + step_eval_loss, rel_preds, arc_preds = model(**inputs, adapter_names=self.adapter_names) + + eval_losses += [step_eval_loss.mean().item()] + + mask = inputs["labels_arcs"].ne(self.model.config.pad_token_id) + predictions_arcs = torch.argmax(arc_preds, dim=-1)[mask] + + labels_arcs = inputs["labels_arcs"][mask] + + predictions_rels, labels_rels = rel_preds[mask], inputs["labels_rels"][mask] + predictions_rels = predictions_rels[torch.arange(len(labels_arcs)), labels_arcs] + predictions_rels = torch.argmax(predictions_rels, dim=-1) + + metric.add(labels_arcs, labels_rels, predictions_arcs, predictions_rels) + + results = metric.get_metric() + results[f"{description}_loss"] = np.mean(eval_losses) + + # Add predictions_rels to output, even though we are only interested in the metrics + return PredictionOutput(predictions=predictions_rels, label_ids=None, metrics=results) From 4bd1909dfed3a7f9cf34b5ce174237583aac2c21 Mon Sep 17 00:00:00 2001 From: calpt <36051308+calpt@users.noreply.github.com> Date: Sat, 17 Jul 2021 20:34:53 +0200 Subject: [PATCH 03/11] Updated run_udp --- examples/dependency-parsing/preprocessing.py | 95 ++++++++ examples/dependency-parsing/requirements.txt | 3 + examples/dependency-parsing/run_udp.py | 174 +++++++------- examples/dependency-parsing/ud_dataset.py | 234 ------------------- examples/dependency-parsing/utils_udp.py | 107 ++++++--- 5 files changed, 262 insertions(+), 351 deletions(-) create mode 100644 examples/dependency-parsing/preprocessing.py create mode 100644 examples/dependency-parsing/requirements.txt delete mode 100644 examples/dependency-parsing/ud_dataset.py diff --git a/examples/dependency-parsing/preprocessing.py b/examples/dependency-parsing/preprocessing.py new file mode 100644 index 0000000000..ee96600a18 --- /dev/null +++ b/examples/dependency-parsing/preprocessing.py @@ -0,0 +1,95 @@ +from collections import defaultdict +from typing import List + +import datasets +import numpy as np + +from transformers import PreTrainedTokenizer + + +def preprocess_dataset( + dataset: datasets.DatasetDict, + tokenizer: PreTrainedTokenizer, + label_list: List[str], + data_args, + pad_token_id=-1, +): + label_map = {label: i for i, label in enumerate(label_list)} + + def encode_batch(examples): + features = defaultdict(list) + for words, heads, deprels in zip(examples["tokens"], examples["head"], examples["deprel"]): + # clean up + i = 0 + while i < len(heads): + if heads[i] == "None": + del words[i] + del heads[i] + del deprels[i] + i += 1 + tokens = [tokenizer.tokenize(w) for w in words] + word_lengths = [len(w) for w in tokens] + tokens_merged = [] + list(map(tokens_merged.extend, tokens)) + + if 0 in word_lengths: + continue + # Filter out sequences that are too long + if len(tokens_merged) >= (data_args.max_seq_length - 2): + continue + + encoding = tokenizer( + words, + add_special_tokens=True, + padding="max_length", + truncation=True, + max_length=data_args.max_seq_length, + is_split_into_words=True, + return_token_type_ids=True, + return_attention_mask=True, + ) + + input_ids = encoding["input_ids"] + token_type_ids = encoding["token_type_ids"] + attention_mask = encoding["attention_mask"] + + pad_item = [pad_token_id] + + # pad or truncate arc labels + labels_arcs = [int(h) for h in heads] + labels_arcs = labels_arcs + (data_args.max_seq_length - len(labels_arcs)) * pad_item + + # convert rel labels from map, pad or truncate if necessary + labels_rels = [label_map[i.split(":")[0]] for i in deprels] + labels_rels = labels_rels + (data_args.max_seq_length - len(labels_rels)) * pad_item + + # determine start indices of words, pad or truncate if necessary + word_starts = np.cumsum([1] + word_lengths).tolist() + word_starts = word_starts + (data_args.max_seq_length + 1 - len(word_starts)) * pad_item + + # sanity check lengths + assert len(input_ids) == data_args.max_seq_length + assert len(attention_mask) == data_args.max_seq_length + assert len(token_type_ids) == data_args.max_seq_length + assert len(labels_arcs) == data_args.max_seq_length + assert len(labels_rels) == data_args.max_seq_length + assert len(word_starts) == data_args.max_seq_length + 1 + + features["input_ids"].append(input_ids) + features["attention_mask"].append(attention_mask) + features["token_type_ids"].append(token_type_ids) + features["word_starts"].append(word_starts) + features["labels_arcs"].append(labels_arcs) + features["labels_rels"].append(labels_rels) + + return dict(features) + + # Expects columns in all splits to be identical + remove_columns = dataset.column_names["train"] + dataset = dataset.map( + encode_batch, + batched=True, + load_from_cache_file=not data_args.overwrite_cache, + remove_columns=remove_columns, + ) + return dataset diff --git a/examples/dependency-parsing/requirements.txt b/examples/dependency-parsing/requirements.txt new file mode 100644 index 0000000000..b316ccf49b --- /dev/null +++ b/examples/dependency-parsing/requirements.txt @@ -0,0 +1,3 @@ +datasets >= 1.8.0 +torch >= 1.3 +conllu diff --git a/examples/dependency-parsing/run_udp.py b/examples/dependency-parsing/run_udp.py index 7618f54576..1d2529ae4e 100644 --- a/examples/dependency-parsing/run_udp.py +++ b/examples/dependency-parsing/run_udp.py @@ -4,18 +4,19 @@ from dataclasses import dataclass, field from typing import Dict, Optional -from modeling_biaffine import BertForBiaffineParsing +from datasets import load_dataset + +import transformers.adapters.composition as AC +from preprocessing import preprocess_dataset from transformers import ( - AdapterArguments, AdapterConfig, - AdapterType, AutoConfig, + AutoModelWithHeads, AutoTokenizer, HfArgumentParser, + MultiLingAdapterArguments, set_seed, - setup_task_adapter_training, ) -from ud_dataset import Split, UDDataset from utils_udp import UD_HEAD_LABELS, DependencyParsingTrainer, UDTrainingArguments @@ -63,7 +64,7 @@ class DataTrainingArguments: Arguments pertaining to what data we are going to input our model for training and eval. """ - data_dir: str = field(metadata={"help": "Path to train, dev, and test data files."}) + task_name: str = field(metadata={"help": "The identifier of the Universal Dependencies dataset to train on."}) max_seq_length: int = field( default=128, metadata={ @@ -80,7 +81,7 @@ def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. - parser = HfArgumentParser((ModelArguments, DataTrainingArguments, UDTrainingArguments, AdapterArguments)) + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, UDTrainingArguments, MultiLingAdapterArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. @@ -147,15 +148,18 @@ def main(): else None, ) - model = BertForBiaffineParsing.from_pretrained( + # The task name (with prefix) + task_name = "ud_" + data_args.task_name + language = adapter_args.language + + model = AutoModelWithHeads.from_pretrained( model_args.model_name_or_path, config=config, cache_dir=model_args.cache_dir, ) - - # Setup adapters - task_name = "udp" - language = adapter_args.language - if model_args.replace_embeddings: - model.resize_token_embeddings(len(tokenizer)) + model.add_dependency_parsing_head( + task_name, + num_labels=num_labels, + id2label=label_map, + ) if model_args.leave_out_twelvth: logger.info("Leaving out 12") @@ -163,60 +167,73 @@ def main(): else: leave_out = [] - setup_task_adapter_training( - model, task_name, adapter_args, leave_out=leave_out, with_embeddings=model_args.replace_embeddings - ) - if model_args.leave_out_twelvth: - if language in model.base_model.encoder.layer._modules["11"].output.layer_text_lang_adapters: - del model.base_model.encoder.layer._modules["11"].output.layer_text_lang_adapters[language] - logger.info("Deleted language adapter " + language + " in layer 12") - if language in model.base_model.encoder.layer._modules["11"].attention.output.attention_text_lang_adapters: - del model.base_model.encoder.layer._modules["11"].attention.output.attention_text_lang_adapters[language] - logger.info("Deleted language adapter " + language + " in layer 12") - + # Setup adapters if adapter_args.train_adapter: - if language: - adapter_names = [[language], [task_name]] + # check if adapter already exists, otherwise add it + if task_name not in model.config.adapters: + # resolve the adapter config + adapter_config = AdapterConfig.load( + adapter_args.adapter_config, + non_linearity=adapter_args.adapter_non_linearity, + reduction_factor=adapter_args.adapter_reduction_factor, + leave_out=leave_out, + ) + # load a pre-trained from Hub if specified + if adapter_args.load_adapter: + model.load_adapter( + adapter_args.load_adapter, + config=adapter_config, + load_as=task_name, + leave_out=leave_out, + ) + # otherwise, add a fresh adapter + else: + model.add_adapter(task_name, config=adapter_config) + # optionally load a pre-trained language adapter + if adapter_args.load_lang_adapter: + # resolve the language adapter config + lang_adapter_config = AdapterConfig.load( + adapter_args.lang_adapter_config, + non_linearity=adapter_args.lang_adapter_non_linearity, + reduction_factor=adapter_args.lang_adapter_reduction_factor, + leave_out=leave_out, + ) + # load the language adapter from Hub + lang_adapter_name = model.load_adapter( + adapter_args.load_lang_adapter, + config=lang_adapter_config, + load_as=adapter_args.language, + leave_out=leave_out, + ) + else: + lang_adapter_name = None + # Freeze all model weights except of those of this adapter + model.train_adapter([task_name]) + # Set the adapters to be used in every forward pass + if lang_adapter_name: + model.set_active_adapters(AC.Stack(lang_adapter_name, task_name)) else: - adapter_names = [[task_name]] + model.set_active_adapters(task_name) else: - adapter_names = None - - train_dataset = ( - UDDataset( - data_dir=data_args.data_dir, - tokenizer=tokenizer, - labels=labels, - max_seq_length=data_args.max_seq_length, - overwrite_cache=data_args.overwrite_cache, - mode=Split.train, - ) - if training_args.do_train - else None - ) + if adapter_args.load_adapter or adapter_args.load_lang_adapter: + raise ValueError( + "Adapters can only be loaded in adapters training mode." + "Use --train_adapter to enable adapter training" + ) - eval_dataset = ( - UDDataset( - data_dir=data_args.data_dir, - tokenizer=tokenizer, - labels=labels, - max_seq_length=data_args.max_seq_length, - overwrite_cache=data_args.overwrite_cache, - mode=Split.dev, - ) - if training_args.do_eval - else None - ) + # Load and preprocess dataset + dataset = load_dataset("universal_dependencies", data_args.task_name) + dataset = preprocess_dataset(dataset, tokenizer, labels, data_args, pad_token_id=-1) # Initialize our Trainer + training_args.remove_unused_columns = False trainer = DependencyParsingTrainer( model=model, args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, + train_dataset=dataset["train"], + eval_dataset=dataset["validation"], do_save_full_model=not adapter_args.train_adapter, do_save_adapters=adapter_args.train_adapter, - adapter_names=adapter_names, ) # Training @@ -227,7 +244,7 @@ def main(): trainer.save_model() # For convenience, we also re-save the tokenizer to the same directory, # so that you can share your model easily on huggingface.co/models =) - if trainer.is_world_master(): + if trainer.is_world_process_zero(): tokenizer.save_pretrained(training_args.output_dir) # Evaluation @@ -238,7 +255,7 @@ def main(): result = trainer.evaluate() output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt") - if trainer.is_world_master(): + if trainer.is_world_process_zero(): with open(output_eval_file, "w") as writer: logger.info("***** Eval results *****") for key, value in result.items(): @@ -249,15 +266,6 @@ def main(): # Predict if training_args.do_predict: - test_dataset = UDDataset( - data_dir=data_args.data_dir, - tokenizer=tokenizer, - labels=labels, - max_seq_length=data_args.max_seq_length, - overwrite_cache=data_args.overwrite_cache, - mode=Split.test, - ) - logging.info("*** Test ***") if training_args.store_best_model: @@ -272,9 +280,9 @@ def main(): os.path.join(training_args.output_dir, "best_model", language) if training_args.do_train else adapter_args.load_lang_adapter, - AdapterType.text_lang, config=lang_adapter_config, load_as=language, + leave_out=leave_out, ) task_adapter_config = AdapterConfig.load( config="pfeiffer", non_linearity="gelu", reduction_factor=16, leave_out=leave_out @@ -282,40 +290,28 @@ def main(): model.load_adapter( os.path.join(training_args.output_dir, "best_model", task_name) if training_args.do_train - else adapter_args.load_task_adapter, - AdapterType.text_task, + else adapter_args.load_adapter, config=task_adapter_config, load_as=task_name, + leave_out=leave_out, ) - if model_args.leave_out_twelvth: - if language in model.base_model.encoder.layer._modules["11"].output.layer_text_lang_adapters: - del model.base_model.encoder.layer._modules["11"].output.layer_text_lang_adapters[language] - logger.info("Deleted language adapter " + language + " in layer 12") - if ( - language - in model.base_model.encoder.layer._modules["11"].attention.output.attention_text_lang_adapters - ): - del model.base_model.encoder.layer._modules[ - "11" - ].attention.output.attention_text_lang_adapters[language] - logger.info("Deleted language adapter " + language + " in layer 12") - if language: - adapter_names = [[language], [task_name]] + model.set_active_adapters(AC.Stack(lang_adapter_name, task_name)) else: - adapter_names = [[task_name]] + model.set_active_adapters(task_name) + model.to(training_args.device) else: - trainer.model = BertForBiaffineParsing.from_pretrained( + trainer.model = AutoModelWithHeads.from_pretrained( os.path.join(training_args.output_dir, "best_model"), from_tf=bool(".ckpt" in model_args.model_name_or_path), config=config, cache_dir=model_args.cache_dir, ).to(training_args.device) - predictions, _, metrics = trainer.predict(test_dataset) + predictions, _, metrics = trainer.predict(dataset["test"]) output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt") - if trainer.is_world_master(): + if trainer.is_world_process_zero(): with open(output_test_results_file, "w") as writer: for key, value in metrics.items(): logger.info(" %s = %s", key, value) diff --git a/examples/dependency-parsing/ud_dataset.py b/examples/dependency-parsing/ud_dataset.py deleted file mode 100644 index cd3589adba..0000000000 --- a/examples/dependency-parsing/ud_dataset.py +++ /dev/null @@ -1,234 +0,0 @@ -import glob -import logging -import os -from dataclasses import dataclass -from enum import Enum -from typing import List, Optional, Union - -import numpy as np - -from filelock import FileLock -from transformers import PreTrainedTokenizer, is_torch_available - - -logger = logging.getLogger(__name__) - - -@dataclass -class InputExample: - """ - A single training/test example for universal dependency parsing. - - Args: - words: list. The words of the sequence. - head_labels: (Optional) list. The labels for each word's dependency head. This should be - specified for train and dev examples, but not for test examples. - rel_labels: (Optional) list. The labels for the relations between each word and its respective head. This should be - specified for train and dev examples, but not for test examples. - """ - - words: List[str] - head_labels: Optional[List[int]] - rel_labels: Optional[List[str]] - - -@dataclass -class InputFeatures: - """ - A single set of features of data. - Property names are the same names as the corresponding inputs to a BertForBiaffineParsing model. - """ - - input_ids: List[int] - attention_mask: List[int] - token_type_ids: List[int] - word_starts: List[int] - labels_arcs: List[int] - labels_rels: List[int] - - -class Split(Enum): - train = "train" - dev = "dev" - test = "test" - - -if is_torch_available(): - import torch - from torch.utils.data.dataset import Dataset - - class UDDataset(Dataset): - """ - Pytorch Dataset for universal dependency parsing. - """ - - features: List[InputFeatures] - - def __init__( - self, - data_dir: str, - tokenizer: PreTrainedTokenizer, - labels: List[str], - max_seq_length: Optional[int] = None, - overwrite_cache=False, - mode: Split = Split.train, - ): - # Load data features from cache or dataset file - cached_features_file = os.path.join( - data_dir, "cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length)), - ) - - # Make sure only the first process in distributed training processes the dataset, - # and the others will use the cache. - lock_path = cached_features_file + ".lock" - with FileLock(lock_path): - - if os.path.exists(cached_features_file) and not overwrite_cache: - logger.info(f"Loading features from cached file {cached_features_file}") - self.features = torch.load(cached_features_file) - else: - logger.info(f"Creating features from dataset file at {data_dir}") - examples = read_examples_from_file(data_dir, mode) - self.features = convert_examples_to_features( - examples=examples, label_list=labels, max_seq_length=max_seq_length, tokenizer=tokenizer - ) - logger.info(f"Saving features into cached file {cached_features_file}") - torch.save(self.features, cached_features_file) - - def __len__(self): - return len(self.features) - - def __getitem__(self, i) -> InputFeatures: - return self.features[i] - - -def get_file(data_dir: str, mode: Union[Split, str]) -> Optional[str]: - if isinstance(mode, Split): - mode = mode.value - else: - raise ValueError(f"Unsupported mode: {mode}") - - fp = os.path.join(data_dir, f"*-ud-{mode}.conllu") - _fp = glob.glob(fp) - if len(_fp) == 1: - return _fp[0] - elif len(_fp) == 0: - return None - else: - raise ValueError(f"Unsupported mode: {mode}") - - -def read_examples_from_file(data_dir, mode: Union[Split, str]) -> List[InputExample]: - - file_path = get_file(data_dir, mode) - examples = [] - - with open(file_path, "r", encoding="utf-8") as f: - words: List[str] = [] - head_labels: List[int] = [] - rel_labels: List[str] = [] - for line in f.readlines(): - tok = line.strip().split("\t") - if len(tok) < 2 or line[0] == "#": - if words: - examples.append(InputExample(words=words, head_labels=head_labels, rel_labels=rel_labels)) - words = [] - head_labels = [] - rel_labels = [] - if tok[0].isdigit(): - word, head, label = tok[1], tok[6], tok[7] - words.append(word) - head_labels.append(int(head)) - rel_labels.append(label.split(":")[0]) - if words: - examples.append(InputExample(words=words, head_labels=head_labels, rel_labels=rel_labels)) - return examples - - -def convert_examples_to_features( - examples: List[InputExample], - label_list: List[str], - max_seq_length: int, - tokenizer: PreTrainedTokenizer, - pad_token=-1, -) -> List[InputFeatures]: - """ Loads a data file into a list of `InputFeatures` - """ - - label_map = {label: i for i, label in enumerate(label_list)} - - features = [] - for (ex_index, example) in enumerate(examples): - if ex_index % 10_000 == 0: - logger.info("Writing example %d of %d", ex_index, len(examples)) - - tokens = [tokenizer.tokenize(w) for w in example.words] - word_lengths = [len(w) for w in tokens] - tokens_merged = [] - list(map(tokens_merged.extend, tokens)) - - if 0 in word_lengths: - logger.info("Invalid sequence with word length 0 filtered: %s", example.words) - continue - # Filter out sequences that are too long - if len(tokens_merged) >= (max_seq_length - 2): - logger.info("Sequence of len %d filtered: %s", len(tokens_merged), tokens_merged) - continue - - encoding = tokenizer.encode_plus( - tokens_merged, - add_special_tokens=True, - pad_to_max_length=True, - max_length=max_seq_length, - is_split_into_words=True, - return_token_type_ids=True, - return_attention_mask=True, - ) - - input_ids = encoding["input_ids"] - token_type_ids = encoding["token_type_ids"] - attention_mask = encoding["attention_mask"] - - pad_item = [pad_token] - - # pad or truncate arc labels - labels_arcs = example.head_labels - labels_arcs = labels_arcs + (max_seq_length - len(labels_arcs)) * pad_item - - # convert rel labels from map, pad or truncate if necessary - labels_rels = [label_map[i] for i in example.rel_labels] - labels_rels = labels_rels + (max_seq_length - len(labels_rels)) * pad_item - - # determine start indices of words, pad or truncate if necessary - word_starts = np.cumsum([1] + word_lengths).tolist() - word_starts = word_starts + (max_seq_length + 1 - len(word_starts)) * pad_item - - # sanity check lengths - assert len(input_ids) == max_seq_length - assert len(attention_mask) == max_seq_length - assert len(token_type_ids) == max_seq_length - assert len(labels_arcs) == max_seq_length - assert len(labels_rels) == max_seq_length - assert len(word_starts) == max_seq_length + 1 - - if ex_index < 5: - logger.info("*** Example ***") - logger.info("tokens: %s", " ".join([str(x) for x in tokens_merged])) - logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) - logger.info("attention_mask: %s", " ".join([str(x) for x in attention_mask])) - logger.info("token_type_ids: %s", " ".join([str(x) for x in token_type_ids])) - logger.info("labels_arcs: %s", " ".join([str(x) for x in labels_arcs])) - logger.info("labels_rels: %s", " ".join([str(x) for x in labels_rels])) - logger.info("word_starts: %s", " ".join([str(x) for x in word_starts])) - - features.append( - InputFeatures( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - word_starts=word_starts, - labels_arcs=labels_arcs, - labels_rels=labels_rels, - ) - ) - return features diff --git a/examples/dependency-parsing/utils_udp.py b/examples/dependency-parsing/utils_udp.py index 2606529194..e7d2406e00 100644 --- a/examples/dependency-parsing/utils_udp.py +++ b/examples/dependency-parsing/utils_udp.py @@ -1,24 +1,31 @@ +import collections import logging import os from dataclasses import dataclass, field -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from torch.utils.data import DataLoader from torch.utils.data.dataset import Dataset -from torch.utils.tensorboard import SummaryWriter -from tqdm.auto import tqdm - -from transformers import DataCollator, EvalPrediction, PreTrainedModel, Trainer, TrainingArguments +from tqdm import tqdm + +from transformers import ( + DataCollator, + EvalPrediction, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainerCallback, + TrainingArguments, + is_torch_tpu_available, +) from transformers.trainer_utils import PredictionOutput -from transformers.training_args import is_tpu_available -if is_tpu_available(): +if is_torch_tpu_available(): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met - import torch_xla.distributed.parallel_loader as pl logger = logging.getLogger(__name__) @@ -163,23 +170,23 @@ def reset(self): class DependencyParsingTrainer(Trainer): - args: UDTrainingArguments - def __init__( self, - model: PreTrainedModel, - args: UDTrainingArguments, + model: Union[PreTrainedModel, torch.nn.Module] = None, + args: UDTrainingArguments = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Dataset] = None, + tokenizer: Optional["PreTrainedTokenizerBase"] = None, + model_init: Callable[[], PreTrainedModel] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, - prediction_loss_only=False, + callbacks: Optional[List[TrainerCallback]] = None, do_save_full_model: bool = True, do_save_adapters: bool = False, do_save_adapter_fusion: bool = False, adapter_names: Optional[List[List[str]]] = None, - tb_writer: Optional["SummaryWriter"] = None, - optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + **kwargs, ): super().__init__( model, @@ -187,16 +194,17 @@ def __init__( data_collator, train_dataset, eval_dataset, + tokenizer, + model_init, compute_metrics, - prediction_loss_only, + callbacks, do_save_full_model, do_save_adapters, do_save_adapter_fusion, adapter_names, - tb_writer, optimizers, + **kwargs, ) - # for finding the best model. # assumes higher is better self.best_score = 0.0 # torch.autograd.set_detect_anomaly(True) @@ -225,7 +233,7 @@ def evaluate( if self.args.store_best_model: self.store_best_model(output) - self._log(output.metrics) + self.log(output.metrics) if self.args.tpu_metrics_debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) @@ -233,6 +241,49 @@ def evaluate( return output.metrics + def predict( + self, test_dataset: Dataset + ) -> PredictionOutput: + """ + Run prediction and returns predictions and potential metrics. + + Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method + will also return metrics, like in :obj:`evaluate()`. + + Args: + test_dataset (:obj:`Dataset`): + Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the + ``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__` + ignore_keys (:obj:`Lst[str]`, `optional`): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"test"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "test_bleu" if the prefix is "test" (default) + + .. note:: + + If your predictions or labels have different sequence length (for instance because you're doing dynamic + padding in a token classification task) the predictions will be padded (on the right) to allow for + concatenation into one array. The padding index is -100. + + Returns: `NamedTuple` A namedtuple with the following keys: + + - predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`. + - label_ids (:obj:`np.ndarray`, `optional`): The labels (if the dataset contained some). + - metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset + contained labels). + """ + test_dataloader = self.get_test_dataloader(test_dataset) + + output = self._prediction_loop( + test_dataloader, description="Prediction" + ) + + self.log(output.metrics) + + return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics) + def store_best_model(self, output): if self.args.metric_score not in output.metrics: @@ -253,12 +304,15 @@ def _prediction_loop( self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None ) -> PredictionOutput: """ - Prediction/evaluation loop, shared by `evaluate()` and `predict()`. - - Works both with or without labels. - """ + Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`. + Works both with or without labels. + """ - prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only + if not isinstance(dataloader.dataset, collections.abc.Sized): + raise ValueError("dataset must implement __len__") + prediction_loss_only = ( + prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only + ) model = self.model # multi-gpu eval @@ -279,16 +333,13 @@ def _prediction_loop( metric = ParsingMetric() - if is_tpu_available(): - dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device) - for inputs in tqdm(dataloader, desc=description): for k, v in inputs.items(): inputs[k] = v.to(self.args.device) with torch.no_grad(): - step_eval_loss, rel_preds, arc_preds = model(**inputs, adapter_names=self.adapter_names) + step_eval_loss, rel_preds, arc_preds = model(**inputs) eval_losses += [step_eval_loss.mean().item()] From 3409dd0d1a6b067e31cf5fd4194e295d8d205f98 Mon Sep 17 00:00:00 2001 From: calpt <36051308+calpt@users.noreply.github.com> Date: Mon, 19 Jul 2021 10:58:02 +0200 Subject: [PATCH 04/11] Documentation & minor fixes --- examples/README.md | 1 + examples/dependency-parsing/README.md | 38 +++++++++++++ examples/dependency-parsing/preprocessing.py | 2 +- examples/dependency-parsing/run_udp.py | 28 ++++++++-- examples/dependency-parsing/utils_udp.py | 19 ++++--- .../adapters/heads/dependency_parsing.py | 56 ++++++++++++++----- src/transformers/adapters/models/bert.py | 10 ++++ 7 files changed, 124 insertions(+), 30 deletions(-) create mode 100644 examples/dependency-parsing/README.md diff --git a/examples/README.md b/examples/README.md index e3400e1936..f5c2207879 100644 --- a/examples/README.md +++ b/examples/README.md @@ -44,6 +44,7 @@ Currently, scripts for these tasks support adapters: | [**`text-generation`**](https://github.com/adapter-hub/adapter-transformers/tree/master/examples/text-generation) | Text generation, e.g. using GPT-2 | [**`token-classification`**](https://github.com/adapter-hub/adapter-transformers/tree/master/examples/token-classification) | NER, e.g. on CoNLL2003 | [**`translation`**](https://github.com/adapter-hub/adapter-transformers/tree/master/examples/seq2seq) | Machine translation, e.g. on WMT tasks +| [**`dependency-parsing`**](https://github.com/adapter-hub/adapter-transformers/tree/master/examples/dependency-parsing) | Dependency parsing on Universal Dependencies All scripts listed above which can be used for training provide a new `--train_adapter` option that switches between full fine-tuning and adapter training. Loading pre-trained adapters can be done via `--load_adapter`. diff --git a/examples/dependency-parsing/README.md b/examples/dependency-parsing/README.md new file mode 100644 index 0000000000..50051413a8 --- /dev/null +++ b/examples/dependency-parsing/README.md @@ -0,0 +1,38 @@ +# Dependency parsing on Universal Dependencies + +These example scripts are based on the fine-tuning code from the repository of ["How Good is Your Tokenizer? On the Monolingual Performance of Multilingual Language Models"](https://github.com/Adapter-Hub/hgiyt). +The scripts were upgraded to `adapter-transformers` v2.x and modified to use [flex heads](https://docs.adapterhub.ml/prediction_heads.html#models-with-flexible-heads) and HuggingFace Datasets. + +The used biaffine dependency parsing prediction head is described in ["Is Supervised Syntactic Parsing Beneficial for Language Understanding Tasks? An Empirical Investigation" (Glavaš & Vulić, 2021)](https://arxiv.org/pdf/2008.06788.pdf). + +## Training on Universal Dependencies + +Script: [`run_udp.py`](https://github.com/Adapter-Hub/adapter-transformers/blob/master/examples/dependency-parsing/run_udp.py). + +Fine-tuning on the treebanks of [Universal Dependencies](https://universaldependencies.org/). +The datasets are loaded from [HuggingFace Datasets](https://huggingface.co/datasets/universal_dependencies) and which dataset to use can be specified via the `--task_name` option. + +Training an adapter on the English Web Treebank (`en_ewt`) could be done as follows: + +```bash +export TASK_NAME="en_ewt" + +python run_udp.py \ + --model_name_or_path bert-base-cased \ + --do_train \ + --do_eval \ + --do_predict \ + --task_name $TASK_NAME \ + --per_device_train_batch_size 12 \ + --learning_rate 5e-4 \ + --num_train_epochs 10 \ + --max_seq_length 256 \ + --output_dir experiments/$TASK_NAME \ + --overwrite_output_dir \ + --store_best_model \ + --evaluation_strategy epoch \ + --metric_score las \ + --train_adapter +``` + +Fore more information, also visit the original code at https://github.com/Adapter-Hub/hgiyt/tree/master/finetuning. diff --git a/examples/dependency-parsing/preprocessing.py b/examples/dependency-parsing/preprocessing.py index ee96600a18..9fc631bb14 100644 --- a/examples/dependency-parsing/preprocessing.py +++ b/examples/dependency-parsing/preprocessing.py @@ -39,7 +39,7 @@ def encode_batch(examples): continue encoding = tokenizer( - words, + tokens_merged, add_special_tokens=True, padding="max_length", truncation=True, diff --git a/examples/dependency-parsing/run_udp.py b/examples/dependency-parsing/run_udp.py index 1d2529ae4e..9f771a84ea 100644 --- a/examples/dependency-parsing/run_udp.py +++ b/examples/dependency-parsing/run_udp.py @@ -1,3 +1,8 @@ +""" +Code taken and modified from: https://github.com/Adapter-Hub/hgiyt. +Credits: "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual Language Models" (Rust et al., 2021) +https://arxiv.org/abs/2012.15613 +""" import logging import os import sys @@ -33,16 +38,19 @@ class ModelArguments: metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} ) config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}, + default=None, + metadata={"help": "Pretrained config name or path if not the same as model_name"}, ) tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}, + default=None, + metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}, ) use_fast: bool = field(default=False, metadata={"help": "Set this flag to use fast tokenization."}) # If you want to tweak more attributes on your tokenizer, you should do it in a distinct script, # or just modify its tokenizer_config.json. cache_dir: Optional[str] = field( - default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}, + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}, ) replace_embeddings: bool = field(default=False, metadata={"help": "Whether or not to replace embeddings."}) leave_out_twelvth: bool = field( @@ -73,7 +81,8 @@ class DataTrainingArguments: }, ) overwrite_cache: bool = field( - default=False, metadata={"help": "Overwrite the cached training and evaluation sets."}, + default=False, + metadata={"help": "Overwrite the cached training and evaluation sets."}, ) @@ -89,7 +98,12 @@ def main(): json_file=os.path.abspath(sys.argv[1]) ) else: - (model_args, data_args, training_args, adapter_args,) = parser.parse_args_into_dataclasses() + ( + model_args, + data_args, + training_args, + adapter_args, + ) = parser.parse_args_into_dataclasses() if ( os.path.exists(training_args.output_dir) @@ -153,7 +167,9 @@ def main(): language = adapter_args.language model = AutoModelWithHeads.from_pretrained( - model_args.model_name_or_path, config=config, cache_dir=model_args.cache_dir, + model_args.model_name_or_path, + config=config, + cache_dir=model_args.cache_dir, ) model.add_dependency_parsing_head( task_name, diff --git a/examples/dependency-parsing/utils_udp.py b/examples/dependency-parsing/utils_udp.py index e7d2406e00..f333cfcf65 100644 --- a/examples/dependency-parsing/utils_udp.py +++ b/examples/dependency-parsing/utils_udp.py @@ -1,3 +1,8 @@ +""" +Code taken and modified from: https://github.com/Adapter-Hub/hgiyt. +Credits: "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual Language Models" (Rust et al., 2021) +https://arxiv.org/abs/2012.15613 +""" import collections import logging import os @@ -210,7 +215,9 @@ def __init__( # torch.autograd.set_detect_anomaly(True) def evaluate( - self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None, + self, + eval_dataset: Optional[Dataset] = None, + prediction_loss_only: Optional[bool] = None, ) -> Dict[str, float]: """ Run evaluation and return metrics. @@ -241,9 +248,7 @@ def evaluate( return output.metrics - def predict( - self, test_dataset: Dataset - ) -> PredictionOutput: + def predict(self, test_dataset: Dataset) -> PredictionOutput: """ Run prediction and returns predictions and potential metrics. @@ -276,9 +281,7 @@ def predict( """ test_dataloader = self.get_test_dataloader(test_dataset) - output = self._prediction_loop( - test_dataloader, description="Prediction" - ) + output = self._prediction_loop(test_dataloader, description="Prediction") self.log(output.metrics) @@ -339,7 +342,7 @@ def _prediction_loop( inputs[k] = v.to(self.args.device) with torch.no_grad(): - step_eval_loss, rel_preds, arc_preds = model(**inputs) + step_eval_loss, rel_preds, arc_preds = model(**inputs, return_dict=False) eval_losses += [step_eval_loss.mean().item()] diff --git a/src/transformers/adapters/heads/dependency_parsing.py b/src/transformers/adapters/heads/dependency_parsing.py index adcf859eab..33e7661d77 100644 --- a/src/transformers/adapters/heads/dependency_parsing.py +++ b/src/transformers/adapters/heads/dependency_parsing.py @@ -1,15 +1,25 @@ """ -Code taken and modified from: https://github.com/Adapter-Hub/hgiyt. -Credits: "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual Language Models" (Rust et al., 2021) -https://arxiv.org/abs/2012.15613 +Code taken and modified from: https://github.com/Adapter-Hub/hgiyt. Credits: "How Good is Your Tokenizer? On the +Monolingual Performance of Multilingual Language Models" (Rust et al., 2021) https://arxiv.org/abs/2012.15613 """ +from typing import Optional, Tuple + import torch from torch import nn from torch.nn import CrossEntropyLoss +from ...file_utils import ModelOutput from .base import PredictionHead +class DependencyParsingOutput(ModelOutput): + loss: Optional[torch.FloatTensor] = None + rel_preds: torch.FloatTensor = None + arc_preds: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + # Credit: # Class taken from https://github.com/yzhangcs/biaffine-parser class Biaffine(nn.Module): @@ -48,9 +58,8 @@ def forward(self, x, y): class BiaffineParsingHead(PredictionHead): """ - Credit: G. Glavaš & I. Vulić - Based on paper "Is Supervised Syntactic Parsing Beneficial for Language Understanding? An Empirical Investigation" - (https://arxiv.org/pdf/2008.06788.pdf) + Credit: G. Glavaš & I. Vulić Based on paper "Is Supervised Syntactic Parsing Beneficial for Language Understanding? + An Empirical Investigation" (https://arxiv.org/pdf/2008.06788.pdf) """ def __init__(self, model, head_name, num_labels=2, id2label=None): @@ -76,7 +85,15 @@ def build(self, model): self.train(model.training) # make sure training mode is consistent def forward( - self, outputs, cls_output=None, attention_mask=None, return_dict=False, word_starts=None, labels_arcs=None, labels_rels=None, **kwargs + self, + outputs, + cls_output=None, + attention_mask=None, + return_dict=False, + word_starts=None, + labels_arcs=None, + labels_rels=None, + **kwargs ): outs = self.dropout(outputs[0]) word_outputs_deps = self._merge_subword_tokens(outs, word_starts) @@ -88,20 +105,27 @@ def forward( arc_preds = self.biaffine_arcs(word_outputs_deps, word_outputs_heads) arc_preds = arc_preds.squeeze() - outputs = (arc_preds,) + if len(arc_preds.shape) == 2: + arc_preds = arc_preds.unsqueeze(0) rel_preds = self.biaffine_rels(word_outputs_deps, word_outputs_heads) rel_preds = rel_preds.permute(0, 2, 3, 1) - outputs = (rel_preds,) + outputs loss = self._get_loss(arc_preds, rel_preds, labels_arcs, labels_rels, self.loss_fn) - # TODO-AH return_dict - outputs = (loss,) + outputs - - if len(arc_preds.shape) == 2: - return loss, rel_preds, arc_preds.unsqueeze(0) - return outputs + if return_dict: + return DependencyParsingOutput( + loss=loss, + rel_preds=rel_preds, + arc_preds=arc_preds, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + outputs = (rel_preds, arc_preds) + if loss is not None: + outputs = (loss,) + outputs + return outputs def _merge_subword_tokens(self, subword_outputs, word_starts): instances = [] @@ -139,6 +163,8 @@ def _merge_subword_tokens(self, subword_outputs, word_starts): return w_tens def _get_loss(self, arc_preds, rel_preds, labels_arc, labels_rel, loss_fn): + if labels_arc is None or labels_rel is None: + return None if len(arc_preds.shape) == 2: arc_preds = arc_preds.unsqueeze(0) diff --git a/src/transformers/adapters/models/bert.py b/src/transformers/adapters/models/bert.py index f95811cb76..4afc31bdca 100644 --- a/src/transformers/adapters/models/bert.py +++ b/src/transformers/adapters/models/bert.py @@ -260,5 +260,15 @@ def add_qa_head( self.add_prediction_head(head, overwrite_ok) def add_dependency_parsing_head(self, head_name, num_labels=2, overwrite_ok=False, id2label=None): + """Adds a biaffine dependency parsing head on top of the model. + The parsing head uses the architecture described in "Is Supervised Syntactic Parsing Beneficial for Language Understanding? + An Empirical Investigation" (Glavaš & Vulić, 2021) (https://arxiv.org/pdf/2008.06788.pdf). + + Args: + head_name (str): The name of the head. + num_labels (int, optional): Number of labels. Defaults to 2. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + id2label (dict, optional): Mapping from label ids to labels. Defaults to None. + """ head = BiaffineParsingHead(self, head_name, num_labels, id2label) self.add_prediction_head(head, overwrite_ok) From ea044e4756248bc0d6d5b5eb5bd94bb8af95df04 Mon Sep 17 00:00:00 2001 From: calpt <36051308+calpt@users.noreply.github.com> Date: Mon, 19 Jul 2021 12:01:12 +0200 Subject: [PATCH 05/11] Some more fixes --- examples/dependency-parsing/preprocessing.py | 7 ++++++- examples/dependency-parsing/run_udp.py | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/dependency-parsing/preprocessing.py b/examples/dependency-parsing/preprocessing.py index 9fc631bb14..2188aab4fb 100644 --- a/examples/dependency-parsing/preprocessing.py +++ b/examples/dependency-parsing/preprocessing.py @@ -1,3 +1,8 @@ +""" +Code taken and modified from: https://github.com/Adapter-Hub/hgiyt. +Credits: "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual Language Models" (Rust et al., 2021) +https://arxiv.org/abs/2012.15613 +""" from collections import defaultdict from typing import List @@ -39,7 +44,7 @@ def encode_batch(examples): continue encoding = tokenizer( - tokens_merged, + words, add_special_tokens=True, padding="max_length", truncation=True, diff --git a/examples/dependency-parsing/run_udp.py b/examples/dependency-parsing/run_udp.py index 9f771a84ea..a8d8c4779d 100644 --- a/examples/dependency-parsing/run_udp.py +++ b/examples/dependency-parsing/run_udp.py @@ -157,6 +157,7 @@ def main(): cache_dir=model_args.cache_dir, use_fast=model_args.use_fast, do_lower_case=model_args.do_lower_case, + add_prefix_space=True, # Used e.g. for RoBERTa mecab_kwargs={"mecab_option": f"-r {model_args.mecab_dir} -d {model_args.mecab_dic_dir}"} if model_args.is_japanese else None, From cb84e2ed4b6e01b96f8751b584ab63ace787363f Mon Sep 17 00:00:00 2001 From: calpt <36051308+calpt@users.noreply.github.com> Date: Mon, 19 Jul 2021 21:34:17 +0200 Subject: [PATCH 06/11] Style --- examples/dependency-parsing/run_udp.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/dependency-parsing/run_udp.py b/examples/dependency-parsing/run_udp.py index a8d8c4779d..1112ab5d1f 100644 --- a/examples/dependency-parsing/run_udp.py +++ b/examples/dependency-parsing/run_udp.py @@ -11,7 +11,7 @@ from datasets import load_dataset -import transformers.adapters.composition as AC +import transformers.adapters.composition as ac from preprocessing import preprocess_dataset from transformers import ( AdapterConfig, @@ -228,7 +228,7 @@ def main(): model.train_adapter([task_name]) # Set the adapters to be used in every forward pass if lang_adapter_name: - model.set_active_adapters(AC.Stack(lang_adapter_name, task_name)) + model.set_active_adapters(ac.Stack(lang_adapter_name, task_name)) else: model.set_active_adapters(task_name) else: @@ -243,6 +243,7 @@ def main(): dataset = preprocess_dataset(dataset, tokenizer, labels, data_args, pad_token_id=-1) # Initialize our Trainer + # HACK: Set this attribute to False to prevent label columns from being deleted training_args.remove_unused_columns = False trainer = DependencyParsingTrainer( model=model, @@ -313,7 +314,7 @@ def main(): leave_out=leave_out, ) if language: - model.set_active_adapters(AC.Stack(lang_adapter_name, task_name)) + model.set_active_adapters(ac.Stack(lang_adapter_name, task_name)) else: model.set_active_adapters(task_name) model.to(training_args.device) From ad1eef6c5ee1d729892af5fb1e8202fc25dc9287 Mon Sep 17 00:00:00 2001 From: calpt <36051308+calpt@users.noreply.github.com> Date: Thu, 22 Jul 2021 22:10:56 +0200 Subject: [PATCH 07/11] Init flex LM heads. --- src/transformers/adapters/heads/__init__.py | 1 + src/transformers/adapters/heads/base.py | 79 +++++++++- .../adapters/heads/language_modeling.py | 147 ++++++++++++++++++ src/transformers/adapters/models/gpt2.py | 14 +- 4 files changed, 232 insertions(+), 9 deletions(-) create mode 100644 src/transformers/adapters/heads/language_modeling.py diff --git a/src/transformers/adapters/heads/__init__.py b/src/transformers/adapters/heads/__init__.py index e12be6a821..44d01013f9 100644 --- a/src/transformers/adapters/heads/__init__.py +++ b/src/transformers/adapters/heads/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa from .base import * from .dependency_parsing import * +from .language_modeling import BertStyleCausalLMHead, BertStyleMaskedLMHead, CausalLMHead, Seq2SeqLMHead diff --git a/src/transformers/adapters/heads/base.py b/src/transformers/adapters/heads/base.py index ab40f9669d..ea6158e3f2 100644 --- a/src/transformers/adapters/heads/base.py +++ b/src/transformers/adapters/heads/base.py @@ -33,13 +33,18 @@ def __init__(self, name): def build(self, model): model_config = model.config pred_head = [] + dropout_prob = self.config.get("dropout_prob", model_config.hidden_dropout_prob) + with_layer_norm = self.config.get("layer_norm", False) bias = self.config.get("bias", True) - for l in range(self.config["layers"]): - pred_head.append(nn.Dropout(model_config.hidden_dropout_prob)) - if l < self.config["layers"] - 1: - pred_head.append(nn.Linear(model_config.hidden_size, model_config.hidden_size, bias=bias)) + for l_id in range(self.config["layers"]): + if dropout_prob > 0: + pred_head.append(nn.Dropout(dropout_prob)) + if l_id < self.config["layers"] - 1: + pred_head.append(nn.Linear(model_config.hidden_size, model_config.hidden_size)) if self.config["activation_function"]: pred_head.append(Activation_Function_Class(self.config["activation_function"])) + if with_layer_norm: + pred_head.append(nn.LayerNorm(model_config.hidden_size, eps=model_config.layer_norm_eps)) else: if "num_labels" in self.config: pred_head.append(nn.Linear(model_config.hidden_size, self.config["num_labels"], bias=bias)) @@ -55,6 +60,9 @@ def build(self, model): self.apply(model._init_weights) self.train(model.training) # make sure training mode is consistent + def get_output_embeddings(self): + return None # override for heads with output embeddings + class ClassificationHead(PredictionHead): def __init__( @@ -393,6 +401,59 @@ def _init_head_modules(self): for head_name, config in self.config.prediction_heads.items(): self.add_prediction_head_from_config(head_name, config) + # The following methods are required for handling LM heads + + def get_output_embeddings(self): + all_output_embeddings = {} + + for head_name, head in self.heads.items(): + output_embeddings = head.get_output_embeddings() + if output_embeddings is not None: + all_output_embeddings[head_name] = output_embeddings + + return all_output_embeddings + + def set_output_embeddings(self, new_embeddings): + for head_name, head in self.heads.items(): + if head.get_output_embeddings() is not None: + head.set_output_embeddings(new_embeddings) + + def tie_weights(self): + """ + Tie the weights between the input embeddings and the output embeddings. + + If the :obj:`torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning + the weights instead. + """ + for head_name, head in self.heads.items(): + output_embeddings = head.get_output_embeddings() + if output_embeddings is not None and self.config.tie_word_embeddings: + self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) + + if self.config.is_encoder_decoder and self.config.tie_encoder_decoder: + if hasattr(self, self.base_model_prefix): + self = getattr(self, self.base_model_prefix) + self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix) + + return self.get_input_embeddings() + + def _resize_token_embeddings(self, new_num_tokens): + old_embeddings = self.get_input_embeddings() + new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) + self.set_input_embeddings(new_embeddings) + + # if word embeddings are not tied, make sure that lm head is resized as well + if not self.config.tie_word_embeddings: + for head in self.heads.values(): + old_lm_head = self.get_output_embeddings() + if old_lm_head is not None: + new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens) + self.set_output_embeddings(new_lm_head) + + return self.get_input_embeddings() + + # Methods for managing prediction heads + def add_prediction_head_from_config(self, head_name, config, overwrite_ok=False): head_type = config.pop("head_type") # handle cases when id2label, label2id or both are available @@ -453,7 +514,7 @@ def active_head(self, head_name_or_list: Union[str, List[str], AdapterCompositio self._active_heads = [head_name_or_list] if head_name_or_list else None # If we set a single head, also switch the label mapping. For multiple head, that doesn't make sense? if head_name_or_list: - self.config.label2id = self.heads[head_name_or_list].config["label2id"] + self.config.label2id = self.heads[head_name_or_list].config.get("label2id", None) self.config.id2label = self.get_labels_dict(head_name_or_list) else: self._active_heads = head_name_or_list @@ -512,21 +573,23 @@ def add_prediction_head( head: PredictionHead, overwrite_ok: bool = False, ): - if head.name not in self.heads or overwrite_ok: self.heads[head.name] = head # add reference to model config to save all head configs too self.config.prediction_heads[head.name] = head.config - if "label2id" not in head.config.keys() or head.config["label2id"] is None: + # Set a default label2id map if not given + if "label2id" in head.config.keys() and head.config["label2id"] is None: if "num_labels" in head.config.keys(): head.config["label2id"] = {"LABEL_" + str(num): num for num in range(head.config["num_labels"])} if "num_choices" in head.config.keys(): head.config["label2id"] = {"LABEL_" + str(num): num for num in range(head.config["num_choices"])} + # In case the added head has tied weights, tie them here. + self.tie_weights() + logger.info(f"Adding head '{head.name}' with config {head.config}.") self.active_head = head.name - else: raise ValueError( f"Model already contains a head with name '{head.name}'. Use overwrite_ok=True to force overwrite." diff --git a/src/transformers/adapters/heads/language_modeling.py b/src/transformers/adapters/heads/language_modeling.py new file mode 100644 index 0000000000..9917138b33 --- /dev/null +++ b/src/transformers/adapters/heads/language_modeling.py @@ -0,0 +1,147 @@ +import torch.nn as nn + +from ...modeling_outputs import CausalLMOutput, MaskedLMOutput, Seq2SeqLMOutput +from .base import PredictionHead + + +class CausalLMHead(PredictionHead): + def __init__( + self, + model, + head_name, + shift_labels=True, + ): + super().__init__(head_name) + self.config = { + "head_type": "causal_lm", + "num_labels": model.config.vocab_size, + "layers": 1, + "activation_function": None, + "dropout_prob": 0, + "bias": False, + "shift_labels": shift_labels, + } + self.build(model) + + def get_output_embeddings(self): + # The last child is our embedding layer + return self._modules[next(reversed(self._modules))] + + def set_output_embeddings(self, new_embeddings): + # The last child is our embedding layer + self._modules[next(reversed(self._modules))] = new_embeddings + + @staticmethod + def _create_model_output(loss, logits, base_outputs): + return CausalLMOutput( + loss=loss, + logits=logits, + hidden_states=base_outputs.hidden_states, + attentions=base_outputs.attentions, + ) + + def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs): + lm_logits = super().forward(outputs[0]) + + loss = None + labels = kwargs.pop("labels", None) + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + if self.config["shift_labels"]: + logits_for_loss = lm_logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + else: + logits_for_loss = lm_logits + loss = loss_fct(logits_for_loss.view(-1, self.config["num_labels"]), labels.view(-1)) + + if return_dict: + return self._create_model_output(loss, lm_logits, outputs) + else: + outputs = (lm_logits,) + outputs[1:] + if loss is not None: + outputs = (loss,) + outputs + return outputs + + +class Seq2SeqLMHead(CausalLMHead): + def __init__( + self, + model, + head_name, + ): + super().__init__(head_name) + self.config = { + "head_type": "seq2seq_lm", + "num_labels": model.config.vocab_size, + "layers": 1, + "activation_function": None, + "dropout_prob": 0, + "bias": False, + "shift_labels": False, + } + self.build(model) + + @staticmethod + def _create_model_output(self, loss, logits, base_outputs): + return Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=base_outputs.past_key_values, + decoder_hidden_states=base_outputs.decoder_hidden_states, + decoder_attentions=base_outputs.decoder_attentions, + cross_attentions=base_outputs.cross_attentions, + encoder_last_hidden_state=base_outputs.encoder_last_hidden_state, + encoder_hidden_states=base_outputs.encoder_hidden_states, + encoder_attentions=base_outputs.encoder_attentions, + ) + + +class BertStyleMaskedLMHead(CausalLMHead): + def __init__( + self, + model, + head_name, + activation_function="gelu", + ): + super().__init__(head_name) + self.config = { + "head_type": "masked_lm", + "num_labels": model.config.vocab_size, + "layers": 2, + "activation_function": activation_function, + "dropout_prob": 0, + "layer_norm": True, + "bias": False, + "shift_labels": False, + } + self.build(model) + + @staticmethod + def _create_model_output(self, loss, logits, base_outputs): + return MaskedLMOutput( + loss=loss, + logits=logits, + hidden_states=base_outputs.hidden_states, + attentions=base_outputs.attentions, + ) + + +class BertStyleCausalLMHead(CausalLMHead): + def __init__( + self, + model, + head_name, + activation_function="gelu", + ): + super().__init__(head_name) + self.config = { + "head_type": "causal_lm", + "num_labels": model.config.vocab_size, + "layers": 2, + "activation_function": activation_function, + "dropout_prob": 0, + "layer_norm": True, + "bias": False, + "shift_labels": True, + } + self.build(model) diff --git a/src/transformers/adapters/models/gpt2.py b/src/transformers/adapters/models/gpt2.py index efc186b78f..f72aa05957 100644 --- a/src/transformers/adapters/models/gpt2.py +++ b/src/transformers/adapters/models/gpt2.py @@ -4,7 +4,7 @@ from torch import nn from ..composition import AdapterCompositionBlock, parse_composition -from ..heads import ClassificationHead, MultiLabelClassificationHead +from ..heads import CausalLMHead, ClassificationHead, MultiLabelClassificationHead from ..model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin from .bert import ( BertEncoderAdaptersMixin, @@ -176,6 +176,7 @@ class GPT2ModelHeadsMixin(ModelWithFlexibleHeadsAdaptersMixin): head_types = { "classification": ClassificationHead, "multilabel_classification": MultiLabelClassificationHead, + "causal_lm": CausalLMHead, } def add_classification_head( @@ -205,3 +206,14 @@ def add_classification_head( else: head = ClassificationHead(self, head_name, num_labels, layers, activation_function, id2label) self.add_prediction_head(head, overwrite_ok) + + def add_causal_lm_head(self, head_name, overwrite_ok=False): + """ + Adds a causal language modeling head on top of the model. + + Args: + head_name (str): The name of the head. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + """ + head = CausalLMHead(self, head_name) + self.add_prediction_head(head, overwrite_ok=overwrite_ok) From c98507c4465acb6855ab0fbba6c089de39335d7a Mon Sep 17 00:00:00 2001 From: calpt <36051308+calpt@users.noreply.github.com> Date: Fri, 23 Jul 2021 15:34:57 +0200 Subject: [PATCH 08/11] Finished flex LM head implementations. Added tests for all possible head conversions. --- src/transformers/adapters/head_utils.py | 154 +++++++++++++++--- src/transformers/adapters/heads/__init__.py | 2 +- src/transformers/adapters/heads/base.py | 11 +- .../adapters/heads/language_modeling.py | 110 ++++++++----- src/transformers/adapters/modeling.py | 2 + src/transformers/adapters/models/bart.py | 17 ++ src/transformers/adapters/models/bert.py | 28 ++++ tests/test_adapter.py | 140 ++++++++++++++-- tests/test_adapter_common.py | 19 +-- tests/test_adapter_conversion.py | 138 ++++++++++++++++ tests/test_adapter_heads.py | 34 ++++ 11 files changed, 544 insertions(+), 111 deletions(-) create mode 100644 tests/test_adapter_conversion.py diff --git a/src/transformers/adapters/head_utils.py b/src/transformers/adapters/head_utils.py index 916d5fa782..208afe97c5 100644 --- a/src/transformers/adapters/head_utils.py +++ b/src/transformers/adapters/head_utils.py @@ -6,6 +6,8 @@ logger = logging.getLogger(__name__) +# The "layers" attributes in the configs below map from static head module names to flex head module names. +# In this context, "None" refers to a flex-head layer without weights (e.g. dropout, acts). STATIC_TO_FLEX_HEAD_MAP = { # BERT "BertForSequenceClassification": { @@ -15,7 +17,7 @@ "activation_function": None, "use_pooler": True, }, - "layers": ["classifier"], + "layers": [None, "classifier"], }, "BertForMultipleChoice": { "config": { @@ -24,7 +26,7 @@ "activation_function": None, "use_pooler": True, }, - "layers": ["classifier"], + "layers": [None, "classifier"], }, "BertForTokenClassification": { "config": { @@ -32,7 +34,7 @@ "layers": 1, "activation_function": None, }, - "layers": ["classifier"], + "layers": [None, "classifier"], }, "BertForQuestionAnswering": { "config": { @@ -40,7 +42,37 @@ "layers": 1, "activation_function": None, }, - "layers": ["qa_outputs"], + "layers": [None, "qa_outputs"], + }, + "BertForMaskedLM": { + "config": { + "head_type": "masked_lm", + "layers": 2, + "activation_function": "gelu_orig", + "layer_norm": True, + "bias": True, + }, + "layers": [ + "cls.predictions.transform.dense", + None, + "cls.predictions.transform.LayerNorm", + "cls.predictions.decoder", + ], + }, + "BertLMHeadModel": { + "config": { + "head_type": "causal_lm", + "layers": 2, + "activation_function": "gelu_orig", + "layer_norm": True, + "bias": True, + }, + "layers": [ + "cls.predictions.transform.dense", + None, + "cls.predictions.transform.LayerNorm", + "cls.predictions.decoder", + ], }, # RoBERTa "RobertaForSequenceClassification": { @@ -50,7 +82,7 @@ "activation_function": "tanh", "use_pooler": False, }, - "layers": ["classifier.dense", "classifier.out_proj"], + "layers": [None, "classifier.dense", None, None, "classifier.out_proj"], }, "RobertaForMultipleChoice": { "config": { @@ -59,7 +91,7 @@ "activation_function": None, "use_pooler": True, }, - "layers": ["classifier"], + "layers": [None, "classifier"], }, "RobertaForTokenClassification": { "config": { @@ -67,7 +99,7 @@ "layers": 1, "activation_function": None, }, - "layers": ["classifier"], + "layers": [None, "classifier"], }, "RobertaForQuestionAnswering": { "config": { @@ -75,7 +107,27 @@ "layers": 1, "activation_function": None, }, - "layers": ["qa_outputs"], + "layers": [None, "qa_outputs"], + }, + "RobertaForMaskedLM": { + "config": { + "head_type": "masked_lm", + "layers": 2, + "activation_function": "gelu_orig", + "layer_norm": True, + "bias": True, + }, + "layers": ["lm_head.dense", None, "lm_head.layer_norm", "lm_head.decoder"], + }, + "RobertaForCausalLM": { + "config": { + "head_type": "causal_lm", + "layers": 2, + "activation_function": "gelu_orig", + "layer_norm": True, + "bias": True, + }, + "layers": ["lm_head.dense", None, "lm_head.layer_norm", "lm_head.decoder"], }, # XLM-RoBERTa "XLMRobertaForSequenceClassification": { @@ -85,7 +137,7 @@ "activation_function": "tanh", "use_pooler": False, }, - "layers": ["classifier.dense", "classifier.out_proj"], + "layers": [None, "classifier.dense", None, None, "classifier.out_proj"], }, "XLMRobertaForMultipleChoice": { "config": { @@ -94,7 +146,7 @@ "activation_function": None, "use_pooler": True, }, - "layers": ["classifier"], + "layers": [None, "classifier"], }, "XLMRobertaForTokenClassification": { "config": { @@ -102,7 +154,7 @@ "layers": 1, "activation_function": None, }, - "layers": ["classifier"], + "layers": [None, "classifier"], }, "XLMRobertaForQuestionAnswering": { "config": { @@ -110,7 +162,27 @@ "layers": 1, "activation_function": None, }, - "layers": ["qa_outputs"], + "layers": [None, "qa_outputs"], + }, + "XLMRobertaForMaskedLM": { + "config": { + "head_type": "masked_lm", + "layers": 2, + "activation_function": "gelu_orig", + "layer_norm": True, + "bias": True, + }, + "layers": ["lm_head.dense", "lm_head.layer_norm", "lm_head.decoder"], + }, + "XLMRobertaForCausalLM": { + "config": { + "head_type": "causal_lm", + "layers": 2, + "activation_function": "gelu_orig", + "layer_norm": True, + "bias": True, + }, + "layers": ["lm_head.dense", None, "lm_head.layer_norm", "lm_head.decoder"], }, # BART "BartForSequenceClassification": { @@ -119,7 +191,7 @@ "layers": 2, "activation_function": "tanh", }, - "layers": ["classification_head.dense", "classification_head.out_proj"], + "layers": [None, "classification_head.dense", None, None, "classification_head.out_proj"], }, "BartForQuestionAnswering": { "config": { @@ -127,7 +199,13 @@ "layers": 1, "activation_function": None, }, - "layers": ["qa_outputs"], + "layers": [None, "qa_outputs"], + }, + "BartForConditionalGeneration": { + "config": { + "head_type": "seq2seq_lm", + }, + "layers": ["lm_head"], }, # MBART "MBartForSequenceClassification": { @@ -136,7 +214,7 @@ "layers": 2, "activation_function": "tanh", }, - "layers": ["classification_head.dense", "classification_head.out_proj"], + "layers": [None, "classification_head.dense", None, None, "classification_head.out_proj"], }, "MBartForQuestionAnswering": { "config": { @@ -144,7 +222,13 @@ "layers": 1, "activation_function": None, }, - "layers": ["qa_outputs"], + "layers": [None, "qa_outputs"], + }, + "MBartForConditionalGeneration": { + "config": { + "head_type": "seq2seq_lm", + }, + "layers": ["lm_head"], }, # DistilBERT "DistilBertForSequenceClassification": { @@ -153,7 +237,7 @@ "layers": 2, "activation_function": "relu", }, - "layers": ["pre_classifier", "classifier"], + "layers": [None, "pre_classifier", None, None, "classifier"], }, "DistilBertForMultipleChoice": { "config": { @@ -161,7 +245,7 @@ "layers": 2, "activation_function": "relu", }, - "layers": ["pre_classifier", "classifier"], + "layers": [None, "pre_classifier", None, None, "classifier"], }, "DistilBertForTokenClassification": { "config": { @@ -169,7 +253,7 @@ "layers": 1, "activation_function": None, }, - "layers": ["classifier"], + "layers": [None, "classifier"], }, "DistilBertForQuestionAnswering": { "config": { @@ -177,7 +261,17 @@ "layers": 1, "activation_function": None, }, - "layers": ["qa_outputs"], + "layers": [None, "qa_outputs"], + }, + "DistilBertForMaskedLM": { + "config": { + "head_type": "masked_lm", + "layers": 2, + "activation_function": "gelu_orig", + "layer_norm": True, + "bias": True, + }, + "layers": ["vocab_transform", None, "vocab_layer_norm", "vocab_projector"], }, # GPT-2 "GPT2ForSequenceClassification": { @@ -187,7 +281,13 @@ "activation_function": None, "bias": False, }, - "layers": ["score"], + "layers": [None, "score"], + }, + "GPT2LMHeadModel": { + "config": { + "head_type": "causal_lm", + }, + "layers": ["lm_head"], }, } @@ -213,16 +313,18 @@ def get_head_config_and_rename_list(model_class_name, head_name, label2id, num_l config = copy.deepcopy(data["config"]) if config["head_type"] == "multiple_choice": config["num_choices"] = num_labels - else: + config["label2id"] = label2id + elif config["head_type"] not in ["causal_lm", "masked_lm", "seq2seq_lm"]: config["num_labels"] = num_labels - config["label2id"] = label2id + config["label2id"] = label2id # rename rename_list = [] i = 0 for name in data["layers"]: - escaped_name = re.escape(name) - rename_list.append((rf"{escaped_name}\.(\S+)", f"heads.{head_name}.{i+1}.{{0}}")) - i += 3 if config["activation_function"] else 2 # there's always a dropout layer in between + if name is not None: + escaped_name = re.escape(name) + rename_list.append((rf"{escaped_name}\.(\S+)", f"heads.{head_name}.{i}.{{0}}")) + i += 1 rename_func = lambda k, rename_list=rename_list: _regex_list_rename_func(k, rename_list) return config, rename_func diff --git a/src/transformers/adapters/heads/__init__.py b/src/transformers/adapters/heads/__init__.py index 44d01013f9..26deac2d0a 100644 --- a/src/transformers/adapters/heads/__init__.py +++ b/src/transformers/adapters/heads/__init__.py @@ -1,4 +1,4 @@ # flake8: noqa from .base import * from .dependency_parsing import * -from .language_modeling import BertStyleCausalLMHead, BertStyleMaskedLMHead, CausalLMHead, Seq2SeqLMHead +from .language_modeling import BertStyleMaskedLMHead, CausalLMHead, Seq2SeqLMHead diff --git a/src/transformers/adapters/heads/base.py b/src/transformers/adapters/heads/base.py index ea6158e3f2..a4c3e05a8d 100644 --- a/src/transformers/adapters/heads/base.py +++ b/src/transformers/adapters/heads/base.py @@ -34,7 +34,6 @@ def build(self, model): model_config = model.config pred_head = [] dropout_prob = self.config.get("dropout_prob", model_config.hidden_dropout_prob) - with_layer_norm = self.config.get("layer_norm", False) bias = self.config.get("bias", True) for l_id in range(self.config["layers"]): if dropout_prob > 0: @@ -43,8 +42,6 @@ def build(self, model): pred_head.append(nn.Linear(model_config.hidden_size, model_config.hidden_size)) if self.config["activation_function"]: pred_head.append(Activation_Function_Class(self.config["activation_function"])) - if with_layer_norm: - pred_head.append(nn.LayerNorm(model_config.hidden_size, eps=model_config.layer_norm_eps)) else: if "num_labels" in self.config: pred_head.append(nn.Linear(model_config.hidden_size, self.config["num_labels"], bias=bias)) @@ -465,9 +462,13 @@ def add_prediction_head_from_config(self, head_name, config, overwrite_ok=False) else: # don't pass label2id to head_class config.pop("label2id", None) + # re-add id2label map to config + if id2label is not None: + config["id2label"] = id2label + if head_type in self.head_types: head_class = self.head_types[head_type] - head = head_class(self, head_name, id2label=id2label, **config) + head = head_class(self, head_name, **config) self.add_prediction_head(head, overwrite_ok=overwrite_ok) elif head_type in self.config.custom_heads: # we have to re-add the head type for custom heads @@ -692,7 +693,7 @@ def get_labels_dict(self, head_name=None): head_name = self.active_head if head_name is None: raise ValueError("No head name given and no active head in the model") - if "label2id" in self.heads[head_name].config.keys(): + if "label2id" in self.heads[head_name].config.keys() and self.heads[head_name].config["label2id"] is not None: return {id_: label for label, id_ in self.heads[head_name].config["label2id"].items()} else: return None diff --git a/src/transformers/adapters/heads/language_modeling.py b/src/transformers/adapters/heads/language_modeling.py index 9917138b33..6026d7d6c1 100644 --- a/src/transformers/adapters/heads/language_modeling.py +++ b/src/transformers/adapters/heads/language_modeling.py @@ -1,6 +1,7 @@ import torch.nn as nn from ...modeling_outputs import CausalLMOutput, MaskedLMOutput, Seq2SeqLMOutput +from ..modeling import Activation_Function_Class from .base import PredictionHead @@ -9,20 +10,50 @@ def __init__( self, model, head_name, + vocab_size=None, + layers=1, + activation_function=None, + layer_norm=False, + bias=False, shift_labels=True, ): - super().__init__(head_name) + super(CausalLMHead, self).__init__(head_name) self.config = { "head_type": "causal_lm", - "num_labels": model.config.vocab_size, - "layers": 1, - "activation_function": None, - "dropout_prob": 0, - "bias": False, + "vocab_size": vocab_size or model.config.vocab_size, + "layers": layers, + "activation_function": activation_function, + "layer_norm": layer_norm, + "bias": bias, "shift_labels": shift_labels, + "label2id": None, } self.build(model) + def build(self, model): + model_config = model.config + # Additional FC layers + pred_head = [] + with_layer_norm = self.config.get("layer_norm", False) + for l_id in range(self.config["layers"] - 1): + pred_head.append(nn.Linear(model_config.hidden_size, model_config.hidden_size)) + if self.config["activation_function"]: + pred_head.append(Activation_Function_Class(self.config["activation_function"])) + if with_layer_norm: + eps = getattr(model_config, "layer_norm_eps", 1e-12) + pred_head.append(nn.LayerNorm(model_config.hidden_size, eps=eps)) + for i, module in enumerate(pred_head): + self.add_module(str(i), module) + + # Final embedding layer + self.add_module( + str(len(pred_head)), + nn.Linear(model_config.hidden_size, self.config["vocab_size"], bias=self.config["bias"]), + ) + + self.apply(model._init_weights) + self.train(model.training) # make sure training mode is consistent + def get_output_embeddings(self): # The last child is our embedding layer return self._modules[next(reversed(self._modules))] @@ -52,7 +83,7 @@ def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=Fal labels = labels[..., 1:].contiguous() else: logits_for_loss = lm_logits - loss = loss_fct(logits_for_loss.view(-1, self.config["num_labels"]), labels.view(-1)) + loss = loss_fct(logits_for_loss.view(-1, self.config["vocab_size"]), labels.view(-1)) if return_dict: return self._create_model_output(loss, lm_logits, outputs) @@ -68,21 +99,28 @@ def __init__( self, model, head_name, + vocab_size=None, + layers=1, + activation_function=None, + layer_norm=False, + bias=False, + shift_labels=False, ): - super().__init__(head_name) + super(CausalLMHead, self).__init__(head_name) self.config = { "head_type": "seq2seq_lm", - "num_labels": model.config.vocab_size, - "layers": 1, - "activation_function": None, - "dropout_prob": 0, - "bias": False, - "shift_labels": False, + "vocab_size": vocab_size or model.config.vocab_size, + "layers": layers, + "activation_function": activation_function, + "layer_norm": layer_norm, + "bias": bias, + "shift_labels": shift_labels, + "label2id": None, } self.build(model) @staticmethod - def _create_model_output(self, loss, logits, base_outputs): + def _create_model_output(loss, logits, base_outputs): return Seq2SeqLMOutput( loss=loss, logits=logits, @@ -101,47 +139,31 @@ def __init__( self, model, head_name, + vocab_size=None, + layers=2, activation_function="gelu", + layer_norm=True, + bias=True, + shift_labels=False, ): - super().__init__(head_name) + super(CausalLMHead, self).__init__(head_name) self.config = { "head_type": "masked_lm", - "num_labels": model.config.vocab_size, - "layers": 2, + "vocab_size": vocab_size or model.config.vocab_size, + "layers": layers, "activation_function": activation_function, - "dropout_prob": 0, - "layer_norm": True, - "bias": False, - "shift_labels": False, + "layer_norm": layer_norm, + "bias": bias, + "shift_labels": shift_labels, + "label2id": None, } self.build(model) @staticmethod - def _create_model_output(self, loss, logits, base_outputs): + def _create_model_output(loss, logits, base_outputs): return MaskedLMOutput( loss=loss, logits=logits, hidden_states=base_outputs.hidden_states, attentions=base_outputs.attentions, ) - - -class BertStyleCausalLMHead(CausalLMHead): - def __init__( - self, - model, - head_name, - activation_function="gelu", - ): - super().__init__(head_name) - self.config = { - "head_type": "causal_lm", - "num_labels": model.config.vocab_size, - "layers": 2, - "activation_function": activation_function, - "dropout_prob": 0, - "layer_norm": True, - "bias": False, - "shift_labels": True, - } - self.build(model) diff --git a/src/transformers/adapters/modeling.py b/src/transformers/adapters/modeling.py index 2cb9ac3a71..6cd8058bd8 100644 --- a/src/transformers/adapters/modeling.py +++ b/src/transformers/adapters/modeling.py @@ -31,6 +31,8 @@ def gelu_new(x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) self.f = gelu_new + elif hidden_act.lower() == "gelu_orig": + self.f = nn.functional.gelu elif hidden_act.lower() == "leakyrelu": self.f = nn.functional.leaky_relu diff --git a/src/transformers/adapters/models/bart.py b/src/transformers/adapters/models/bart.py index e9824c1ed6..0f31ab396c 100644 --- a/src/transformers/adapters/models/bart.py +++ b/src/transformers/adapters/models/bart.py @@ -9,6 +9,7 @@ ModelWithFlexibleHeadsAdaptersMixin, MultiLabelClassificationHead, QuestionAnsweringHead, + Seq2SeqLMHead, ) from ..layer import AdapterLayerBaseMixin from ..model_mixin import ModelAdaptersMixin @@ -294,6 +295,7 @@ class BartModelHeadsMixin(ModelWithFlexibleHeadsAdaptersMixin): "classification": ClassificationHead, "multilabel_classification": MultiLabelClassificationHead, "question_answering": QuestionAnsweringHead, + "seq2seq_lm": Seq2SeqLMHead, } def add_classification_head( @@ -335,3 +337,18 @@ def add_qa_head( ): head = QuestionAnsweringHead(self, head_name, num_labels, layers, activation_function, id2label) self.add_prediction_head(head, overwrite_ok) + + def add_seq2seq_lm_head( + self, + head_name, + overwrite_ok=False, + ): + """ + Adds a sequence-to-sequence language modeling head on top of the model. + + Args: + head_name (str): The name of the head. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + """ + head = Seq2SeqLMHead(self, head_name) + self.add_prediction_head(head, overwrite_ok=overwrite_ok) diff --git a/src/transformers/adapters/models/bert.py b/src/transformers/adapters/models/bert.py index 4afc31bdca..ecafc0e81b 100644 --- a/src/transformers/adapters/models/bert.py +++ b/src/transformers/adapters/models/bert.py @@ -5,7 +5,9 @@ from ..composition import AdapterCompositionBlock, parse_composition from ..heads import ( + BertStyleMaskedLMHead, BiaffineParsingHead, + CausalLMHead, ClassificationHead, ModelWithFlexibleHeadsAdaptersMixin, MultiLabelClassificationHead, @@ -181,6 +183,8 @@ class BertModelHeadsMixin(ModelWithFlexibleHeadsAdaptersMixin): "multiple_choice": MultipleChoiceHead, "question_answering": QuestionAnsweringHead, "dependency_parsing": BiaffineParsingHead, + "masked_lm": BertStyleMaskedLMHead, + "causal_lm": CausalLMHead, } def add_classification_head( @@ -272,3 +276,27 @@ def add_dependency_parsing_head(self, head_name, num_labels=2, overwrite_ok=Fals """ head = BiaffineParsingHead(self, head_name, num_labels, id2label) self.add_prediction_head(head, overwrite_ok) + + def add_masked_lm_head(self, head_name, activation_function="gelu", overwrite_ok=False): + """Adds a masked language modeling head on top of the model. + + Args: + head_name (str): The name of the head. + activation_function (str, optional): Activation function. Defaults to 'gelu'. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + """ + head = BertStyleMaskedLMHead(self, head_name, activation_function=activation_function) + self.add_prediction_head(head, overwrite_ok=overwrite_ok) + + def add_causal_lm_head(self, head_name, activation_function="gelu", overwrite_ok=False): + """Adds a causal language modeling head on top of the model. + + Args: + head_name (str): The name of the head. + activation_function (str, optional): Activation function. Defaults to 'gelu'. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + """ + head = CausalLMHead( + self, head_name, layers=2, activation_function=activation_function, layer_norm=True, bias=True + ) + self.add_prediction_head(head, overwrite_ok=overwrite_ok) diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 0d212d76f4..a55301835f 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -1,10 +1,14 @@ +import random import unittest +import torch + from transformers import BartConfig, BertConfig, DistilBertConfig, GPT2Config, MBartConfig, RobertaConfig -from transformers.testing_utils import require_torch +from transformers.testing_utils import require_torch, torch_device from .test_adapter_common import AdapterModelTestMixin from .test_adapter_composition import ParallelAdapterInferenceTestMixin +from .test_adapter_conversion import ModelClassConversionTestMixin from .test_adapter_fusion_common import AdapterFusionModelTestMixin from .test_adapter_heads import PredictionHeadModelTestMixin from .test_adapter_training import AdapterTrainingTestMixin @@ -14,15 +18,25 @@ def make_config(config_class, **kwargs): return staticmethod(lambda: config_class(**kwargs)) -@require_torch -class BertAdapterTest( - AdapterModelTestMixin, - AdapterFusionModelTestMixin, - PredictionHeadModelTestMixin, - AdapterTrainingTestMixin, - ParallelAdapterInferenceTestMixin, - unittest.TestCase, -): +class AdapterTestBase: + def get_input_samples(self, shape, vocab_size=5000, config=None): + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(random.randint(0, vocab_size - 1)) + input_ids = torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous() + # this is needed e.g. for BART + if config and config.eos_token_id is not None: + input_ids[input_ids == config.eos_token_id] = random.randint(0, config.eos_token_id - 1) + input_ids[:, -1] = config.eos_token_id + + return input_ids + + +class BertAdapterTestBase(AdapterTestBase): config_class = BertConfig config = make_config( BertConfig, @@ -35,13 +49,28 @@ class BertAdapterTest( @require_torch -class RobertaAdapterTest( +class BertAdapterTest( AdapterModelTestMixin, AdapterFusionModelTestMixin, PredictionHeadModelTestMixin, + AdapterTrainingTestMixin, ParallelAdapterInferenceTestMixin, + BertAdapterTestBase, + unittest.TestCase, +): + pass + + +@require_torch +class BertClassConversionTest( + ModelClassConversionTestMixin, + BertAdapterTestBase, unittest.TestCase, ): + pass + + +class RobertaAdapterTestBase(AdapterTestBase): config_class = RobertaConfig config = make_config( RobertaConfig, @@ -53,14 +82,27 @@ class RobertaAdapterTest( @require_torch -class DistilBertAdapterTest( +class RobertaAdapterTest( AdapterModelTestMixin, AdapterFusionModelTestMixin, PredictionHeadModelTestMixin, - AdapterTrainingTestMixin, ParallelAdapterInferenceTestMixin, + RobertaAdapterTestBase, + unittest.TestCase, +): + pass + + +@require_torch +class RobertaClassConversionTest( + ModelClassConversionTestMixin, + RobertaAdapterTestBase, unittest.TestCase, ): + pass + + +class DistilBertAdapterTestBase(AdapterTestBase): config_class = DistilBertConfig config = make_config( DistilBertConfig, @@ -73,14 +115,28 @@ class DistilBertAdapterTest( @require_torch -class BartAdapterTest( +class DistilBertAdapterTest( AdapterModelTestMixin, AdapterFusionModelTestMixin, PredictionHeadModelTestMixin, AdapterTrainingTestMixin, ParallelAdapterInferenceTestMixin, + DistilBertAdapterTestBase, unittest.TestCase, ): + pass + + +@require_torch +class DistilBertClassConversionTest( + ModelClassConversionTestMixin, + DistilBertAdapterTestBase, + unittest.TestCase, +): + pass + + +class BartAdapterTestBase(AdapterTestBase): config_class = BartConfig config = make_config( BartConfig, @@ -96,13 +152,28 @@ class BartAdapterTest( @require_torch -class MBartAdapterTest( +class BartAdapterTest( AdapterModelTestMixin, AdapterFusionModelTestMixin, PredictionHeadModelTestMixin, + AdapterTrainingTestMixin, ParallelAdapterInferenceTestMixin, + BartAdapterTestBase, + unittest.TestCase, +): + pass + + +@require_torch +class BartClassConversionTest( + ModelClassConversionTestMixin, + BartAdapterTestBase, unittest.TestCase, ): + pass + + +class MBartAdapterTestBase(AdapterTestBase): config_class = MBartConfig config = make_config( MBartConfig, @@ -117,14 +188,27 @@ class MBartAdapterTest( @require_torch -class GPT2AdapterTest( +class MBartAdapterTest( AdapterModelTestMixin, AdapterFusionModelTestMixin, PredictionHeadModelTestMixin, - AdapterTrainingTestMixin, ParallelAdapterInferenceTestMixin, + MBartAdapterTestBase, unittest.TestCase, ): + pass + + +@require_torch +class MBartClassConversionTest( + ModelClassConversionTestMixin, + MBartAdapterTestBase, + unittest.TestCase, +): + pass + + +class GPT2AdapterTestBase(AdapterTestBase): config_class = GPT2Config config = make_config( GPT2Config, @@ -135,3 +219,25 @@ class GPT2AdapterTest( pad_token_id=50256, ) tokenizer_name = "gpt2" + + +@require_torch +class GPT2AdapterTest( + AdapterModelTestMixin, + AdapterFusionModelTestMixin, + PredictionHeadModelTestMixin, + AdapterTrainingTestMixin, + ParallelAdapterInferenceTestMixin, + GPT2AdapterTestBase, + unittest.TestCase, +): + pass + + +@require_torch +class GPT2ClassConversionTest( + ModelClassConversionTestMixin, + GPT2AdapterTestBase, + unittest.TestCase, +): + pass diff --git a/tests/test_adapter_common.py b/tests/test_adapter_common.py index fecca08b1d..076a013e7e 100644 --- a/tests/test_adapter_common.py +++ b/tests/test_adapter_common.py @@ -1,5 +1,4 @@ import copy -import random import tempfile import torch @@ -14,7 +13,7 @@ PfeifferConfig, PfeifferInvConfig, ) -from transformers.testing_utils import require_torch, torch_device +from transformers.testing_utils import require_torch def create_twin_models(model_class, config_creator=None): @@ -33,22 +32,6 @@ def create_twin_models(model_class, config_creator=None): @require_torch class AdapterModelTestMixin: - def get_input_samples(self, shape, vocab_size=5000, config=None): - total_dims = 1 - for dim in shape: - total_dims *= dim - - values = [] - for _ in range(total_dims): - values.append(random.randint(0, vocab_size - 1)) - input_ids = torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous() - # this is needed e.g. for BART - if config and config.eos_token_id is not None: - input_ids[input_ids == config.eos_token_id] = random.randint(0, config.eos_token_id - 1) - input_ids[:, -1] = config.eos_token_id - - return input_ids - def test_add_adapter(self): model = AutoModel.from_config(self.config()) model.eval() diff --git a/tests/test_adapter_conversion.py b/tests/test_adapter_conversion.py new file mode 100644 index 0000000000..25976ba0e2 --- /dev/null +++ b/tests/test_adapter_conversion.py @@ -0,0 +1,138 @@ +import inspect +import re +import tempfile + +import torch + +from transformers import ( + MODEL_FOR_CAUSAL_LM_MAPPING, + MODEL_FOR_MASKED_LM_MAPPING, + MODEL_FOR_MULTIPLE_CHOICE_MAPPING, + MODEL_FOR_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, + AutoModelWithHeads, + BertPreTrainedModel, + RobertaPreTrainedModel, +) +from transformers.testing_utils import require_torch, torch_device + + +@require_torch +class ModelClassConversionTestMixin: + + batch_size = 1 + seq_length = 128 + + def run_test(self, static_model, input_shape=None, label_dict=None): + flex_model = AutoModelWithHeads.from_pretrained( + None, config=self.config(), state_dict=static_model.state_dict() + ) + static_model.eval() + flex_model.eval() + if static_model.base_model.__class__ != flex_model.base_model.__class__: + self.skipTest("Skipping as base model classes are different.") + + with tempfile.TemporaryDirectory() as temp_dir: + static_model.save_head(temp_dir) + + loading_info = {} + flex_model.load_head(temp_dir, load_as="test", loading_info=loading_info) + + self.assertEqual( + 0, len(loading_info["missing_keys"]), "Missing keys: {}".format(", ".join(loading_info["missing_keys"])) + ) + # We don't need to convert some of the weights, so remove them for the check + unexpected_keys = loading_info["unexpected_keys"] + if static_model._keys_to_ignore_on_load_missing is not None: + for pat in static_model._keys_to_ignore_on_load_missing: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + # HACK for bert-based models + if isinstance(static_model, BertPreTrainedModel): + unexpected_keys = [k for k in unexpected_keys if "cls.predictions.bias" not in k] + elif isinstance(static_model, RobertaPreTrainedModel): + unexpected_keys = [k for k in unexpected_keys if "lm_head.bias" not in k] + self.assertEqual(0, len(unexpected_keys), "Unexpected keys: {}".format(", ".join(unexpected_keys))) + + # adapter and head were loaded + self.assertIn("test", flex_model.heads) + + # check equal output + input_shape = input_shape or (self.batch_size, self.seq_length) + in_data = {"input_ids": self.get_input_samples(input_shape, config=flex_model.config)} + if label_dict: + for k, v in label_dict.items(): + in_data[k] = v + output1 = static_model(**in_data) + output2 = flex_model(**in_data) + self.assertTrue(torch.allclose(output1.loss, output2.loss)) + self.assertTrue(torch.allclose(output1[1], output2[1])) # it's not called "logits" for all classes + + def test_conversion_causal_lm_model(self): + if self.config_class not in MODEL_FOR_CAUSAL_LM_MAPPING: + self.skipTest("No causal language modeling class.") + + model = MODEL_FOR_CAUSAL_LM_MAPPING[self.config_class](self.config()) + label_dict = {} + label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device) + self.run_test(model, label_dict=label_dict) + + def test_conversion_masked_lm_model(self): + if self.config_class not in MODEL_FOR_MASKED_LM_MAPPING: + self.skipTest("No masked language modeling class.") + + model = MODEL_FOR_MASKED_LM_MAPPING[self.config_class](self.config()) + label_dict = {} + label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device) + # for encoder-decoder models such as BART, we additionally pass the decoder input ids + if "decoder_input_ids" in inspect.signature(model.forward).parameters: + label_dict["decoder_input_ids"] = label_dict["labels"].clone() + self.run_test(model, label_dict=label_dict) + + def test_conversion_seq2seq_lm_model(self): + if self.config_class not in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: + self.skipTest("No seq2seq language modeling class.") + + model = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[self.config_class](self.config()) + label_dict = {} + label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device) + label_dict["decoder_input_ids"] = label_dict["labels"].clone() + self.run_test(model, label_dict=label_dict) + + def test_conversion_classification_model(self): + if self.config_class not in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING: + self.skipTest("No sequence classification class.") + + model = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING[self.config_class](self.config()) + label_dict = {} + label_dict["labels"] = torch.zeros(self.batch_size, dtype=torch.long, device=torch_device) + self.run_test(model, label_dict=label_dict) + + def test_conversion_question_answering_model(self): + if self.config_class not in MODEL_FOR_QUESTION_ANSWERING_MAPPING: + self.skipTest("No question answering class.") + + model = MODEL_FOR_QUESTION_ANSWERING_MAPPING[self.config_class](self.config()) + label_dict = {} + label_dict["start_positions"] = torch.zeros(self.batch_size, dtype=torch.long, device=torch_device) + label_dict["end_positions"] = torch.zeros(self.batch_size, dtype=torch.long, device=torch_device) + self.run_test(model, label_dict=label_dict) + + def test_conversion_token_classification_model(self): + if self.config_class not in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING: + self.skipTest("No token classification class.") + + model = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING[self.config_class](self.config()) + label_dict = {} + label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device) + self.run_test(model, label_dict=label_dict) + + def test_conversion_multiple_choice_model(self): + if self.config_class not in MODEL_FOR_MULTIPLE_CHOICE_MAPPING: + self.skipTest("No token classification class.") + + model = MODEL_FOR_MULTIPLE_CHOICE_MAPPING[self.config_class](self.config()) + label_dict = {} + label_dict["labels"] = torch.ones(self.batch_size, dtype=torch.long, device=torch_device) + self.run_test(model, input_shape=(self.batch_size, 2, self.seq_length), label_dict=label_dict) diff --git a/tests/test_adapter_heads.py b/tests/test_adapter_heads.py index 0ff5e99422..ca5f541588 100644 --- a/tests/test_adapter_heads.py +++ b/tests/test_adapter_heads.py @@ -96,6 +96,40 @@ def test_qa_head(self): model1, model2, "dummy", output_shape=(1, self.seq_length), label_dict=label_dict ) + def test_causal_or_seq2seq_lm_head(self): + if not hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_causal_lm_head"): + if hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_seq2seq_lm_head"): + seq2seq_head = True + else: + self.skipTest("No causal or seq2seq language model head") + else: + seq2seq_head = False + + model1, model2 = create_twin_models(AutoModelWithHeads, self.config) + + if seq2seq_head: + model1.add_seq2seq_lm_head("dummy") + else: + model1.add_causal_lm_head("dummy") + label_dict = {} + label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device) + self.run_prediction_head_test( + model1, model2, "dummy", output_shape=(1, self.seq_length, model1.config.vocab_size), label_dict=label_dict + ) + + def test_masked_lm_head(self): + if not hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_masked_lm_head"): + self.skipTest("No causal or seq2seq language model head") + + model1, model2 = create_twin_models(AutoModelWithHeads, self.config) + + model1.add_masked_lm_head("dummy") + label_dict = {} + label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device) + self.run_prediction_head_test( + model1, model2, "dummy", output_shape=(1, self.seq_length, model1.config.vocab_size), label_dict=label_dict + ) + def test_dependency_parsing_head(self): if not hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_dependency_parsing_head"): self.skipTest("No dependency parsing head") From 040a5cb3d446240177b2cb92152e284e1d317e6d Mon Sep 17 00:00:00 2001 From: calpt <36051308+calpt@users.noreply.github.com> Date: Mon, 26 Jul 2021 17:46:46 +0200 Subject: [PATCH 09/11] Invertible adapters in flex LM heads. --- src/transformers/adapters/heads/base.py | 5 +++ .../adapters/heads/language_modeling.py | 13 ++++++- src/transformers/adapters/models/bart.py | 1 + tests/test_adapter_heads.py | 35 +++++++++++++++++++ 4 files changed, 53 insertions(+), 1 deletion(-) diff --git a/src/transformers/adapters/heads/base.py b/src/transformers/adapters/heads/base.py index a4c3e05a8d..80336fe102 100644 --- a/src/transformers/adapters/heads/base.py +++ b/src/transformers/adapters/heads/base.py @@ -637,6 +637,11 @@ def _get_head_input(outputs, cls_out, batch): cls_input = None return inputs, cls_input + # Pass invertible adapter if we have one + inv_adapter = self.base_model.get_invertible_adapter() + if inv_adapter: + kwargs["invertible_adapter"] = inv_adapter + for head in used_heads: if head not in self.heads: raise ValueError("Unknown head_name '{}'".format(head)) diff --git a/src/transformers/adapters/heads/language_modeling.py b/src/transformers/adapters/heads/language_modeling.py index 6026d7d6c1..fd20d9d40a 100644 --- a/src/transformers/adapters/heads/language_modeling.py +++ b/src/transformers/adapters/heads/language_modeling.py @@ -72,7 +72,18 @@ def _create_model_output(loss, logits, base_outputs): ) def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs): - lm_logits = super().forward(outputs[0]) + # First, pass through all layers except the last embedding layer + seq_outputs = outputs[0] + for i in range(len(self) - 1): + seq_outputs = self[i](seq_outputs) + + # Now, pass through an invertible adapter if available + inv_adapter = kwargs.pop("invertible_adapter", None) + if inv_adapter is not None: + seq_outputs = inv_adapter(seq_outputs, rev=True) + + # Finally, pass through the last embedding layer + lm_logits = self[len(self) - 1](seq_outputs) loss = None labels = kwargs.pop("labels", None) diff --git a/src/transformers/adapters/models/bart.py b/src/transformers/adapters/models/bart.py index 0f31ab396c..38d54c87ae 100644 --- a/src/transformers/adapters/models/bart.py +++ b/src/transformers/adapters/models/bart.py @@ -184,6 +184,7 @@ def _init_adapter_modules(self): self.invertible_adapters = self.encoder.invertible_adapters self.add_invertible_adapter = self.encoder.add_invertible_adapter self.get_invertible_adapter = self.encoder.get_invertible_adapter + self.invertible_adapters_forward = self.encoder.invertible_adapters_forward def train_adapter(self, adapter_setup: Union[list, AdapterCompositionBlock]): """Sets the model into mode for training the given adapters.""" diff --git a/tests/test_adapter_heads.py b/tests/test_adapter_heads.py index ca5f541588..fb5055c35e 100644 --- a/tests/test_adapter_heads.py +++ b/tests/test_adapter_heads.py @@ -278,3 +278,38 @@ def test_reload_static_to_flex_head(self): output1 = static_head_model(in_data, adapter_names=["test"]) output2 = flex_head_model(in_data, adapter_names=["test"]) self.assertTrue(torch.all(torch.isclose(output1.logits, output2.logits))) + + def test_invertible_adapter_with_head(self): + if not hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_masked_lm_head"): + if hasattr(MODEL_WITH_HEADS_MAPPING[self.config_class], "add_causal_lm_head"): + causal_lm_head = True + else: + self.skipTest("No masked or causel language model head") + else: + causal_lm_head = False + + model = AutoModelWithHeads.from_config(self.config()) + model.add_adapter("test", config="pfeiffer+inv") + if causal_lm_head: + model.add_causal_lm_head("test") + else: + model.add_masked_lm_head("test") + model.set_active_adapters("test") + + # Set a hook before the invertible adapter to make sure it's actually called twice: + # Once after the embedding layer and once in the prediction head. + calls = 0 + + def forward_pre_hook(module, input): + nonlocal calls + calls += 1 + + inv_adapter = model.base_model.get_invertible_adapter() + self.assertIsNotNone(inv_adapter) + inv_adapter.register_forward_pre_hook(forward_pre_hook) + + in_data = self.get_input_samples((self.batch_size, self.seq_length), config=model.config) + out = model(in_data) + + self.assertEqual((self.batch_size, self.seq_length, model.config.vocab_size), out[0].shape) + self.assertEqual(2, calls) From 035e9a4c28eb5bc882bc1c350709684ce1b8dbf2 Mon Sep 17 00:00:00 2001 From: calpt <36051308+calpt@users.noreply.github.com> Date: Mon, 26 Jul 2021 17:55:42 +0200 Subject: [PATCH 10/11] Fix output_embedding method implementation for XModelWithHeads --- src/transformers/adapters/heads/base.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/adapters/heads/base.py b/src/transformers/adapters/heads/base.py index 80336fe102..ab2fbf681e 100644 --- a/src/transformers/adapters/heads/base.py +++ b/src/transformers/adapters/heads/base.py @@ -401,17 +401,17 @@ def _init_head_modules(self): # The following methods are required for handling LM heads def get_output_embeddings(self): - all_output_embeddings = {} - - for head_name, head in self.heads.items(): - output_embeddings = head.get_output_embeddings() - if output_embeddings is not None: - all_output_embeddings[head_name] = output_embeddings - - return all_output_embeddings + # Only gets the output embeddings for the currently active head + if self.active_head in self.heads: + head = self.heads[self.active_head] + return head.get_output_embeddings() + else: + return None def set_output_embeddings(self, new_embeddings): - for head_name, head in self.heads.items(): + # Only sets the output embeddings for the currently active head + if self.active_head in self.heads: + head = self.heads[self.active_head] if head.get_output_embeddings() is not None: head.set_output_embeddings(new_embeddings) From 54d45906f62ed23e0683fc463da4373df638da68 Mon Sep 17 00:00:00 2001 From: calpt <36051308+calpt@users.noreply.github.com> Date: Mon, 26 Jul 2021 20:04:40 +0200 Subject: [PATCH 11/11] hacked fix for GPT-2 pad_token_id problem --- src/transformers/models/gpt2/modeling_gpt2.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index ed19dcb343..73662a0a6e 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1355,6 +1355,12 @@ def forward( """ The GPT2 Model that allows the loading of different heads dor different tasks. This enables a flexible use of the models and adpters. + +Since this class does classification on the last token, it requires to know the position of the last token. If a +:obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each +row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot +guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take +the last value in each row of the batch). """, GPT2_START_DOCSTRING, ) @@ -1403,10 +1409,8 @@ def forward( batch_size = outputs[0].shape[0] - assert ( - self.config.pad_token_id is not None or batch_size == 1 - ), "Cannot handle batch sizes > 1 if no padding token is defined." if self.config.pad_token_id is None: + # TODO-AH: this may result in unexpected behavior for classification. Find a better way to do this? sequence_lengths = -1 else: if input_ids is not None: