diff --git a/src/transformers/adapter_bert.py b/src/transformers/adapter_bert.py index 9ceebb1faa..12081596fd 100644 --- a/src/transformers/adapter_bert.py +++ b/src/transformers/adapter_bert.py @@ -3,18 +3,18 @@ import torch from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from .adapter_config import DEFAULT_ADAPTER_CONFIG, AdapterType +from .adapter_heads import ( + ClassificationHead, + MultiLabelClassificationHead, + MultipleChoiceHead, + QuestionAnsweringHead, + TaggingHead, +) from .adapter_model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin, ModelWithHeadsAdaptersMixin -from .adapter_modeling import Activation_Function_Class, Adapter, BertFusion +from .adapter_modeling import Adapter, BertFusion from .adapter_utils import flatten_adapter_names, parse_adapter_names -from .modeling_outputs import ( - MultipleChoiceModelOutput, - QuestionAnsweringModelOutput, - SequenceClassifierOutput, - TokenClassifierOutput, -) logger = logging.getLogger(__name__) @@ -542,16 +542,88 @@ class BertModelHeadsMixin(ModelWithHeadsAdaptersMixin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - + if not hasattr(self.config, "custom_heads"): + self.config.custom_heads = {} self.active_head = None def _init_head_modules(self): if not hasattr(self.config, "prediction_heads"): self.config.prediction_heads = {} self.heads = nn.ModuleDict(dict()) + heads_to_add = self.config.prediction_heads + self.config.prediction_heads = {} # add modules for heads in config - for head_name in self.config.prediction_heads: - self._add_prediction_head_module(head_name) + for head_name, config in heads_to_add.items(): + self.add_prediction_head_from_config(head_name, config) + + def add_prediction_head_from_config(self, head_name, config, overwrite_ok=False): + id2label = ( + {id_: label for label, id_ in config["label2id"].items()} + if "label2id" in config.keys() and config["label2id"] + else None + ) + if config["head_type"] == "classification": + self.add_classification_head( + head_name, + config["num_labels"], + config["layers"], + config["activation_function"], + id2label=id2label, + overwrite_ok=overwrite_ok, + ) + elif config["head_type"] == "multilabel_classification": + self.add_classification_head( + head_name, + config["num_labels"], + config["layers"], + config["activation_function"], + multilabel=True, + id2label=id2label, + overwrite_ok=overwrite_ok, + ) + elif config["head_type"] == "tagging": + self.add_tagging_head( + head_name, + config["num_labels"], + config["layers"], + config["activation_function"], + id2label=id2label, + overwrite_ok=overwrite_ok, + ) + elif config["head_type"] == "multiple_choice": + self.add_multiple_choice_head( + head_name, + config["num_choices"], + config["layers"], + config["activation_function"], + id2label=id2label, + overwrite_ok=overwrite_ok, + ) + elif config["head_type"] == "question_answering": + self.add_qa_head( + head_name, + config["num_labels"], + config["layers"], + config["activation_function"], + id2label=id2label, + overwrite_ok=overwrite_ok, + ) + else: + if config["head_type"] in self.config.custom_heads: + self.add_custom_head(head_name, config, overwrite_ok=overwrite_ok) + else: + raise AttributeError("Please register the PredictionHead before loading the model") + + # self._add_prediction_head_module(head_name) + + def get_prediction_heads_config(self): + heads = {} + for head_name, head in self.config.prediction_heads.items(): + heads[head_name] = head.config + return heads + + def register_custom_head(self, identifier, head): + self.config.custom_heads[identifier] = head @property def active_head(self): @@ -561,7 +633,19 @@ def active_head(self): def active_head(self, head_name): self._active_head = head_name if head_name is not None and head_name in self.config.prediction_heads: - self.config.label2id = self.config.prediction_heads[head_name]["label2id"] + self.config.label2id = self.config.prediction_heads[head_name].config["label2id"] + + self.config.id2label = self.get_labels_dict(head_name) + + @property + def active_head(self): + return self._active_head + + @active_head.setter + def active_head(self, head_name): + self._active_head = head_name + if head_name is not None and head_name in self.config.prediction_heads: + self.config.label2id = self.config.prediction_heads[head_name].config["label2id"] self.config.id2label = self.get_labels_dict(head_name) def set_active_adapters(self, adapter_names: list): @@ -603,18 +687,12 @@ def add_classification_head( overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. multilabel (bool, optional): Enable multilabel classification setup. Defaults to False. """ + if multilabel: - head_type = "multilabel_classification" + head = MultiLabelClassificationHead(head_name, num_labels, layers, activation_function, id2label, self) else: - head_type = "classification" - config = { - "head_type": head_type, - "num_labels": num_labels, - "layers": layers, - "activation_function": activation_function, - "label2id": {label: id_ for id_, label in id2label.items()} if id2label else None, - } - self.add_prediction_head(head_name, config, overwrite_ok) + head = ClassificationHead(head_name, num_labels, layers, activation_function, id2label, self) + self.add_prediction_head(head, overwrite_ok) def add_multiple_choice_head( self, head_name, num_choices=2, layers=2, activation_function="tanh", overwrite_ok=False, id2label=None @@ -628,14 +706,8 @@ def add_multiple_choice_head( activation_function (str, optional): Activation function. Defaults to 'tanh'. overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. """ - config = { - "head_type": "multiple_choice", - "num_choices": num_choices, - "layers": layers, - "activation_function": activation_function, - "label2id": {label: id_ for id_, label in id2label.items()} if id2label else None, - } - self.add_prediction_head(head_name, config, overwrite_ok) + head = MultipleChoiceHead(head_name, num_choices, layers, activation_function, id2label, self) + self.add_prediction_head(head, overwrite_ok) def add_tagging_head( self, head_name, num_labels=2, layers=1, activation_function="tanh", overwrite_ok=False, id2label=None @@ -649,72 +721,51 @@ def add_tagging_head( activation_function (str, optional): Activation function. Defaults to 'tanh'. overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. """ - config = { - "head_type": "tagging", - "num_labels": num_labels, - "layers": layers, - "activation_function": activation_function, - "label2id": {label: id_ for id_, label in id2label.items()} if id2label else None, - } - self.add_prediction_head(head_name, config, overwrite_ok) + head = TaggingHead(head_name, num_labels, layers, activation_function, id2label, self) + self.add_prediction_head(head, overwrite_ok) def add_qa_head( self, head_name, num_labels=2, layers=1, activation_function="tanh", overwrite_ok=False, id2label=None ): - config = { - "head_type": "question_answering", - "num_labels": num_labels, - "layers": layers, - "activation_function": activation_function, - "label2id": {label: id_ for id_, label in id2label.items()} if id2label else None, - } - self.add_prediction_head(head_name, config, overwrite_ok) + head = QuestionAnsweringHead(head_name, num_labels, layers, activation_function, id2label, self) + self.add_prediction_head(head, overwrite_ok) + + def add_custom_head(self, head_name, config, overwrite_ok=False): + if config["head_type"] in self.config.custom_heads: + head = self.config.custom_heads[config["head_type"]](head_name, config, self) + self.add_prediction_head(head, overwrite_ok) + else: + raise AttributeError( + "The given head as a head_type that is not registered as a custom head yet." + " Please register the head first." + ) def add_prediction_head( self, - head_name, - config, + head, overwrite_ok=False, ): - if head_name not in self.config.prediction_heads or overwrite_ok: - self.config.prediction_heads[head_name] = config - if "label2id" not in config.keys() or config["label2id"] is None: - if "num_labels" in config.keys(): - config["label2id"] = {"LABEL_" + str(num): num for num in range(config["num_labels"])} - if "num_choices" in config.keys(): - config["label2id"] = {"LABEL_" + str(num): num for num in range(config["num_choices"])} + if head.name not in self.config.prediction_heads or overwrite_ok: + self.config.prediction_heads[head.name] = head - logger.info(f"Adding head '{head_name}' with config {config}.") - self._add_prediction_head_module(head_name) - self.active_head = head_name + if "label2id" not in head.config.keys() or head.config["label2id"] is None: + if "num_labels" in head.config.keys(): + head.config["label2id"] = {"LABEL_" + str(num): num for num in range(head.config["num_labels"])} + if "num_choices" in head.config.keys(): + head.config["label2id"] = {"LABEL_" + str(num): num for num in range(head.config["num_choices"])} + + logger.info(f"Adding head '{head.name}' with config {head.config}.") + # self._add_prediction_head_module(head.name) + self.active_head = head.name else: raise ValueError( - f"Model already contains a head with name '{head_name}'. Use overwrite_ok=True to force overwrite." + f"Model already contains a head with name '{head.name}'. Use overwrite_ok=True to force overwrite." ) - def _add_prediction_head_module(self, head_name): - head_config = self.config.prediction_heads.get(head_name) - - pred_head = [] - for l in range(head_config["layers"]): - pred_head.append(nn.Dropout(self.config.hidden_dropout_prob)) - if l < head_config["layers"] - 1: - pred_head.append(nn.Linear(self.config.hidden_size, self.config.hidden_size)) - pred_head.append(Activation_Function_Class(head_config["activation_function"])) - else: - if "num_labels" in head_config: - pred_head.append(nn.Linear(self.config.hidden_size, head_config["num_labels"])) - else: # used for multiple_choice head - pred_head.append(nn.Linear(self.config.hidden_size, 1)) - - self.heads[head_name] = nn.Sequential(*pred_head) - - self.heads[head_name].apply(self._init_weights) - self.heads[head_name].train(self.training) # make sure training mode is consistent - def forward_head(self, outputs, head_name=None, attention_mask=None, return_dict=False, **kwargs): + head_name = head_name or self.active_head if not head_name: logger.debug("No prediction head is used.") @@ -722,151 +773,13 @@ def forward_head(self, outputs, head_name=None, attention_mask=None, return_dict if head_name not in self.config.prediction_heads: raise ValueError("Unknown head_name '{}'".format(head_name)) - + if not self.training: + self.config.prediction_heads[head_name].eval() + else: + self.config.prediction_heads[head_name].train() head = self.config.prediction_heads[head_name] - sequence_output = outputs[0] - loss = None - - if head["head_type"] == "classification": - logits = self.heads[head_name](sequence_output[:, 0]) - - outputs = (logits,) + outputs[2:] - labels = kwargs.pop("labels", None) - if labels is not None: - if head["num_labels"] == 1: - # We are doing regression - loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, head["num_labels"]), labels.view(-1)) - outputs = (loss,) + outputs - - if return_dict: - return SequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - else: - return outputs - - elif head["head_type"] == "multilabel_classification": - logits = self.heads[head_name](sequence_output[:, 0]) - - outputs = (logits,) + outputs[2:] - labels = kwargs.pop("labels", None) - if labels is not None: - loss_fct = BCEWithLogitsLoss() - if labels.dtype != torch.float32: - labels = labels.float() - loss = loss_fct(logits, labels) - outputs = (loss,) + outputs - - if return_dict: - return SequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - else: - return outputs - - elif head["head_type"] == "multiple_choice": - logits = self.heads[head_name](sequence_output[:, 0]) - logits = logits.view(-1, head["num_choices"]) - - outputs = (logits,) + outputs[2:] - labels = kwargs.pop("labels", None) - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits, labels) - outputs = (loss,) + outputs - - if return_dict: - return MultipleChoiceModelOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - else: - return outputs - - elif head["head_type"] == "tagging": - logits = self.heads[head_name](sequence_output) - - outputs = (logits,) + outputs[2:] - labels = kwargs.pop("labels", None) - if labels is not None: - loss_fct = CrossEntropyLoss() - # Only keep active parts of the loss - if attention_mask is not None: - active_loss = attention_mask.view(-1) == 1 - active_logits = logits.view(-1, self.num_labels) - active_labels = torch.where( - active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) - ) - loss = loss_fct(active_logits, active_labels) - else: - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - outputs = (loss,) + outputs - - if return_dict: - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - else: - return outputs - - elif head["head_type"] == "question_answering": - logits = self.heads[head_name](sequence_output) - - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - outputs = ( - start_logits, - end_logits, - ) + outputs[2:] - start_positions = kwargs.pop("start_positions", None) - end_positions = kwargs.pop("end_positions", None) - if start_positions is not None and end_positions is not None: - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions.clamp_(0, ignored_index) - end_positions.clamp_(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - outputs = (total_loss,) + outputs - - if return_dict: - return QuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - else: - return outputs - - else: - raise ValueError("Unknown head_type '{}'".format(head["head_type"])) + return head(outputs, attention_mask, return_dict, **kwargs) def get_labels_dict(self, head_name=None): """ @@ -882,8 +795,8 @@ def get_labels_dict(self, head_name=None): head_name = self.active_head if head_name is None: raise ValueError("No head name given and no active head in the model") - if "label2id" in self.config.prediction_heads[head_name].keys(): - return {id_: label for label, id_ in self.config.prediction_heads[head_name]["label2id"].items()} + if "label2id" in self.config.prediction_heads[head_name].config.keys(): + return {id_: label for label, id_ in self.config.prediction_heads[head_name].config["label2id"].items()} else: return None diff --git a/src/transformers/adapter_heads.py b/src/transformers/adapter_heads.py new file mode 100644 index 0000000000..3105b92c5e --- /dev/null +++ b/src/transformers/adapter_heads.py @@ -0,0 +1,261 @@ +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from .adapter_modeling import Activation_Function_Class +from .modeling_outputs import ( + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) + + +class PredictionHead(nn.Module): + def __init__(self, name): + super().__init__() + self.config = None + self.head = None + self.name = name + + def build( + self, + model, + ): # _init_weights): + model_config = model.config + pred_head = [] + for l in range(self.config["layers"]): + pred_head.append(nn.Dropout(model_config.hidden_dropout_prob)) + if l < self.config["layers"] - 1: + pred_head.append(nn.Linear(model_config.hidden_size, model_config.hidden_size)) + pred_head.append(Activation_Function_Class(self.config["activation_function"])) + else: + if "num_labels" in self.config: + pred_head.append(nn.Linear(model_config.hidden_size, self.config["num_labels"])) + else: # used for multiple_choice head + pred_head.append(nn.Linear(model_config.hidden_size, 1)) + self.head = nn.Sequential(*pred_head) + + self.head.apply(model._init_weights) + + def forward(self, outputs, attention_mask, return_dict, **kwarg): + raise NotImplementedError("Use a Prediction Head that inherits from this class") + + def save_head(self, path): + torch.save(self, path) + + @staticmethod + def load_head(self, path, load_as): + head = torch.load(path) + if load_as: + head.name = load_as + return head + + +class ClassificationHead(PredictionHead): + def __init__(self, head_name, num_labels, layers, activation_function, id2label, model): + super().__init__(head_name) + self.config = { + "head_type": "classification", + "num_labels": num_labels, + "layers": layers, + "activation_function": activation_function, + "label2id": {label: id_ for id_, label in id2label.items()} if id2label else None, + } + self.build(model) + + def forward(self, outputs, attention_mask, return_dict, **kwargs): + + logits = self.head(outputs[0][:, 0]) + loss = None + labels = kwargs.pop("labels", None) + if labels is not None: + if self.config["num_labels"] == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config["num_labels"]), labels.view(-1)) + + if return_dict: + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + outputs = (logits,) + outputs[2:] + if labels is not None: + outputs = (loss,) + outputs + return outputs + + +class MultiLabelClassificationHead(PredictionHead): + def __init__(self, head_name, num_labels, layers, activation_function, id2label, model): + super().__init__(head_name) + self.config = { + "head_type": "multilabel_classification", + "num_labels": num_labels, + "layers": layers, + "activation_function": activation_function, + "label2id": {label: id_ for id_, label in id2label.items()} if id2label else None, + } + self.build(model) + + def forward(self, outputs, attention_mask, return_dict, **kwargs): + logits = self.head(outputs[0][:, 0]) + loss = None + labels = kwargs.pop("labels", None) + if labels is not None: + loss_fct = BCEWithLogitsLoss() + if labels.dtype != torch.float32: + labels = labels.float() + loss = loss_fct(logits, labels) + + if return_dict: + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + outputs = (logits,) + outputs[2:] + if labels is not None: + outputs = (loss,) + outputs + return outputs + + +class MultipleChoiceHead(PredictionHead): + def __init__(self, head_name, num_choices, layers, activation_function, id2label, model): + super().__init__(head_name) + self.config = { + "head_type": "multiple_choice", + "num_choices": num_choices, + "layers": layers, + "activation_function": activation_function, + "label2id": {label: id_ for id_, label in id2label.items()} if id2label else None, + } + self.build(model) + + def forward(self, outputs, attention_mask, return_dict, **kwargs): + logits = self.head(outputs[0][:, 0]) + logits = logits.view(-1, self.config["num_choices"]) + loss = None + labels = kwargs.pop("labels", None) + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits, labels) + + if return_dict: + return MultipleChoiceModelOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + outputs = (logits,) + outputs[2:] + if labels is not None: + outputs = (loss,) + outputs + return outputs + + +class TaggingHead(PredictionHead): + def __init__(self, head_name, num_labels, layers, activation_function, id2label, model): + super().__init__(head_name) + self.config = { + "head_type": "tagging", + "num_labels": num_labels, + "layers": layers, + "activation_function": activation_function, + "label2id": {label: id_ for id_, label in id2label.items()} if id2label else None, + } + self.build(model) + + def forward(self, outputs, attention_mask, return_dict, **kwargs): + logits = self.head(outputs[0]) + loss = None + + labels = kwargs.pop("labels", None) + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if return_dict: + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + outputs = (logits,) + outputs[2:] + if labels is not None: + outputs = (loss,) + outputs + return outputs + + +class QuestionAnsweringHead(PredictionHead): + def __init__(self, head_name, num_labels, layers, activation_function, id2label, model): + super().__init__(head_name) + self.config = { + "head_type": "question_answering", + "num_labels": num_labels, + "layers": layers, + "activation_function": activation_function, + "label2id": {label: id_ for id_, label in id2label.items()} if id2label else None, + } + self.build(model) + + def forward(self, outputs, attention_mask, return_dict, **kwargs): + sequence_output = outputs[0] + logits = self.heads[self.name](sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + outputs = ( + start_logits, + end_logits, + ) + outputs[2:] + start_positions = kwargs.pop("start_positions", None) + end_positions = kwargs.pop("end_positions", None) + total_loss = None + if start_positions is not None and end_positions is not None: + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + outputs = (total_loss,) + outputs + + if return_dict: + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + return outputs diff --git a/src/transformers/adapter_model_mixin.py b/src/transformers/adapter_model_mixin.py index 1f1112f8e3..ea3b4528be 100644 --- a/src/transformers/adapter_model_mixin.py +++ b/src/transformers/adapter_model_mixin.py @@ -68,14 +68,17 @@ def save_weights_config(self, save_directory, config, meta_dict=None): json.dump(config, f, indent=2, sort_keys=True) logger.info("Configuration saved in {}".format(output_config_file)) - def save_weights(self, save_directory, filter_func): + def save_weights(self, save_directory, filter_func, head_name=None): if not exists(save_directory): mkdir(save_directory) else: assert isdir(save_directory), "Saving path should be a directory where the module weights can be saved." # Get the state of all adapter modules for this task - state_dict = self.state_dict(filter_func) + if head_name: + state_dict = self.model.config.prediction_heads[head_name].state_dict() + else: + state_dict = self.state_dict(filter_func) # Save the adapter weights output_file = join(save_directory, self.weights_name) torch.save(state_dict, output_file) @@ -120,7 +123,15 @@ def load(module, prefix=""): ) return missing_keys, unexpected_keys - def load_weights(self, save_directory, filter_func, rename_func=None, loading_info=None, in_base_model=False): + def load_weights( + self, + save_directory, + filter_func, + rename_func=None, + loading_info=None, + in_base_model=False, + head_name=None, + ): weights_file = join(save_directory, self.weights_name) # Load the weights of the adapter try: @@ -137,7 +148,10 @@ def load_weights(self, save_directory, filter_func, rename_func=None, loading_in # Add the weights to the model # Make sure we are able to load base models as well as derived models (with heads) start_prefix = "" - model_to_load = self.model + if head_name: + model_to_load = self.model.config.prediction_heads[head_name] + else: + model_to_load = self.model has_prefix_module = any(s.startswith(self.model.base_model_prefix) for s in state_dict.keys()) if not hasattr(self.model, self.model.base_model_prefix) and has_prefix_module: start_prefix = self.model.base_model_prefix + "." @@ -551,7 +565,8 @@ def save(self, save_directory: str, name: str = None): # if we use a custom head, save it if name and hasattr(self.model.config, "prediction_heads"): - head_config = self.model.config.prediction_heads[name] + head = self.model.config.prediction_heads[name] + head_config = head.config else: head_config = None @@ -567,8 +582,9 @@ def save(self, save_directory: str, name: str = None): self.weights_helper.save_weights_config(save_directory, config_dict) # Save head weights + filter_func = self.filter_func(name) - self.weights_helper.save_weights(save_directory, filter_func) + self.weights_helper.save_weights(save_directory, filter_func, head_name=name) def load(self, save_directory, load_as=None, loading_info=None): """Loads a prediction head module from the given directory. @@ -610,7 +626,7 @@ def load(self, save_directory, load_as=None, loading_info=None): head_name = load_as or config["name"] if head_name in self.model.config.prediction_heads: logger.warning("Overwriting existing head '{}'".format(head_name)) - self.model.add_prediction_head(head_name, config["config"], overwrite_ok=True) + self.model.add_prediction_head_from_config(head_name, config["config"], overwrite_ok=True) else: if "label2id" in config.keys(): self.model.config.id2label = {int(id_): label for label, id_ in config["label2id"].items()} @@ -621,9 +637,14 @@ def load(self, save_directory, load_as=None, loading_info=None): rename_func = self.rename_func(config["name"], load_as) else: rename_func = None - self.weights_helper.load_weights( - save_directory, filter_func, rename_func=rename_func, loading_info=loading_info - ) + if hasattr(self.model.config, "prediction_heads"): + self.weights_helper.load_weights( + save_directory, filter_func, rename_func=rename_func, loading_info=loading_info, head_name=head_name + ) + else: + self.weights_helper.load_weights( + save_directory, filter_func, rename_func=rename_func, loading_info=loading_info + ) return save_directory, head_name diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index f94e2c658f..d159116f45 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -21,6 +21,8 @@ import os from typing import Any, Dict, Tuple +from transformers.adapter_heads import PredictionHead + from .adapter_utils import DataclassJSONEncoder from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url from .utils import logging @@ -444,6 +446,8 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig": if hasattr(config, key): setattr(config, key, value) to_remove.append(key) + elif key == "custom_heads": + setattr(config, key, value) for key in to_remove: kwargs.pop(key, None) @@ -522,6 +526,12 @@ def to_dict(self) -> Dict[str, Any]: output["model_type"] = self.__class__.model_type if hasattr(self, "adapters") and not isinstance(output["adapters"], dict): output["adapters"] = self.adapters.to_dict() + if "prediction_heads" in output.keys(): + for head in output["prediction_heads"].keys(): + if isinstance(output["prediction_heads"][head], PredictionHead): + output["prediction_heads"][head] = output["prediction_heads"][head].config + if "custom_heads" in output.keys(): + del output["custom_heads"] return output def to_json_string(self, use_diff: bool = True) -> str: diff --git a/tests/test_adapter_common.py b/tests/test_adapter_common.py index 0d2698ce0e..42bd8be2db 100644 --- a/tests/test_adapter_common.py +++ b/tests/test_adapter_common.py @@ -185,7 +185,6 @@ def test_adapter_with_head(self): model2.load_adapter(temp_dir) model2.set_active_adapters(name) - # check equal output in_data = ids_tensor((1, 128), 1000) output1 = model1(in_data) @@ -223,14 +222,14 @@ def test_load_full_model(self): with self.subTest(model_class=model_class.__name__): model = model_class(model_class.config_class()) model.add_tagging_head("dummy") - true_config = model.config.prediction_heads + true_config = model.get_prediction_heads_config() with tempfile.TemporaryDirectory() as temp_dir: # save model.save_pretrained(temp_dir) # reload model = model_class.from_pretrained(temp_dir) self.assertIn("dummy", model.config.prediction_heads) - self.assertDictEqual(true_config, model.config.prediction_heads) + self.assertDictEqual(true_config, model.get_prediction_heads_config()) @require_torch diff --git a/tests/test_adapter_custom_head.py b/tests/test_adapter_custom_head.py new file mode 100644 index 0000000000..03c5ab7cb1 --- /dev/null +++ b/tests/test_adapter_custom_head.py @@ -0,0 +1,72 @@ +import tempfile +import unittest + +import torch + +from tests.test_modeling_common import ids_tensor +from transformers import AutoConfig, AutoModelWithHeads +from transformers.adapter_heads import PredictionHead + + +class CustomHead(PredictionHead): + def __init__(self, name, config, model): + super().__init__(name) + self.config = config + self.build(model=model) + + def forward(self, outputs, attention_mask, return_dict, **kwargs): + logits = self.head(outputs[0]) + outputs = (logits,) + outputs[2:] + return outputs + + +class AdapterCustomHeadTest(unittest.TestCase): + def test_add_custom_head(self): + model_name = "bert-base-uncased" + model = AutoModelWithHeads.from_pretrained(model_name) + model.register_custom_head("tag", CustomHead) + config = {"head_type": "tag", "num_labels": 3, "layers": 2, "activation_function": "tanh"} + model.add_custom_head("custom_head", config) + model.eval() + in_data = ids_tensor((1, 128), 1000) + output1 = model(in_data) + model.add_tagging_head("tagging_head", num_labels=3, layers=2) + output2 = model(in_data) + self.assertEqual(output1[0].size(), output2[0].size()) + + def test_custom_head_from_model_config(self): + model_name = "bert-base-uncased" + model_config = AutoConfig.from_pretrained(model_name, custom_heads={"tag": CustomHead}) + model = AutoModelWithHeads.from_pretrained(model_name, config=model_config) + config = {"head_type": "tag", "num_labels": 3, "layers": 2, "activation_function": "tanh"} + model.add_custom_head("custom_head", config) + model.eval() + in_data = ids_tensor((1, 128), 1000) + output1 = model(in_data) + model.add_tagging_head("tagging_head", num_labels=3, layers=2) + output2 = model(in_data) + self.assertEqual(output1[0].size(), output2[0].size()) + + def test_save_load_custom_head(self): + model_name = "bert-base-uncased" + model_config = AutoConfig.from_pretrained(model_name, custom_heads={"tag": CustomHead}) + model1 = AutoModelWithHeads.from_pretrained(model_name, config=model_config) + model2 = AutoModelWithHeads.from_pretrained(model_name, config=model_config) + config = {"head_type": "tag", "num_labels": 3, "layers": 2, "activation_function": "tanh"} + model1.add_custom_head("custom_head", config) + + with tempfile.TemporaryDirectory() as temp_dir: + model1.save_head(temp_dir, "custom_head") + model2.load_head(temp_dir) + + model1.eval() + model2.eval() + + in_data = ids_tensor((1, 128), 1000) + output1 = model1(in_data) + output2 = model2(in_data) + self.assertEqual(output1[0].size(), output2[0].size()) + state1 = model1.config.prediction_heads["custom_head"].state_dict() + state2 = model2.config.prediction_heads["custom_head"].state_dict() + for ((k1, v1), (k2, v2)) in zip(state1.items(), state2.items()): + self.assertTrue(torch.equal(v1, v2))