diff --git a/jiant/proj/main/modeling/heads.py b/jiant/proj/main/modeling/heads.py index 38b637c99..11cdcd6e9 100644 --- a/jiant/proj/main/modeling/heads.py +++ b/jiant/proj/main/modeling/heads.py @@ -4,7 +4,13 @@ import torch.nn as nn import transformers + from jiant.ext.allennlp import SelfAttentiveSpanExtractor +from jiant.shared.model_resolution import ModelArchitectures +from jiant.tasks.core import TaskTypes +from typing import Callable +from typing import List + """ In HuggingFace/others, these heads differ slightly across different encoder models. @@ -12,18 +18,41 @@ """ +class JiantHeadFactory: + # Internal registry for available task models + registry = {} + + @classmethod + def register(cls, task_type_list: List[TaskTypes]) -> Callable: + def inner_wrapper(wrapped_class: BaseHead) -> Callable: + for task_type in task_type_list: + assert task_type not in cls.registry + cls.registry[task_type] = wrapped_class + return wrapped_class + + return inner_wrapper + + def __call__(self, task, **kwargs): + head_class = self.registry[task.TASK_TYPE] + head = head_class(task, **kwargs) + return head + + class BaseHead(nn.Module, metaclass=abc.ABCMeta): - pass + @abc.abstractmethod + def __init__(self): + super().__init__() +@JiantHeadFactory.register([TaskTypes.CLASSIFICATION]) class ClassificationHead(BaseHead): - def __init__(self, hidden_size, hidden_dropout_prob, num_labels): + def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs): """From RobertaClassificationHead""" super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) self.dropout = nn.Dropout(hidden_dropout_prob) - self.out_proj = nn.Linear(hidden_size, num_labels) - self.num_labels = num_labels + self.out_proj = nn.Linear(hidden_size, task.num_labels) + self.num_labels = len(task.LABELS) def forward(self, pooled): x = self.dropout(pooled) @@ -34,8 +63,9 @@ def forward(self, pooled): return logits +@JiantHeadFactory.register([TaskTypes.REGRESSION, TaskTypes.MULTIPLE_CHOICE]) class RegressionHead(BaseHead): - def __init__(self, hidden_size, hidden_dropout_prob): + def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs): """From RobertaClassificationHead""" super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) @@ -51,12 +81,13 @@ def forward(self, pooled): return scores +@JiantHeadFactory.register([TaskTypes.SPAN_COMPARISON_CLASSIFICATION]) class SpanComparisonHead(BaseHead): - def __init__(self, hidden_size, hidden_dropout_prob, num_spans, num_labels): + def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs): """From RobertaForSpanComparisonClassification""" super().__init__() - self.num_spans = num_spans - self.num_labels = num_labels + self.num_spans = task.num_spans + self.num_labels = len(task.LABELS) self.hidden_size = hidden_size self.dropout = nn.Dropout(hidden_dropout_prob) self.span_attention_extractor = SelfAttentiveSpanExtractor(hidden_size) @@ -70,13 +101,14 @@ def forward(self, unpooled, spans): return logits +@JiantHeadFactory.register([TaskTypes.TAGGING]) class TokenClassificationHead(BaseHead): - def __init__(self, hidden_size, num_labels, hidden_dropout_prob): + def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs): """From RobertaForTokenClassification""" super().__init__() - self.num_labels = num_labels + self.num_labels = len(task.LABELS) self.dropout = nn.Dropout(hidden_dropout_prob) - self.classifier = nn.Linear(hidden_size, num_labels) + self.classifier = nn.Linear(hidden_size, self.num_labels) def forward(self, unpooled): unpooled = self.dropout(unpooled) @@ -84,8 +116,9 @@ def forward(self, unpooled): return logits +@JiantHeadFactory.register([TaskTypes.SQUAD_STYLE_QA]) class QAHead(BaseHead): - def __init__(self, hidden_size): + def __init__(self, task, hidden_size, **kwargs): """From RobertaForQuestionAnswering""" super().__init__() self.qa_outputs = nn.Linear(hidden_size, 2) @@ -98,10 +131,37 @@ def forward(self, unpooled): return logits +@JiantHeadFactory.register([TaskTypes.MASKED_LANGUAGE_MODELING]) +class JiantMLMHeadFactory: + # Internal registry for available task models + registry = {} + + @classmethod + def register(cls, model_arch_list: List[ModelArchitectures]) -> Callable: + def inner_wrapper(wrapped_class: BaseMLMHead) -> Callable: + for model_arch in model_arch_list: + assert model_arch not in cls.registry + cls.registry[model_arch] = wrapped_class + return wrapped_class + + return inner_wrapper + + def __call__( + self, + task, + **kwargs + # task_type: TaskTypes, model_arch: ModelArchitectures, hidden_size, hidden_dropout_prob + ): + mlm_head_class = self.registry[task.TASK_TYPE] + mlm_head = mlm_head_class(task, **kwargs) + return mlm_head + + class BaseMLMHead(BaseHead, metaclass=abc.ABCMeta): pass +@JiantMLMHeadFactory.register([ModelArchitectures.BERT]) class BertMLMHead(BaseMLMHead): """From BertOnlyMLMHead, BertLMPredictionHead, BertPredictionHeadTransform""" @@ -126,6 +186,7 @@ def forward(self, unpooled): return logits +@JiantMLMHeadFactory.register([ModelArchitectures.ROBERTA, ModelArchitectures.XLM_ROBERTA]) class RobertaMLMHead(BaseMLMHead): """From RobertaLMHead""" @@ -151,7 +212,8 @@ def forward(self, unpooled): return logits -class AlbertMLMHead(nn.Module): +@JiantMLMHeadFactory.register([ModelArchitectures.ALBERT]) +class AlbertMLMHead(BaseMLMHead): """From AlbertMLMHead""" def __init__(self, hidden_size, embedding_size, vocab_size, hidden_act="gelu"): diff --git a/jiant/proj/main/modeling/model_setup.py b/jiant/proj/main/modeling/model_setup.py index 6ff546e72..4a6865a19 100644 --- a/jiant/proj/main/modeling/model_setup.py +++ b/jiant/proj/main/modeling/model_setup.py @@ -2,7 +2,6 @@ from typing import Any from typing import Dict from typing import List -from typing import Optional import torch import torch.nn as nn @@ -10,14 +9,14 @@ import jiant.proj.main.components.container_setup as container_setup -import jiant.proj.main.modeling.heads as heads import jiant.proj.main.modeling.primary as primary -import jiant.proj.main.modeling.taskmodels as taskmodels import jiant.utils.python.strings as strings +from jiant.proj.main.modeling.heads import JiantHeadFactory +from jiant.proj.main.modeling.taskmodels import JiantTaskModelFactory, Taskmodel, MLMModel + from jiant.shared.model_resolution import ModelArchitectures from jiant.tasks import Task -from jiant.tasks import TaskTypes def setup_jiant_model( @@ -52,19 +51,13 @@ def setup_jiant_model( JiantModel nn.Module. """ - model = transformers.AutoModel.from_pretrained(hf_pretrained_model_name_or_path) - model_arch = ModelArchitectures.from_model_type(model.base_model_prefix) - transformers_class_spec = TRANSFORMERS_CLASS_SPEC_DICT[model_arch] + hf_model = transformers.AutoModel.from_pretrained(hf_pretrained_model_name_or_path) tokenizer = transformers.AutoTokenizer.from_pretrained(hf_pretrained_model_name_or_path) - ancestor_model = get_ancestor_model( - transformers_class_spec=transformers_class_spec, model_config_path=model_config_path, - ) - encoder = get_encoder(model_arch=model_arch, ancestor_model=ancestor_model) + jiant_transformers_model = primary.JiantTransformersModelFactory.create_model(hf_model) taskmodels_dict = { taskmodel_name: create_taskmodel( task=task_dict[task_name_list[0]], # Take the first task - model_arch=model_arch, - encoder=encoder, + jiant_transformers_model=jiant_transformers_model, taskmodel_kwargs=taskmodels_config.get_taskmodel_kwargs(taskmodel_name), ) for taskmodel_name, task_name_list in get_taskmodel_and_task_names( @@ -73,7 +66,7 @@ def setup_jiant_model( } return primary.JiantModel( task_dict=task_dict, - encoder=encoder, + encoder=jiant_transformers_model, taskmodels_dict=taskmodels_dict, task_to_taskmodel_map=taskmodels_config.task_to_taskmodel_map, tokenizer=tokenizer, @@ -162,7 +155,7 @@ def load_encoder_from_transformers_weights( remainder_weights_dict = {} load_weights_dict = {} model_arch = ModelArchitectures.from_encoder(encoder=encoder) - encoder_prefix = MODEL_PREFIX[model_arch] + "." + encoder_prefix = model_arch.value + "." # Encoder for k, v in weights_dict.items(): if k.startswith(encoder_prefix): @@ -200,7 +193,7 @@ def load_lm_heads_from_transformers_weights(jiant_model, weights_dict): raise KeyError(model_arch) missed = set() for taskmodel_name, taskmodel in jiant_model.taskmodels_dict.items(): - if not isinstance(taskmodel, taskmodels.MLMModel): + if not isinstance(taskmodel, MLMModel): continue mismatch = taskmodel.mlm_head.load_state_dict(mlm_weights_dict) assert not mismatch.missing_keys @@ -273,15 +266,14 @@ def load_partial_heads( return result -def create_taskmodel( - task, model_arch, encoder, taskmodel_kwargs: Optional[Dict] = None -) -> taskmodels.Taskmodel: +def create_taskmodel(task, jiant_transformers_model, **taskmodel_kwargs) -> Taskmodel: """Creates, initializes and returns the task model for a given task type and encoder. Args: task (Task): Task object associated with the taskmodel being created. model_arch (ModelArchitectures.Any): Model architecture (e.g., ModelArchitectures.BERT). - encoder (PreTrainedModel): Transformer w/o heads (embedding layer + self-attention layer). + jiant_transformers_model (JiantTransformersModelWrapper): Transformer w/o heads + (embedding layer + self-attention layer). taskmodel_kwargs (Optional[Dict]): map containing any kwargs needed for taskmodel setup. Raises: @@ -291,178 +283,20 @@ def create_taskmodel( Taskmodel (e.g., ClassificationModel) appropriate for the task type and encoder. """ - if model_arch in [ - ModelArchitectures.BERT, - ModelArchitectures.ROBERTA, - ModelArchitectures.ALBERT, - ModelArchitectures.XLM_ROBERTA, - ModelArchitectures.ELECTRA, - ]: - hidden_size = encoder.config.hidden_size - hidden_dropout_prob = encoder.config.hidden_dropout_prob - elif model_arch in [ - ModelArchitectures.BART, - ModelArchitectures.MBART, - ]: - hidden_size = encoder.config.d_model - hidden_dropout_prob = encoder.config.dropout - else: - raise KeyError() - - if task.TASK_TYPE == TaskTypes.CLASSIFICATION: - assert taskmodel_kwargs is None - classification_head = heads.ClassificationHead( - hidden_size=hidden_size, - hidden_dropout_prob=hidden_dropout_prob, - num_labels=len(task.LABELS), - ) - taskmodel = taskmodels.ClassificationModel( - encoder=encoder, classification_head=classification_head, - ) - elif task.TASK_TYPE == TaskTypes.REGRESSION: - assert taskmodel_kwargs is None - regression_head = heads.RegressionHead( - hidden_size=hidden_size, hidden_dropout_prob=hidden_dropout_prob, - ) - taskmodel = taskmodels.RegressionModel(encoder=encoder, regression_head=regression_head) - elif task.TASK_TYPE == TaskTypes.MULTIPLE_CHOICE: - assert taskmodel_kwargs is None - choice_scoring_head = heads.RegressionHead( - hidden_size=hidden_size, hidden_dropout_prob=hidden_dropout_prob, - ) - taskmodel = taskmodels.MultipleChoiceModel( - encoder=encoder, num_choices=task.NUM_CHOICES, choice_scoring_head=choice_scoring_head, - ) - elif task.TASK_TYPE == TaskTypes.SPAN_PREDICTION: - assert taskmodel_kwargs is None - span_prediction_head = heads.TokenClassificationHead( - hidden_size=hidden_size, - hidden_dropout_prob=encoder.config.hidden_dropout_prob, - num_labels=2, - ) - taskmodel = taskmodels.SpanPredictionModel( - encoder=encoder, span_prediction_head=span_prediction_head, - ) - elif task.TASK_TYPE == TaskTypes.SPAN_COMPARISON_CLASSIFICATION: - assert taskmodel_kwargs is None - span_comparison_head = heads.SpanComparisonHead( - hidden_size=hidden_size, - hidden_dropout_prob=hidden_dropout_prob, - num_spans=task.num_spans, - num_labels=len(task.LABELS), - ) - taskmodel = taskmodels.SpanComparisonModel( - encoder=encoder, span_comparison_head=span_comparison_head, - ) - elif task.TASK_TYPE == TaskTypes.MULTI_LABEL_SPAN_CLASSIFICATION: - assert taskmodel_kwargs is None - span_comparison_head = heads.SpanComparisonHead( - hidden_size=hidden_size, - hidden_dropout_prob=hidden_dropout_prob, - num_spans=task.num_spans, - num_labels=len(task.LABELS), - ) - taskmodel = taskmodels.MultiLabelSpanComparisonModel( - encoder=encoder, span_comparison_head=span_comparison_head, - ) - elif task.TASK_TYPE == TaskTypes.TAGGING: - assert taskmodel_kwargs is None - token_classification_head = heads.TokenClassificationHead( - hidden_size=hidden_size, - hidden_dropout_prob=hidden_dropout_prob, - num_labels=len(task.LABELS), - ) - taskmodel = taskmodels.TokenClassificationModel( - encoder=encoder, token_classification_head=token_classification_head, - ) - elif task.TASK_TYPE == TaskTypes.SQUAD_STYLE_QA: - assert taskmodel_kwargs is None - qa_head = heads.QAHead(hidden_size=hidden_size) - taskmodel = taskmodels.QAModel(encoder=encoder, qa_head=qa_head) - elif task.TASK_TYPE == TaskTypes.MASKED_LANGUAGE_MODELING: - assert taskmodel_kwargs is None - if model_arch == ModelArchitectures.BERT: - mlm_head = heads.BertMLMHead( - hidden_size=hidden_size, - vocab_size=encoder.config.vocab_size, - layer_norm_eps=encoder.config.layer_norm_eps, - hidden_act=encoder.config.hidden_act, - ) - elif model_arch == ModelArchitectures.ROBERTA: - mlm_head = heads.RobertaMLMHead( - hidden_size=hidden_size, - vocab_size=encoder.config.vocab_size, - layer_norm_eps=encoder.config.layer_norm_eps, - ) - elif model_arch == ModelArchitectures.ALBERT: - mlm_head = heads.AlbertMLMHead( - hidden_size=hidden_size, - embedding_size=encoder.config.embedding_size, - vocab_size=encoder.config.vocab_size, - hidden_act=encoder.config.hidden_act, - ) - elif model_arch == ModelArchitectures.XLM_ROBERTA: - mlm_head = heads.RobertaMLMHead( - hidden_size=hidden_size, - vocab_size=encoder.config.vocab_size, - layer_norm_eps=encoder.config.layer_norm_eps, - ) - elif model_arch in ( - ModelArchitectures.BART, - ModelArchitectures.MBART, - ModelArchitectures.ELECTRA, - ): - raise NotImplementedError() - else: - raise KeyError(model_arch) - taskmodel = taskmodels.MLMModel(encoder=encoder, mlm_head=mlm_head) - elif task.TASK_TYPE == TaskTypes.EMBEDDING: - if taskmodel_kwargs["pooler_type"] == "mean": - pooler_head = heads.MeanPoolerHead() - elif taskmodel_kwargs["pooler_type"] == "first": - pooler_head = heads.FirstPoolerHead() - else: - raise KeyError(taskmodel_kwargs["pooler_type"]) - taskmodel = taskmodels.EmbeddingModel( - encoder=encoder, pooler_head=pooler_head, layer=taskmodel_kwargs["layer"], - ) - else: - raise KeyError(task.TASK_TYPE) - return taskmodel - - -def get_encoder(model_arch, ancestor_model): - """From model architecture, get the encoder (encoder = embedding layer + self-attention layer). - - This function will return the "The bare Bert Model transformer outputting raw hidden-states - without any specific head on top", when provided with ModelArchitectures and BertForPreTraining - model. See Hugging Face's BertForPreTraining and BertModel documentation for more info. - - Args: - model_arch: Model architecture. - ancestor_model: Model with pretraining heads attached. - - Raises: - KeyError if ModelArchitectures + head_kwargs = {} + head_kwargs["hidden_size"] = jiant_transformers_model.get_hidden_size() + head_kwargs["hidden_dropout_prob"] = jiant_transformers_model.get_hidden_dropout_prob() + head_kwargs["vocab_size"] = jiant_transformers_model.config.vocab_size + head_kwargs["layer_norm_eps"] = (jiant_transformers_model.config.layer_norm_eps,) + head_kwargs["hidden_act"] = jiant_transformers_model.config.hidden_act + head_kwargs["model_arch"] = ModelArchitectures.from_model_type( + jiant_transformers_model.config.model_type + ) - Returns: - Bare pretrained model outputting raw hidden-states without a specific head on top. + head = JiantHeadFactory()(task, **head_kwargs) - """ - if model_arch == ModelArchitectures.BERT: - return ancestor_model.bert - elif model_arch == ModelArchitectures.ROBERTA: - return ancestor_model.roberta - elif model_arch == ModelArchitectures.ALBERT: - return ancestor_model.albert - elif model_arch == ModelArchitectures.XLM_ROBERTA: - return ancestor_model.roberta - elif model_arch in (ModelArchitectures.BART, ModelArchitectures.MBART): - return ancestor_model.model - elif model_arch == ModelArchitectures.ELECTRA: - return ancestor_model.electra - else: - raise KeyError(model_arch) + taskmodel = JiantTaskModelFactory()(task, jiant_transformers_model, head, **taskmodel_kwargs) + return taskmodel @dataclass @@ -472,45 +306,6 @@ class TransformersClassSpec: model_class: Any -TRANSFORMERS_CLASS_SPEC_DICT = { - ModelArchitectures.BERT: TransformersClassSpec( - config_class=transformers.BertConfig, - tokenizer_class=transformers.BertTokenizer, - model_class=transformers.BertForPreTraining, - ), - ModelArchitectures.ROBERTA: TransformersClassSpec( - config_class=transformers.RobertaConfig, - tokenizer_class=transformers.RobertaTokenizer, - model_class=transformers.RobertaForMaskedLM, - ), - ModelArchitectures.ALBERT: TransformersClassSpec( - config_class=transformers.AlbertConfig, - tokenizer_class=transformers.AlbertTokenizer, - model_class=transformers.AlbertForMaskedLM, - ), - ModelArchitectures.XLM_ROBERTA: TransformersClassSpec( - config_class=transformers.XLMRobertaConfig, - tokenizer_class=transformers.XLMRobertaTokenizer, - model_class=transformers.XLMRobertaForMaskedLM, - ), - ModelArchitectures.BART: TransformersClassSpec( - config_class=transformers.BartConfig, - tokenizer_class=transformers.BartTokenizer, - model_class=transformers.BartForConditionalGeneration, - ), - ModelArchitectures.MBART: TransformersClassSpec( - config_class=transformers.BartConfig, - tokenizer_class=transformers.MBartTokenizer, - model_class=transformers.BartForConditionalGeneration, - ), - ModelArchitectures.ELECTRA: TransformersClassSpec( - config_class=transformers.ElectraConfig, - tokenizer_class=transformers.ElectraTokenizer, - model_class=transformers.ElectraForPreTraining, - ), -} - - def get_taskmodel_and_task_names(task_to_taskmodel_map: Dict[str, str]) -> Dict[str, List[str]]: """Get mapping from task model name to the list of task names associated with that task model. @@ -533,17 +328,6 @@ def get_model_arch_from_jiant_model(jiant_model: nn.Module) -> ModelArchitecture return ModelArchitectures.from_encoder(encoder=jiant_model.encoder) -MODEL_PREFIX = { - ModelArchitectures.BERT: "bert", - ModelArchitectures.ROBERTA: "roberta", - ModelArchitectures.ALBERT: "albert", - ModelArchitectures.XLM_ROBERTA: "xlm-roberta", - ModelArchitectures.BART: "model", - ModelArchitectures.MBART: "model", - ModelArchitectures.ELECTRA: "electra", -} - - def get_ancestor_model(transformers_class_spec, model_config_path): """Load the model config from a file, configure the model, and return the model. diff --git a/jiant/proj/main/modeling/primary.py b/jiant/proj/main/modeling/primary.py index bf5f398a4..24e3e65b6 100644 --- a/jiant/proj/main/modeling/primary.py +++ b/jiant/proj/main/modeling/primary.py @@ -1,10 +1,14 @@ -from typing import Dict, Union +from typing import Dict +from typing import Union import torch.nn as nn +import transformers import jiant.proj.main.modeling.taskmodels as taskmodels import jiant.tasks as tasks + from jiant.proj.main.components.outputs import construct_output_from_dict +from jiant.shared.model_resolution import ModelArchitectures class JiantModel(nn.Module): @@ -49,7 +53,7 @@ def forward(self, batch: tasks.BatchMixin, task: tasks.Task, compute_loss: bool taskmodel_key = self.task_to_taskmodel_map[task_name] taskmodel = self.taskmodels_dict[taskmodel_key] return taskmodel( - batch=batch, task=task, tokenizer=self.tokenizer, compute_loss=compute_loss, + batch=batch, tokenizer=self.tokenizer, compute_loss=compute_loss, ).to_dict() @@ -85,3 +89,31 @@ def wrap_jiant_forward( if is_multi_gpu and compute_loss: model_output.loss = model_output.loss.mean() return model_output + + +class JiantTransformersModelFactory: + @staticmethod + def create_model(hf_model): + model_arch = ModelArchitectures.from_model_type(hf_model.config.model_type) + if model_arch == ModelArchitectures.BERT: + return JiantBERTModelWrapper(hf_model) + + +class JiantTransformersModelWrapper: + def __init__(self, baseObject): + self.__class__ = type( + baseObject.__class__.__name__, (self.__class__, baseObject.__class__), {} + ) + self.__dict__ = baseObject.__dict__ + + +class JiantBERTModelWrapper(JiantTransformersModelWrapper): + def __init__(self, baseObject): + super().__init__(baseObject) + self.hf_pretrained_encoder_with_pretrained_head = transformers.BertForPreTraining + + def get_hidden_size(self): + return self.config.hidden_size + + def get_hidden_dropout_prob(self): + return self.config.hidden_dropout_prob diff --git a/jiant/proj/main/modeling/taskmodels.py b/jiant/proj/main/modeling/taskmodels.py index 53d5e70c7..569fbe5ee 100644 --- a/jiant/proj/main/modeling/taskmodels.py +++ b/jiant/proj/main/modeling/taskmodels.py @@ -1,53 +1,78 @@ import abc + from dataclasses import dataclass from typing import Any +from typing import Callable import torch import torch.nn as nn import jiant.proj.main.modeling.heads as heads import jiant.utils.transformer_utils as transformer_utils -from jiant.proj.main.components.outputs import LogitsOutput, LogitsAndLossOutput -from jiant.utils.python.datastructures import take_one + +from jiant.proj.main.components.outputs import LogitsAndLossOutput +from jiant.proj.main.components.outputs import LogitsOutput from jiant.shared.model_resolution import ModelArchitectures +from jiant.tasks.core import TaskTypes +from jiant.utils.python.datastructures import take_one + + +class JiantTaskModelFactory: + + # Internal registry for available task heads + registry = {} + + @classmethod + def register(cls, task_type: TaskTypes) -> Callable: + def inner_wrapper(wrapped_class: Taskmodel) -> Callable: + assert task_type not in cls.registry + cls.registry[task_type] = wrapped_class + return wrapped_class + + return inner_wrapper + + def __call__(cls, task, jiant_transformers_model, head, **kwargs): + taskmodel_class = cls.registry[task.TASK_TYPE] + taskmodel = taskmodel_class(task, jiant_transformers_model, head, **kwargs) + return taskmodel class Taskmodel(nn.Module, metaclass=abc.ABCMeta): - def __init__(self, encoder): + def __init__(self, task, encoder, head): super().__init__() + self.task = task self.encoder = encoder + self.head = head - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): raise NotImplementedError +@JiantTaskModelFactory.register(TaskTypes.CLASSIFICATION) class ClassificationModel(Taskmodel): - def __init__(self, encoder, classification_head: heads.ClassificationHead): - super().__init__(encoder=encoder) - self.classification_head = classification_head + def __init__(self, task, encoder, head: heads.ClassificationHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.classification_head(pooled=encoder_output.pooled) + logits = self.head(pooled=encoder_output.pooled) if compute_loss: loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - logits.view(-1, self.classification_head.num_labels), batch.label_id.view(-1), - ) + loss = loss_fct(logits.view(-1, self.head.num_labels), batch.label_id.view(-1),) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.REGRESSION) class RegressionModel(Taskmodel): - def __init__(self, encoder, regression_head: heads.RegressionHead): - super().__init__(encoder=encoder) - self.regression_head = regression_head + def __init__(self, task, encoder, head: heads.RegressionHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) # TODO: Abuse of notation - these aren't really logits (issue #1187) - logits = self.regression_head(pooled=encoder_output.pooled) + logits = self.head(pooled=encoder_output.pooled) if compute_loss: loss_fct = nn.MSELoss() loss = loss_fct(logits.view(-1), batch.label.view(-1)) @@ -56,13 +81,13 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.MULTIPLE_CHOICE) class MultipleChoiceModel(Taskmodel): - def __init__(self, encoder, num_choices: int, choice_scoring_head: heads.RegressionHead): - super().__init__(encoder=encoder) - self.num_choices = num_choices - self.choice_scoring_head = choice_scoring_head + def __init__(self, task, encoder, head: heads.RegressionHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) + self.num_choices = task.NUM_CHOICES - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): input_ids = batch.input_ids segment_ids = batch.segment_ids input_mask = batch.input_mask @@ -76,7 +101,7 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): segment_ids=segment_ids[:, i], input_mask=input_mask[:, i], ) - choice_score = self.choice_scoring_head(pooled=encoder_output.pooled) + choice_score = self.head(pooled=encoder_output.pooled) choice_score_list.append(choice_score) encoder_output_other_ls.append(encoder_output.other) @@ -103,36 +128,34 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=reshaped_outputs) +@JiantTaskModelFactory.register(TaskTypes.SPAN_COMPARISON_CLASSIFICATION) class SpanComparisonModel(Taskmodel): - def __init__(self, encoder, span_comparison_head: heads.SpanComparisonHead): - super().__init__(encoder=encoder) - self.span_comparison_head = span_comparison_head + def __init__(self, task, encoder, head: heads.SpanComparisonHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.span_comparison_head(unpooled=encoder_output.unpooled, spans=batch.spans) + logits = self.head(unpooled=encoder_output.unpooled, spans=batch.spans) if compute_loss: loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - logits.view(-1, self.span_comparison_head.num_labels), batch.label_id.view(-1), - ) + loss = loss_fct(logits.view(-1, self.head.num_labels), batch.label_id.view(-1),) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.SPAN_PREDICTION) class SpanPredictionModel(Taskmodel): - def __init__(self, encoder, span_prediction_head: heads.TokenClassificationHead): - super().__init__(encoder=encoder) + def __init__(self, task, encoder, head: heads.TokenClassificationHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) self.offset_margin = 1000 # 1000 is a big enough number that exp(-1000) will be strict 0 in float32. # So that if we add 1000 to the valid dimensions in the input of softmax, # we can guarantee the output distribution will only be non-zero at those dimensions. - self.span_prediction_head = span_prediction_head - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.span_prediction_head(unpooled=encoder_output.unpooled) + logits = self.head(unpooled=encoder_output.unpooled) # Ensure logits in valid range is at least self.offset_margin higher than others logits_offset = logits.max() - logits.min() + self.offset_margin logits = logits + logits_offset * batch.selection_token_mask.unsqueeze(dim=2) @@ -146,38 +169,36 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.MULTI_LABEL_SPAN_CLASSIFICATION) class MultiLabelSpanComparisonModel(Taskmodel): - def __init__(self, encoder, span_comparison_head: heads.SpanComparisonHead): - super().__init__(encoder=encoder) - self.span_comparison_head = span_comparison_head + def __init__(self, task, encoder, head: heads.SpanComparisonHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.span_comparison_head(unpooled=encoder_output.unpooled, spans=batch.spans) + logits = self.head(unpooled=encoder_output.unpooled, spans=batch.spans) if compute_loss: loss_fct = nn.BCEWithLogitsLoss() - loss = loss_fct( - logits.view(-1, self.span_comparison_head.num_labels), batch.label_ids.float(), - ) + loss = loss_fct(logits.view(-1, self.head.num_labels), batch.label_ids.float(),) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.TAGGING) class TokenClassificationModel(Taskmodel): """From RobertaForTokenClassification""" - def __init__(self, encoder, token_classification_head: heads.TokenClassificationHead): - super().__init__(encoder=encoder) - self.token_classification_head = token_classification_head + def __init__(self, task, encoder, head: heads.TokenClassificationHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.token_classification_head(unpooled=encoder_output.unpooled) + logits = self.head(unpooled=encoder_output.unpooled) if compute_loss: loss_fct = nn.CrossEntropyLoss() active_loss = batch.label_mask.view(-1) == 1 - active_logits = logits.view(-1, self.token_classification_head.num_labels)[active_loss] + active_logits = logits.view(-1, self.head.num_labels)[active_loss] active_labels = batch.label_ids.view(-1)[active_loss] loss = loss_fct(active_logits, active_labels) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) @@ -185,14 +206,14 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.SQUAD_STYLE_QA) class QAModel(Taskmodel): - def __init__(self, encoder, qa_head: heads.QAHead): - super().__init__(encoder=encoder) - self.qa_head = qa_head + def __init__(self, task, encoder, head: heads.QAHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.qa_head(unpooled=encoder_output.unpooled) + logits = self.head(unpooled=encoder_output.unpooled) if compute_loss: loss = compute_qa_loss( logits=logits, @@ -204,14 +225,16 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.MASKED_LANGUAGE_MODELING) class MLMModel(Taskmodel): - def __init__(self, encoder, mlm_head: heads.BaseMLMHead): - super().__init__(encoder=encoder) - self.mlm_head = mlm_head + def __init__(self, task, encoder, head: heads.BaseMLMHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): masked_batch = batch.get_masked( - mlm_probability=task.mlm_probability, tokenizer=tokenizer, do_mask=task.do_mask, + mlm_probability=self.task.mlm_probability, + tokenizer=tokenizer, + do_mask=self.task.do_mask, ) encoder_output = get_output_from_encoder( encoder=self.encoder, @@ -219,7 +242,7 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): segment_ids=masked_batch.segment_ids, input_mask=masked_batch.input_mask, ) - logits = self.mlm_head(unpooled=encoder_output.unpooled) + logits = self.head(unpooled=encoder_output.unpooled) if compute_loss: loss = compute_mlm_loss(logits=logits, masked_lm_labels=masked_batch.masked_lm_labels) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) @@ -227,25 +250,25 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.EMBEDDING) class EmbeddingModel(Taskmodel): - def __init__(self, encoder, pooler_head: heads.AbstractPoolerHead, layer): - super().__init__(encoder=encoder) - self.pooler_head = pooler_head - self.layer = layer + def __init__(self, task, encoder, head: heads.AbstractPoolerHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) + self.layer = kwargs["layer"] - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): with transformer_utils.output_hidden_states_context(self.encoder): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) # A tuple of layers of hidden states hidden_states = take_one(encoder_output.other) layer_hidden_states = hidden_states[self.layer] - if isinstance(self.pooler_head, heads.MeanPoolerHead): - logits = self.pooler_head(unpooled=layer_hidden_states, input_mask=batch.input_mask) - elif isinstance(self.pooler_head, heads.FirstPoolerHead): - logits = self.pooler_head(layer_hidden_states) + if isinstance(self.head, heads.MeanPoolerHead): + logits = self.head(unpooled=layer_hidden_states, input_mask=batch.input_mask) + elif isinstance(self.head, heads.FirstPoolerHead): + logits = self.head(layer_hidden_states) else: - raise TypeError(type(self.pooler_head)) + raise TypeError(type(self.head)) # TODO: Abuse of notation - these aren't really logits (issue #1187) if compute_loss: diff --git a/jiant/shared/model_resolution.py b/jiant/shared/model_resolution.py index 7ec446137..5470d9f57 100644 --- a/jiant/shared/model_resolution.py +++ b/jiant/shared/model_resolution.py @@ -7,14 +7,14 @@ class ModelArchitectures(Enum): - BERT = 1 - XLM = 2 - ROBERTA = 3 - ALBERT = 4 - XLM_ROBERTA = 5 - BART = 6 - MBART = 7 - ELECTRA = 8 + BERT = "bert" + XLM = "xlm" + ROBERTA = "roberta" + ALBERT = "albert" + XLM_ROBERTA = "xlm-roberta" + BART = "bart" + MBART = "mbart" + ELECTRA = "electra" @classmethod def from_model_type(cls, model_type: str): @@ -27,86 +27,25 @@ def from_model_type(cls, model_type: str): Model architecture associated with the provided shortcut name. """ - if model_type.startswith("bert"): + if model_type == "bert": return cls.BERT - elif model_type.startswith("xlm") and not model_type.startswith("xlm-roberta"): + elif model_type == "xlm": return cls.XLM - elif model_type.startswith("roberta"): + elif model_type == "roberta": return cls.ROBERTA - elif model_type.startswith("albert"): + elif model_type == "albert": return cls.ALBERT - elif model_type == "glove_lstm": - return cls.GLOVE_LSTM - elif model_type.startswith("xlm-roberta"): + elif model_type == "xlm-roberta": return cls.XLM_ROBERTA - elif model_type.startswith("bart"): + elif model_type == "bart": return cls.BART - elif model_type.startswith("mbart"): + elif model_type == "mbart": return cls.MBART - elif model_type.startswith("electra"): + elif model_type == "electra": return cls.ELECTRA else: raise KeyError(model_type) - @classmethod - def from_transformers_model(cls, transformers_model): - if isinstance( - transformers_model, transformers.BertPreTrainedModel - ) and transformers_model.__class__.__name__.startswith("Bert"): - return cls.BERT - elif isinstance(transformers_model, transformers.XLMPreTrainedModel): - return cls.XLM - elif isinstance( - transformers_model, transformers.BertPreTrainedModel - ) and transformers_model.__class__.__name__.startswith("Robert"): - return cls.ROBERTA - elif isinstance( - transformers_model, transformers.BertPreTrainedModel - ) and transformers_model.__class__.__name__.startswith("XLMRoberta"): - return cls.XLM_ROBERTA - elif isinstance(transformers_model, transformers.modeling_albert.AlbertPreTrainedModel): - return cls.ALBERT - elif isinstance(transformers_model, transformers.modeling_bart.PretrainedBartModel): - return bart_or_mbart_model_heuristic(model_config=transformers_model.config) - elif isinstance(transformers_model, transformers.modeling_electra.ElectraPreTrainedModel): - return cls.ELECTRA - else: - raise KeyError(str(transformers_model)) - - @classmethod - def from_tokenizer_class(cls, tokenizer_class): - if isinstance(tokenizer_class, transformers.BertTokenizer): - return cls.BERT - elif isinstance(tokenizer_class, transformers.XLMTokenizer): - return cls.XLM - elif isinstance(tokenizer_class, transformers.RobertaTokenizer): - return cls.ROBERTA - elif isinstance(tokenizer_class, transformers.XLMRobertaTokenizer): - return cls.XLM_ROBERTA - elif isinstance(tokenizer_class, transformers.AlbertTokenizer): - return cls.ALBERT - elif isinstance(tokenizer_class, transformers.BartTokenizer): - return cls.BART - elif isinstance(tokenizer_class, transformers.MBartTokenizer): - return cls.MBART - elif isinstance(tokenizer_class, transformers.ElectraTokenizer): - return cls.ELECTRA - else: - raise KeyError(str(tokenizer_class)) - - @classmethod - def is_transformers_model_arch(cls, model_arch): - return model_arch in [ - cls.BERT, - cls.XLM, - cls.ROBERTA, - cls.ALBERT, - cls.XLM_ROBERTA, - cls.BART, - cls.MBART, - cls.ELECTRA, - ] - @classmethod def from_encoder(cls, encoder): if ( diff --git a/jiant/tasks/core.py b/jiant/tasks/core.py index 59471f163..85cc0e0e4 100644 --- a/jiant/tasks/core.py +++ b/jiant/tasks/core.py @@ -85,18 +85,18 @@ def data_row_collate_fn(batch): class TaskTypes(Enum): - CLASSIFICATION = 1 - REGRESSION = 2 - SPAN_COMPARISON_CLASSIFICATION = 3 - MULTIPLE_CHOICE = 4 - SPAN_CHOICE_PROB_TASK = 5 - SQUAD_STYLE_QA = 6 - TAGGING = 7 - MASKED_LANGUAGE_MODELING = 8 - EMBEDDING = 9 - MULTI_LABEL_SPAN_CLASSIFICATION = 10 - SPAN_PREDICTION = 11 - UNDEFINED = -1 + CLASSIFICATION = "classification" + REGRESSION = "regression" + SPAN_COMPARISON_CLASSIFICATION = "span_comparison_classification" + MULTIPLE_CHOICE = "multiple_choice" + SPAN_CHOICE_PROB_TASK = "span_choice_prob_task" + SQUAD_STYLE_QA = "squad_style_qa" + TAGGING = "tagging" + MASKED_LANGUAGE_MODELING = "masked_language_modeling" + EMBEDDING = "embedding" + MULTI_LABEL_SPAN_CLASSIFICATION = "multi_label_span_classification" + SPAN_PREDICTION = "span_prediction" + UNDEFINED = "undefined" class BatchTuple(NamedTuple): diff --git a/tests/tasks/lib/test_mlm_premasked.py b/tests/tasks/lib/test_mlm_premasked.py index 17a5dd0fe..48d25405f 100644 --- a/tests/tasks/lib/test_mlm_premasked.py +++ b/tests/tasks/lib/test_mlm_premasked.py @@ -35,7 +35,7 @@ def test_tokenization_and_featurization(): data_row = tokenized_example.featurize( tokenizer=tokenizer, feat_spec=model_resolution.build_featurization_spec( - model_type="roberta-base", max_seq_length=16, + model_type="roberta", max_seq_length=16, ), ) assert list(data_row.masked_input_ids) == [ diff --git a/tests/tasks/lib/test_mlm_pretokenized.py b/tests/tasks/lib/test_mlm_pretokenized.py index 2b682f234..4f52ea361 100644 --- a/tests/tasks/lib/test_mlm_pretokenized.py +++ b/tests/tasks/lib/test_mlm_pretokenized.py @@ -37,7 +37,7 @@ def test_tokenization_and_featurization(): data_row = tokenized_example.featurize( tokenizer=tokenizer, feat_spec=model_resolution.build_featurization_spec( - model_type="roberta-base", max_seq_length=16, + model_type="roberta", max_seq_length=16, ), ) assert list(data_row.masked_input_ids) == [ diff --git a/tests/tasks/lib/test_mnli.py b/tests/tasks/lib/test_mnli.py index 2b564e41d..e5d7a32dd 100644 --- a/tests/tasks/lib/test_mnli.py +++ b/tests/tasks/lib/test_mnli.py @@ -301,7 +301,7 @@ def test_featurization_of_task_data(): tokenized_examples[0].hypothesis ) feat_spec = model_resolution.build_featurization_spec( - model_type="bert-", max_seq_length=train_example_0_length + model_type="bert", max_seq_length=train_example_0_length ) featurized_examples = [ tokenized_example.featurize(tokenizer=tokenizer, feat_spec=feat_spec) diff --git a/tests/tasks/lib/test_spr1.py b/tests/tasks/lib/test_spr1.py index fb8aefb2a..d3c394a58 100644 --- a/tests/tasks/lib/test_spr1.py +++ b/tests/tasks/lib/test_spr1.py @@ -315,7 +315,7 @@ def test_featurization_of_task_data(): # Testing conversion of a tokenized example to a featurized example train_example_0_length = len(tokenized_examples[0].tokens) + 4 feat_spec = model_resolution.build_featurization_spec( - model_type="bert-", max_seq_length=train_example_0_length + model_type="bert", max_seq_length=train_example_0_length ) featurized_examples = [ tokenized_example.featurize(tokenizer=tokenizer, feat_spec=feat_spec)