From e004af87be87422d53d7360cfe30da2bafca65e7 Mon Sep 17 00:00:00 2001 From: jeswan <57466294+jeswan@users.noreply.github.com> Date: Thu, 25 Feb 2021 10:25:14 -0500 Subject: [PATCH] Switch to task model/head factories instead of embedded if-else statements (#1268) * Use jiant transformers model wrapper instead of if-else. Use taskmodel and head factory instead of if-else. * switch to ModelArchitectures enum instead of strings --- jiant/proj/main/modeling/heads.py | 141 ++++++++++-- jiant/proj/main/modeling/model_setup.py | 267 ++--------------------- jiant/proj/main/modeling/primary.py | 126 ++++++++++- jiant/proj/main/modeling/taskmodels.py | 189 ++++++++++------ jiant/shared/model_resolution.py | 111 +--------- jiant/tasks/core.py | 24 +- tests/tasks/lib/test_mlm_premasked.py | 2 +- tests/tasks/lib/test_mlm_pretokenized.py | 4 +- tests/tasks/lib/test_mnli.py | 6 +- tests/tasks/lib/test_spr1.py | 2 +- 10 files changed, 427 insertions(+), 445 deletions(-) diff --git a/jiant/proj/main/modeling/heads.py b/jiant/proj/main/modeling/heads.py index b68588282..fca8a7e28 100644 --- a/jiant/proj/main/modeling/heads.py +++ b/jiant/proj/main/modeling/heads.py @@ -1,10 +1,17 @@ +from __future__ import annotations + import abc import torch import torch.nn as nn import transformers + from jiant.ext.allennlp import SelfAttentiveSpanExtractor +from jiant.shared.model_resolution import ModelArchitectures +from jiant.tasks.core import TaskTypes +from typing import Callable +from typing import List """ In HuggingFace/others, these heads differ slightly across different encoder models. @@ -12,18 +19,75 @@ """ +class JiantHeadFactory: + """This factory is used to create task-specific heads for the supported Transformer encoders. + + Attributes: + registry (dict): Dynamic registry mapping task types to task heads + """ + + registry = {} + + @classmethod + def register(cls, task_type_list: List[TaskTypes]) -> Callable: + """Register each TaskType in task_type_list as a key mapping to a BaseHead task head + + Args: + task_type_list (List[TaskType]): List of TaskTypes that are associated to a + BaseHead task head + + Returns: + Callable: inner_wrapper() wrapping task head constructor or task head factory + """ + + def inner_wrapper(wrapped_class: BaseHead) -> Callable: + """Summary + + Args: + wrapped_class (BaseHead): Task head class + + Returns: + Callable: Task head constructor or factory + """ + for task_type in task_type_list: + assert task_type not in cls.registry + cls.registry[task_type] = wrapped_class + return wrapped_class + + return inner_wrapper + + def __call__(self, task, **kwargs) -> BaseHead: + """Summary + + Args: + task (Task): A task head will be created based on the task type + **kwargs: Arguments required for task head initialization + + Returns: + BaseHead: Initialized task head + """ + head_class = self.registry[task.TASK_TYPE] + head = head_class(task, **kwargs) + return head + + class BaseHead(nn.Module, metaclass=abc.ABCMeta): - pass + """Absract class for task heads""" + + @abc.abstractmethod + def __init__(self): + super().__init__() +@JiantHeadFactory.register([TaskTypes.CLASSIFICATION]) class ClassificationHead(BaseHead): - def __init__(self, hidden_size, hidden_dropout_prob, num_labels): + def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs): """From RobertaClassificationHead""" super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) self.dropout = nn.Dropout(hidden_dropout_prob) - self.out_proj = nn.Linear(hidden_size, num_labels) - self.num_labels = num_labels + self.out_proj = nn.Linear(hidden_size, task.num_labels) + self.num_labels = len(task.LABELS) def forward(self, pooled): x = self.dropout(pooled) @@ -34,8 +98,9 @@ def forward(self, pooled): return logits +@JiantHeadFactory.register([TaskTypes.REGRESSION, TaskTypes.MULTIPLE_CHOICE]) class RegressionHead(BaseHead): - def __init__(self, hidden_size, hidden_dropout_prob): + def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs): """From RobertaClassificationHead""" super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) @@ -51,12 +116,13 @@ def forward(self, pooled): return scores +@JiantHeadFactory.register([TaskTypes.SPAN_COMPARISON_CLASSIFICATION]) class SpanComparisonHead(BaseHead): - def __init__(self, hidden_size, hidden_dropout_prob, num_spans, num_labels): + def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs): """From RobertaForSpanComparisonClassification""" super().__init__() - self.num_spans = num_spans - self.num_labels = num_labels + self.num_spans = task.num_spans + self.num_labels = len(task.LABELS) self.hidden_size = hidden_size self.dropout = nn.Dropout(hidden_dropout_prob) self.span_attention_extractor = SelfAttentiveSpanExtractor(hidden_size) @@ -70,13 +136,14 @@ def forward(self, unpooled, spans): return logits +@JiantHeadFactory.register([TaskTypes.TAGGING]) class TokenClassificationHead(BaseHead): - def __init__(self, hidden_size, num_labels, hidden_dropout_prob): + def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs): """From RobertaForTokenClassification""" super().__init__() - self.num_labels = num_labels + self.num_labels = len(task.LABELS) self.dropout = nn.Dropout(hidden_dropout_prob) - self.classifier = nn.Linear(hidden_size, num_labels) + self.classifier = nn.Linear(hidden_size, self.num_labels) def forward(self, unpooled): unpooled = self.dropout(unpooled) @@ -84,8 +151,9 @@ def forward(self, unpooled): return logits +@JiantHeadFactory.register([TaskTypes.SQUAD_STYLE_QA]) class QAHead(BaseHead): - def __init__(self, hidden_size): + def __init__(self, task, hidden_size, **kwargs): """From RobertaForQuestionAnswering""" super().__init__() self.qa_outputs = nn.Linear(hidden_size, 2) @@ -98,10 +166,55 @@ def forward(self, unpooled): return logits +@JiantHeadFactory.register([TaskTypes.MASKED_LANGUAGE_MODELING]) +class JiantMLMHeadFactory: + """This factory is used to create masked language modeling (MLM) task heads. + This is required due to Transformers implementing different MLM heads for + different encoders. + + Attributes: + registry (dict): Dynamic registry mapping model architectures to MLM task heads + """ + + registry = {} + + @classmethod + def register(cls, model_arch_list: List[ModelArchitectures]) -> Callable: + """Registers the ModelArchitectures in model_arch_list as keys mapping to a MLMHead + + Args: + model_arch_list (List[ModelArchitectures]): List of ModelArchitectures mapping to + an MLM task head. + + Returns: + Callable: MLMHead class + """ + + def inner_wrapper(wrapped_class: BaseMLMHead) -> Callable: + for model_arch in model_arch_list: + assert model_arch not in cls.registry + cls.registry[model_arch] = wrapped_class + return wrapped_class + + return inner_wrapper + + def __call__(self, task, **kwargs): + """Summary + + Args: + task (Task): Task used to initialize task head + **kwargs: Additional arguments required to initialize task head + """ + mlm_head_class = self.registry[task.TASK_TYPE] + mlm_head = mlm_head_class(task, **kwargs) + return mlm_head + + class BaseMLMHead(BaseHead, metaclass=abc.ABCMeta): pass +@JiantMLMHeadFactory.register([ModelArchitectures.BERT]) class BertMLMHead(BaseMLMHead): """From BertOnlyMLMHead, BertLMPredictionHead, BertPredictionHeadTransform""" @@ -128,6 +241,7 @@ def forward(self, unpooled): return logits +@JiantMLMHeadFactory.register([ModelArchitectures.ROBERTA, ModelArchitectures.XLM_ROBERTA]) class RobertaMLMHead(BaseMLMHead): """From RobertaLMHead""" @@ -155,7 +269,8 @@ def forward(self, unpooled): return logits -class AlbertMLMHead(nn.Module): +@JiantMLMHeadFactory.register([ModelArchitectures.ALBERT]) +class AlbertMLMHead(BaseMLMHead): """From AlbertMLMHead""" def __init__(self, hidden_size, embedding_size, vocab_size, hidden_act="gelu"): diff --git a/jiant/proj/main/modeling/model_setup.py b/jiant/proj/main/modeling/model_setup.py index 6ff546e72..f176e56cf 100644 --- a/jiant/proj/main/modeling/model_setup.py +++ b/jiant/proj/main/modeling/model_setup.py @@ -2,7 +2,6 @@ from typing import Any from typing import Dict from typing import List -from typing import Optional import torch import torch.nn as nn @@ -10,14 +9,14 @@ import jiant.proj.main.components.container_setup as container_setup -import jiant.proj.main.modeling.heads as heads import jiant.proj.main.modeling.primary as primary -import jiant.proj.main.modeling.taskmodels as taskmodels import jiant.utils.python.strings as strings +from jiant.proj.main.modeling.heads import JiantHeadFactory +from jiant.proj.main.modeling.taskmodels import JiantTaskModelFactory, Taskmodel, MLMModel + from jiant.shared.model_resolution import ModelArchitectures from jiant.tasks import Task -from jiant.tasks import TaskTypes def setup_jiant_model( @@ -52,19 +51,13 @@ def setup_jiant_model( JiantModel nn.Module. """ - model = transformers.AutoModel.from_pretrained(hf_pretrained_model_name_or_path) - model_arch = ModelArchitectures.from_model_type(model.base_model_prefix) - transformers_class_spec = TRANSFORMERS_CLASS_SPEC_DICT[model_arch] + hf_model = transformers.AutoModel.from_pretrained(hf_pretrained_model_name_or_path) tokenizer = transformers.AutoTokenizer.from_pretrained(hf_pretrained_model_name_or_path) - ancestor_model = get_ancestor_model( - transformers_class_spec=transformers_class_spec, model_config_path=model_config_path, - ) - encoder = get_encoder(model_arch=model_arch, ancestor_model=ancestor_model) + jiant_transformers_model = primary.JiantTransformersModelFactory()(hf_model) taskmodels_dict = { taskmodel_name: create_taskmodel( task=task_dict[task_name_list[0]], # Take the first task - model_arch=model_arch, - encoder=encoder, + jiant_transformers_model=jiant_transformers_model, taskmodel_kwargs=taskmodels_config.get_taskmodel_kwargs(taskmodel_name), ) for taskmodel_name, task_name_list in get_taskmodel_and_task_names( @@ -73,7 +66,7 @@ def setup_jiant_model( } return primary.JiantModel( task_dict=task_dict, - encoder=encoder, + encoder=jiant_transformers_model, taskmodels_dict=taskmodels_dict, task_to_taskmodel_map=taskmodels_config.task_to_taskmodel_map, tokenizer=tokenizer, @@ -162,7 +155,7 @@ def load_encoder_from_transformers_weights( remainder_weights_dict = {} load_weights_dict = {} model_arch = ModelArchitectures.from_encoder(encoder=encoder) - encoder_prefix = MODEL_PREFIX[model_arch] + "." + encoder_prefix = model_arch.value + "." # Encoder for k, v in weights_dict.items(): if k.startswith(encoder_prefix): @@ -200,7 +193,7 @@ def load_lm_heads_from_transformers_weights(jiant_model, weights_dict): raise KeyError(model_arch) missed = set() for taskmodel_name, taskmodel in jiant_model.taskmodels_dict.items(): - if not isinstance(taskmodel, taskmodels.MLMModel): + if not isinstance(taskmodel, MLMModel): continue mismatch = taskmodel.mlm_head.load_state_dict(mlm_weights_dict) assert not mismatch.missing_keys @@ -273,196 +266,34 @@ def load_partial_heads( return result -def create_taskmodel( - task, model_arch, encoder, taskmodel_kwargs: Optional[Dict] = None -) -> taskmodels.Taskmodel: +def create_taskmodel(task, jiant_transformers_model, **taskmodel_kwargs) -> Taskmodel: """Creates, initializes and returns the task model for a given task type and encoder. Args: task (Task): Task object associated with the taskmodel being created. - model_arch (ModelArchitectures.Any): Model architecture (e.g., ModelArchitectures.BERT). - encoder (PreTrainedModel): Transformer w/o heads (embedding layer + self-attention layer). - taskmodel_kwargs (Optional[Dict]): map containing any kwargs needed for taskmodel setup. + jiant_transformers_model (JiantTransformersModel): Transformer w/o heads + (embedding layer + self-attention layer). + **taskmodel_kwargs: Additional args for taskmodel setup Raises: KeyError if task does not have valid TASK_TYPE. Returns: - Taskmodel (e.g., ClassificationModel) appropriate for the task type and encoder. + Taskmodel """ - if model_arch in [ - ModelArchitectures.BERT, - ModelArchitectures.ROBERTA, - ModelArchitectures.ALBERT, - ModelArchitectures.XLM_ROBERTA, - ModelArchitectures.ELECTRA, - ]: - hidden_size = encoder.config.hidden_size - hidden_dropout_prob = encoder.config.hidden_dropout_prob - elif model_arch in [ - ModelArchitectures.BART, - ModelArchitectures.MBART, - ]: - hidden_size = encoder.config.d_model - hidden_dropout_prob = encoder.config.dropout - else: - raise KeyError() - - if task.TASK_TYPE == TaskTypes.CLASSIFICATION: - assert taskmodel_kwargs is None - classification_head = heads.ClassificationHead( - hidden_size=hidden_size, - hidden_dropout_prob=hidden_dropout_prob, - num_labels=len(task.LABELS), - ) - taskmodel = taskmodels.ClassificationModel( - encoder=encoder, classification_head=classification_head, - ) - elif task.TASK_TYPE == TaskTypes.REGRESSION: - assert taskmodel_kwargs is None - regression_head = heads.RegressionHead( - hidden_size=hidden_size, hidden_dropout_prob=hidden_dropout_prob, - ) - taskmodel = taskmodels.RegressionModel(encoder=encoder, regression_head=regression_head) - elif task.TASK_TYPE == TaskTypes.MULTIPLE_CHOICE: - assert taskmodel_kwargs is None - choice_scoring_head = heads.RegressionHead( - hidden_size=hidden_size, hidden_dropout_prob=hidden_dropout_prob, - ) - taskmodel = taskmodels.MultipleChoiceModel( - encoder=encoder, num_choices=task.NUM_CHOICES, choice_scoring_head=choice_scoring_head, - ) - elif task.TASK_TYPE == TaskTypes.SPAN_PREDICTION: - assert taskmodel_kwargs is None - span_prediction_head = heads.TokenClassificationHead( - hidden_size=hidden_size, - hidden_dropout_prob=encoder.config.hidden_dropout_prob, - num_labels=2, - ) - taskmodel = taskmodels.SpanPredictionModel( - encoder=encoder, span_prediction_head=span_prediction_head, - ) - elif task.TASK_TYPE == TaskTypes.SPAN_COMPARISON_CLASSIFICATION: - assert taskmodel_kwargs is None - span_comparison_head = heads.SpanComparisonHead( - hidden_size=hidden_size, - hidden_dropout_prob=hidden_dropout_prob, - num_spans=task.num_spans, - num_labels=len(task.LABELS), - ) - taskmodel = taskmodels.SpanComparisonModel( - encoder=encoder, span_comparison_head=span_comparison_head, - ) - elif task.TASK_TYPE == TaskTypes.MULTI_LABEL_SPAN_CLASSIFICATION: - assert taskmodel_kwargs is None - span_comparison_head = heads.SpanComparisonHead( - hidden_size=hidden_size, - hidden_dropout_prob=hidden_dropout_prob, - num_spans=task.num_spans, - num_labels=len(task.LABELS), - ) - taskmodel = taskmodels.MultiLabelSpanComparisonModel( - encoder=encoder, span_comparison_head=span_comparison_head, - ) - elif task.TASK_TYPE == TaskTypes.TAGGING: - assert taskmodel_kwargs is None - token_classification_head = heads.TokenClassificationHead( - hidden_size=hidden_size, - hidden_dropout_prob=hidden_dropout_prob, - num_labels=len(task.LABELS), - ) - taskmodel = taskmodels.TokenClassificationModel( - encoder=encoder, token_classification_head=token_classification_head, - ) - elif task.TASK_TYPE == TaskTypes.SQUAD_STYLE_QA: - assert taskmodel_kwargs is None - qa_head = heads.QAHead(hidden_size=hidden_size) - taskmodel = taskmodels.QAModel(encoder=encoder, qa_head=qa_head) - elif task.TASK_TYPE == TaskTypes.MASKED_LANGUAGE_MODELING: - assert taskmodel_kwargs is None - if model_arch == ModelArchitectures.BERT: - mlm_head = heads.BertMLMHead( - hidden_size=hidden_size, - vocab_size=encoder.config.vocab_size, - layer_norm_eps=encoder.config.layer_norm_eps, - hidden_act=encoder.config.hidden_act, - ) - elif model_arch == ModelArchitectures.ROBERTA: - mlm_head = heads.RobertaMLMHead( - hidden_size=hidden_size, - vocab_size=encoder.config.vocab_size, - layer_norm_eps=encoder.config.layer_norm_eps, - ) - elif model_arch == ModelArchitectures.ALBERT: - mlm_head = heads.AlbertMLMHead( - hidden_size=hidden_size, - embedding_size=encoder.config.embedding_size, - vocab_size=encoder.config.vocab_size, - hidden_act=encoder.config.hidden_act, - ) - elif model_arch == ModelArchitectures.XLM_ROBERTA: - mlm_head = heads.RobertaMLMHead( - hidden_size=hidden_size, - vocab_size=encoder.config.vocab_size, - layer_norm_eps=encoder.config.layer_norm_eps, - ) - elif model_arch in ( - ModelArchitectures.BART, - ModelArchitectures.MBART, - ModelArchitectures.ELECTRA, - ): - raise NotImplementedError() - else: - raise KeyError(model_arch) - taskmodel = taskmodels.MLMModel(encoder=encoder, mlm_head=mlm_head) - elif task.TASK_TYPE == TaskTypes.EMBEDDING: - if taskmodel_kwargs["pooler_type"] == "mean": - pooler_head = heads.MeanPoolerHead() - elif taskmodel_kwargs["pooler_type"] == "first": - pooler_head = heads.FirstPoolerHead() - else: - raise KeyError(taskmodel_kwargs["pooler_type"]) - taskmodel = taskmodels.EmbeddingModel( - encoder=encoder, pooler_head=pooler_head, layer=taskmodel_kwargs["layer"], - ) - else: - raise KeyError(task.TASK_TYPE) - return taskmodel - - -def get_encoder(model_arch, ancestor_model): - """From model architecture, get the encoder (encoder = embedding layer + self-attention layer). - - This function will return the "The bare Bert Model transformer outputting raw hidden-states - without any specific head on top", when provided with ModelArchitectures and BertForPreTraining - model. See Hugging Face's BertForPreTraining and BertModel documentation for more info. + head_kwargs = {} + head_kwargs["hidden_size"] = jiant_transformers_model.get_hidden_size() + head_kwargs["hidden_dropout_prob"] = jiant_transformers_model.get_hidden_dropout_prob() + head_kwargs["vocab_size"] = jiant_transformers_model.config.vocab_size + head_kwargs["layer_norm_eps"] = (jiant_transformers_model.config.layer_norm_eps,) + head_kwargs["hidden_act"] = jiant_transformers_model.config.hidden_act + head_kwargs["model_arch"] = ModelArchitectures(jiant_transformers_model.config.model_type) - Args: - model_arch: Model architecture. - ancestor_model: Model with pretraining heads attached. - - Raises: - KeyError if ModelArchitectures - - Returns: - Bare pretrained model outputting raw hidden-states without a specific head on top. + head = JiantHeadFactory()(task, **head_kwargs) - """ - if model_arch == ModelArchitectures.BERT: - return ancestor_model.bert - elif model_arch == ModelArchitectures.ROBERTA: - return ancestor_model.roberta - elif model_arch == ModelArchitectures.ALBERT: - return ancestor_model.albert - elif model_arch == ModelArchitectures.XLM_ROBERTA: - return ancestor_model.roberta - elif model_arch in (ModelArchitectures.BART, ModelArchitectures.MBART): - return ancestor_model.model - elif model_arch == ModelArchitectures.ELECTRA: - return ancestor_model.electra - else: - raise KeyError(model_arch) + taskmodel = JiantTaskModelFactory()(task, jiant_transformers_model, head, **taskmodel_kwargs) + return taskmodel @dataclass @@ -472,45 +303,6 @@ class TransformersClassSpec: model_class: Any -TRANSFORMERS_CLASS_SPEC_DICT = { - ModelArchitectures.BERT: TransformersClassSpec( - config_class=transformers.BertConfig, - tokenizer_class=transformers.BertTokenizer, - model_class=transformers.BertForPreTraining, - ), - ModelArchitectures.ROBERTA: TransformersClassSpec( - config_class=transformers.RobertaConfig, - tokenizer_class=transformers.RobertaTokenizer, - model_class=transformers.RobertaForMaskedLM, - ), - ModelArchitectures.ALBERT: TransformersClassSpec( - config_class=transformers.AlbertConfig, - tokenizer_class=transformers.AlbertTokenizer, - model_class=transformers.AlbertForMaskedLM, - ), - ModelArchitectures.XLM_ROBERTA: TransformersClassSpec( - config_class=transformers.XLMRobertaConfig, - tokenizer_class=transformers.XLMRobertaTokenizer, - model_class=transformers.XLMRobertaForMaskedLM, - ), - ModelArchitectures.BART: TransformersClassSpec( - config_class=transformers.BartConfig, - tokenizer_class=transformers.BartTokenizer, - model_class=transformers.BartForConditionalGeneration, - ), - ModelArchitectures.MBART: TransformersClassSpec( - config_class=transformers.BartConfig, - tokenizer_class=transformers.MBartTokenizer, - model_class=transformers.BartForConditionalGeneration, - ), - ModelArchitectures.ELECTRA: TransformersClassSpec( - config_class=transformers.ElectraConfig, - tokenizer_class=transformers.ElectraTokenizer, - model_class=transformers.ElectraForPreTraining, - ), -} - - def get_taskmodel_and_task_names(task_to_taskmodel_map: Dict[str, str]) -> Dict[str, List[str]]: """Get mapping from task model name to the list of task names associated with that task model. @@ -533,17 +325,6 @@ def get_model_arch_from_jiant_model(jiant_model: nn.Module) -> ModelArchitecture return ModelArchitectures.from_encoder(encoder=jiant_model.encoder) -MODEL_PREFIX = { - ModelArchitectures.BERT: "bert", - ModelArchitectures.ROBERTA: "roberta", - ModelArchitectures.ALBERT: "albert", - ModelArchitectures.XLM_ROBERTA: "xlm-roberta", - ModelArchitectures.BART: "model", - ModelArchitectures.MBART: "model", - ModelArchitectures.ELECTRA: "electra", -} - - def get_ancestor_model(transformers_class_spec, model_config_path): """Load the model config from a file, configure the model, and return the model. diff --git a/jiant/proj/main/modeling/primary.py b/jiant/proj/main/modeling/primary.py index bf5f398a4..fbefe1ab7 100644 --- a/jiant/proj/main/modeling/primary.py +++ b/jiant/proj/main/modeling/primary.py @@ -1,10 +1,17 @@ -from typing import Dict, Union +import abc + +from typing import Dict +from typing import Union +from typing import Callable import torch.nn as nn +import transformers import jiant.proj.main.modeling.taskmodels as taskmodels import jiant.tasks as tasks + from jiant.proj.main.components.outputs import construct_output_from_dict +from jiant.shared.model_resolution import ModelArchitectures class JiantModel(nn.Module): @@ -49,7 +56,7 @@ def forward(self, batch: tasks.BatchMixin, task: tasks.Task, compute_loss: bool taskmodel_key = self.task_to_taskmodel_map[task_name] taskmodel = self.taskmodels_dict[taskmodel_key] return taskmodel( - batch=batch, task=task, tokenizer=self.tokenizer, compute_loss=compute_loss, + batch=batch, tokenizer=self.tokenizer, compute_loss=compute_loss, ).to_dict() @@ -85,3 +92,118 @@ def wrap_jiant_forward( if is_multi_gpu and compute_loss: model_output.loss = model_output.loss.mean() return model_output + + +class JiantTransformersModelFactory: + """This factory is used to create JiantTransformersModels based on Huggingface's models. + A wrapper class around Huggingface's Transformer models is used to abstract any inconsistencies + in the classes. + + Attributes: + registry (dict): Dynamic registry mapping ModelArchitectures to JiantTransformersModels + """ + + registry = {} + + @classmethod + def register(cls, model_arch: ModelArchitectures) -> Callable: + """Register model_arch as a key mapping to a TaskModel + + Args: + model_arch (ModelArchitectures): ModelArchitecture key mapping to a + JiantTransformersModel + + Returns: + Callable: inner_wrapper() wrapping TaskModel constructor + """ + + def inner_wrapper(wrapped_class: JiantTransformersModel) -> Callable: + assert model_arch not in cls.registry + cls.registry[model_arch] = wrapped_class + return wrapped_class + + return inner_wrapper + + def __call__(cls, hf_model): + """Returns the JiantTransformersModel wrapper class for the corresponding Hugging Face + Transformer model. + + Args: + hf_model (PreTrainedModel): Hugging Face model to convert to JiantTransformersModel + + Returns: + JiantTransformersModel: Jiant wrapper class for Hugging Face model + """ + jiant_transformers_model_class = cls.registry[ + ModelArchitectures(hf_model.config.model_type) + ] + jiant_transformers_model = jiant_transformers_model_class(hf_model) + return jiant_transformers_model + + +class JiantTransformersModel(metaclass=abc.ABCMeta): + def __init__(self, baseObject): + self.__class__ = type( + baseObject.__class__.__name__, (self.__class__, baseObject.__class__), {} + ) + self.__dict__ = baseObject.__dict__ + + def get_hidden_size(self): + return self.config.hidden_size + + def get_hidden_dropout_prob(self): + return self.config.hidden_dropout_prob + + +@JiantTransformersModelFactory.register(ModelArchitectures.BERT) +class JiantBertModel(JiantTransformersModel): + def __init__(self, baseObject): + super().__init__(baseObject) + self.hf_pretrained_encoder_with_pretrained_head = transformers.BertForPreTraining + + +@JiantTransformersModelFactory.register(ModelArchitectures.ROBERTA) +class JiantRobertaModel(JiantTransformersModel): + def __init__(self, baseObject): + super().__init__(baseObject) + self.hf_pretrained_encoder_with_pretrained_head = transformers.RobertaForMaskedLM + + +@JiantTransformersModelFactory.register(ModelArchitectures.ALBERT) +class JiantAlbertModel(JiantTransformersModel): + def __init__(self, baseObject): + super().__init__(baseObject) + self.hf_pretrained_encoder_with_pretrained_head = transformers.AlbertForMaskedLM + + +@JiantTransformersModelFactory.register(ModelArchitectures.XLM_ROBERTA) +class JiantXLMRobertaModel(JiantTransformersModel): + def __init__(self, baseObject): + super().__init__(baseObject) + self.hf_pretrained_encoder_with_pretrained_head = transformers.XLMRobertaForMaskedLM + + +@JiantTransformersModelFactory.register(ModelArchitectures.ELECTRA) +class JiantElectraModel(JiantTransformersModel): + def __init__(self, baseObject): + super().__init__(baseObject) + self.hf_pretrained_encoder_with_pretrained_head = transformers.ElectraForPreTraining + + +@JiantTransformersModelFactory.register(ModelArchitectures.BART) +class JiantBartModel(JiantTransformersModel): + def __init__(self, baseObject): + super().__init__(baseObject) + self.hf_pretrained_encoder_with_pretrained_head = transformers.BartForConditionalGeneration + + def get_hidden_size(self): + return self.config.d_model + + def get_hidden_dropout_prob(self): + return self.config.dropout + + +@JiantTransformersModelFactory.register(ModelArchitectures.MBART) +class JiantMBartModel(JiantBartModel): + def __init__(self, baseObject): + super().__init__(baseObject) diff --git a/jiant/proj/main/modeling/taskmodels.py b/jiant/proj/main/modeling/taskmodels.py index c6507708f..9662f315b 100644 --- a/jiant/proj/main/modeling/taskmodels.py +++ b/jiant/proj/main/modeling/taskmodels.py @@ -1,6 +1,8 @@ import abc + from dataclasses import dataclass from typing import Any +from typing import Callable import torch import torch.nn as nn @@ -9,44 +11,93 @@ from jiant.proj.main.components.outputs import LogitsOutput, LogitsAndLossOutput from jiant.utils.python.datastructures import take_one from jiant.shared.model_resolution import ModelArchitectures +from jiant.tasks.core import TaskTypes + + +class JiantTaskModelFactory: + """This factory is used to create task models bundling the task, + encoder, and task head within the task model. + + Attributes: + registry (dict): Dynamic registry mapping task types to task models + """ + + registry = {} + + @classmethod + def register(cls, task_type: TaskTypes) -> Callable: + """Register task_type as a key mapping to a TaskModel + + Args: + task_type (TaskTypes): TaskType key mapping to a BaseHead task head + + Returns: + Callable: inner_wrapper() wrapping TaskModel constructor + """ + + def inner_wrapper(wrapped_class: Taskmodel) -> Callable: + assert task_type not in cls.registry + cls.registry[task_type] = wrapped_class + return wrapped_class + + return inner_wrapper + + def __call__(cls, task, jiant_transformers_model, head, **kwargs): + """This creates the TaskModel corresponding to the Task, abc.abstractmethod, + and encoder used. + + Args: + task (Task): Task + jiant_transformers_model (JiantTransformersModel): Encoder + head (BaseHead): Task head + **kwargs: Additional arguments for initializing TaskModel + + Returns: + TaskModel: Initialized task model bundling task, encoder, and head + """ + taskmodel_class = cls.registry[task.TASK_TYPE] + taskmodel = taskmodel_class(task, jiant_transformers_model, head, **kwargs) + return taskmodel class Taskmodel(nn.Module, metaclass=abc.ABCMeta): - def __init__(self, encoder): + def __init__(self, task, encoder, head): super().__init__() + self.task = task self.encoder = encoder + self.head = head - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): raise NotImplementedError +@JiantTaskModelFactory.register(TaskTypes.CLASSIFICATION) class ClassificationModel(Taskmodel): - def __init__(self, encoder, classification_head: heads.ClassificationHead): - super().__init__(encoder=encoder) - self.classification_head = classification_head + def __init__(self, task, encoder, head: heads.ClassificationHead, **kwargs): + + super().__init__(task=task, encoder=encoder, head=head) + + def forward(self, batch, tokenizer, compute_loss: bool = False): - def forward(self, batch, task, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.classification_head(pooled=encoder_output.pooled) + logits = self.head(pooled=encoder_output.pooled) if compute_loss: loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - logits.view(-1, self.classification_head.num_labels), batch.label_id.view(-1), - ) + loss = loss_fct(logits.view(-1, self.head.num_labels), batch.label_id.view(-1),) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.REGRESSION) class RegressionModel(Taskmodel): - def __init__(self, encoder, regression_head: heads.RegressionHead): - super().__init__(encoder=encoder) - self.regression_head = regression_head + def __init__(self, task, encoder, head: heads.RegressionHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) # TODO: Abuse of notation - these aren't really logits (issue #1187) - logits = self.regression_head(pooled=encoder_output.pooled) + logits = self.head(pooled=encoder_output.pooled) if compute_loss: loss_fct = nn.MSELoss() loss = loss_fct(logits.view(-1), batch.label.view(-1)) @@ -55,13 +106,13 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.MULTIPLE_CHOICE) class MultipleChoiceModel(Taskmodel): - def __init__(self, encoder, num_choices: int, choice_scoring_head: heads.RegressionHead): - super().__init__(encoder=encoder) - self.num_choices = num_choices - self.choice_scoring_head = choice_scoring_head + def __init__(self, task, encoder, head: heads.RegressionHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) + self.num_choices = task.NUM_CHOICES - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): input_ids = batch.input_ids segment_ids = batch.segment_ids input_mask = batch.input_mask @@ -75,7 +126,7 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): segment_ids=segment_ids[:, i], input_mask=input_mask[:, i], ) - choice_score = self.choice_scoring_head(pooled=encoder_output.pooled) + choice_score = self.head(pooled=encoder_output.pooled) choice_score_list.append(choice_score) encoder_output_other_ls.append(encoder_output.other) @@ -102,36 +153,34 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=reshaped_outputs) +@JiantTaskModelFactory.register(TaskTypes.SPAN_COMPARISON_CLASSIFICATION) class SpanComparisonModel(Taskmodel): - def __init__(self, encoder, span_comparison_head: heads.SpanComparisonHead): - super().__init__(encoder=encoder) - self.span_comparison_head = span_comparison_head + def __init__(self, task, encoder, head: heads.SpanComparisonHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.span_comparison_head(unpooled=encoder_output.unpooled, spans=batch.spans) + logits = self.head(unpooled=encoder_output.unpooled, spans=batch.spans) if compute_loss: loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - logits.view(-1, self.span_comparison_head.num_labels), batch.label_id.view(-1), - ) + loss = loss_fct(logits.view(-1, self.head.num_labels), batch.label_id.view(-1),) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.SPAN_PREDICTION) class SpanPredictionModel(Taskmodel): - def __init__(self, encoder, span_prediction_head: heads.TokenClassificationHead): - super().__init__(encoder=encoder) + def __init__(self, task, encoder, head: heads.TokenClassificationHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) self.offset_margin = 1000 # 1000 is a big enough number that exp(-1000) will be strict 0 in float32. # So that if we add 1000 to the valid dimensions in the input of softmax, # we can guarantee the output distribution will only be non-zero at those dimensions. - self.span_prediction_head = span_prediction_head - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.span_prediction_head(unpooled=encoder_output.unpooled) + logits = self.head(unpooled=encoder_output.unpooled) # Ensure logits in valid range is at least self.offset_margin higher than others logits_offset = logits.max() - logits.min() + self.offset_margin logits = logits + logits_offset * batch.selection_token_mask.unsqueeze(dim=2) @@ -145,38 +194,36 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.MULTI_LABEL_SPAN_CLASSIFICATION) class MultiLabelSpanComparisonModel(Taskmodel): - def __init__(self, encoder, span_comparison_head: heads.SpanComparisonHead): - super().__init__(encoder=encoder) - self.span_comparison_head = span_comparison_head + def __init__(self, task, encoder, head: heads.SpanComparisonHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.span_comparison_head(unpooled=encoder_output.unpooled, spans=batch.spans) + logits = self.head(unpooled=encoder_output.unpooled, spans=batch.spans) if compute_loss: loss_fct = nn.BCEWithLogitsLoss() - loss = loss_fct( - logits.view(-1, self.span_comparison_head.num_labels), batch.label_ids.float(), - ) + loss = loss_fct(logits.view(-1, self.head.num_labels), batch.label_ids.float(),) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.TAGGING) class TokenClassificationModel(Taskmodel): """From RobertaForTokenClassification""" - def __init__(self, encoder, token_classification_head: heads.TokenClassificationHead): - super().__init__(encoder=encoder) - self.token_classification_head = token_classification_head + def __init__(self, task, encoder, head: heads.TokenClassificationHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.token_classification_head(unpooled=encoder_output.unpooled) + logits = self.head(unpooled=encoder_output.unpooled) if compute_loss: loss_fct = nn.CrossEntropyLoss() active_loss = batch.label_mask.view(-1) == 1 - active_logits = logits.view(-1, self.token_classification_head.num_labels)[active_loss] + active_logits = logits.view(-1, self.head.num_labels)[active_loss] active_labels = batch.label_ids.view(-1)[active_loss] loss = loss_fct(active_logits, active_labels) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) @@ -184,14 +231,14 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.SQUAD_STYLE_QA) class QAModel(Taskmodel): - def __init__(self, encoder, qa_head: heads.QAHead): - super().__init__(encoder=encoder) - self.qa_head = qa_head + def __init__(self, task, encoder, head: heads.QAHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.qa_head(unpooled=encoder_output.unpooled) + logits = self.head(unpooled=encoder_output.unpooled) if compute_loss: loss = compute_qa_loss( logits=logits, @@ -203,14 +250,16 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.MASKED_LANGUAGE_MODELING) class MLMModel(Taskmodel): - def __init__(self, encoder, mlm_head: heads.BaseMLMHead): - super().__init__(encoder=encoder) - self.mlm_head = mlm_head + def __init__(self, task, encoder, head: heads.BaseMLMHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): masked_batch = batch.get_masked( - mlm_probability=task.mlm_probability, tokenizer=tokenizer, do_mask=task.do_mask, + mlm_probability=self.task.mlm_probability, + tokenizer=tokenizer, + do_mask=self.task.do_mask, ) encoder_output = get_output_from_encoder( encoder=self.encoder, @@ -218,7 +267,7 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): segment_ids=masked_batch.segment_ids, input_mask=masked_batch.input_mask, ) - logits = self.mlm_head(unpooled=encoder_output.unpooled) + logits = self.head(unpooled=encoder_output.unpooled) if compute_loss: loss = compute_mlm_loss(logits=logits, masked_lm_labels=masked_batch.masked_lm_labels) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) @@ -226,26 +275,27 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.EMBEDDING) class EmbeddingModel(Taskmodel): - def __init__(self, encoder, pooler_head: heads.AbstractPoolerHead, layer): - super().__init__(encoder=encoder) - self.pooler_head = pooler_head - self.layer = layer + def __init__(self, task, encoder, head: heads.AbstractPoolerHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) + self.layer = kwargs["layer"] def forward(self, batch, task, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch( - encoder=self.encoder, batch=batch, output_hidden_states=True + encoder_output=self.encoder, batch=batch, output_hidden_states=True ) + # A tuple of layers of hidden states hidden_states = take_one(encoder_output.other) layer_hidden_states = hidden_states[self.layer] - if isinstance(self.pooler_head, heads.MeanPoolerHead): - logits = self.pooler_head(unpooled=layer_hidden_states, input_mask=batch.input_mask) - elif isinstance(self.pooler_head, heads.FirstPoolerHead): - logits = self.pooler_head(layer_hidden_states) + if isinstance(self.head, heads.MeanPoolerHead): + logits = self.head(unpooled=layer_hidden_states, input_mask=batch.input_mask) + elif isinstance(self.head, heads.FirstPoolerHead): + logits = self.head(layer_hidden_states) else: - raise TypeError(type(self.pooler_head)) + raise TypeError(type(self.head)) # TODO: Abuse of notation - these aren't really logits (issue #1187) if compute_loss: @@ -261,6 +311,7 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): @dataclass class EncoderOutput: + pooled: torch.Tensor unpooled: torch.Tensor other: Any = None diff --git a/jiant/shared/model_resolution.py b/jiant/shared/model_resolution.py index 7ec446137..c8a2ab783 100644 --- a/jiant/shared/model_resolution.py +++ b/jiant/shared/model_resolution.py @@ -7,105 +7,14 @@ class ModelArchitectures(Enum): - BERT = 1 - XLM = 2 - ROBERTA = 3 - ALBERT = 4 - XLM_ROBERTA = 5 - BART = 6 - MBART = 7 - ELECTRA = 8 - - @classmethod - def from_model_type(cls, model_type: str): - """Get the model architecture for the provided shortcut name. - - Args: - model_type (str): model shortcut name. - - Returns: - Model architecture associated with the provided shortcut name. - - """ - if model_type.startswith("bert"): - return cls.BERT - elif model_type.startswith("xlm") and not model_type.startswith("xlm-roberta"): - return cls.XLM - elif model_type.startswith("roberta"): - return cls.ROBERTA - elif model_type.startswith("albert"): - return cls.ALBERT - elif model_type == "glove_lstm": - return cls.GLOVE_LSTM - elif model_type.startswith("xlm-roberta"): - return cls.XLM_ROBERTA - elif model_type.startswith("bart"): - return cls.BART - elif model_type.startswith("mbart"): - return cls.MBART - elif model_type.startswith("electra"): - return cls.ELECTRA - else: - raise KeyError(model_type) - - @classmethod - def from_transformers_model(cls, transformers_model): - if isinstance( - transformers_model, transformers.BertPreTrainedModel - ) and transformers_model.__class__.__name__.startswith("Bert"): - return cls.BERT - elif isinstance(transformers_model, transformers.XLMPreTrainedModel): - return cls.XLM - elif isinstance( - transformers_model, transformers.BertPreTrainedModel - ) and transformers_model.__class__.__name__.startswith("Robert"): - return cls.ROBERTA - elif isinstance( - transformers_model, transformers.BertPreTrainedModel - ) and transformers_model.__class__.__name__.startswith("XLMRoberta"): - return cls.XLM_ROBERTA - elif isinstance(transformers_model, transformers.modeling_albert.AlbertPreTrainedModel): - return cls.ALBERT - elif isinstance(transformers_model, transformers.modeling_bart.PretrainedBartModel): - return bart_or_mbart_model_heuristic(model_config=transformers_model.config) - elif isinstance(transformers_model, transformers.modeling_electra.ElectraPreTrainedModel): - return cls.ELECTRA - else: - raise KeyError(str(transformers_model)) - - @classmethod - def from_tokenizer_class(cls, tokenizer_class): - if isinstance(tokenizer_class, transformers.BertTokenizer): - return cls.BERT - elif isinstance(tokenizer_class, transformers.XLMTokenizer): - return cls.XLM - elif isinstance(tokenizer_class, transformers.RobertaTokenizer): - return cls.ROBERTA - elif isinstance(tokenizer_class, transformers.XLMRobertaTokenizer): - return cls.XLM_ROBERTA - elif isinstance(tokenizer_class, transformers.AlbertTokenizer): - return cls.ALBERT - elif isinstance(tokenizer_class, transformers.BartTokenizer): - return cls.BART - elif isinstance(tokenizer_class, transformers.MBartTokenizer): - return cls.MBART - elif isinstance(tokenizer_class, transformers.ElectraTokenizer): - return cls.ELECTRA - else: - raise KeyError(str(tokenizer_class)) - - @classmethod - def is_transformers_model_arch(cls, model_arch): - return model_arch in [ - cls.BERT, - cls.XLM, - cls.ROBERTA, - cls.ALBERT, - cls.XLM_ROBERTA, - cls.BART, - cls.MBART, - cls.ELECTRA, - ] + BERT = "bert" + XLM = "xlm" + ROBERTA = "roberta" + ALBERT = "albert" + XLM_ROBERTA = "xlm-roberta" + BART = "bart" + MBART = "mbart" + ELECTRA = "electra" @classmethod def from_encoder(cls, encoder): @@ -155,7 +64,7 @@ class ModelClassSpec: def build_featurization_spec(model_type, max_seq_length): - model_arch = ModelArchitectures.from_model_type(model_type) + model_arch = ModelArchitectures(model_type) if model_arch == ModelArchitectures.BERT: return FeaturizationSpec( max_seq_length=max_seq_length, @@ -303,7 +212,7 @@ def resolve_tokenizer_class(model_type): Tokenizer associated with the given model. """ - return TOKENIZER_CLASS_DICT[ModelArchitectures.from_model_type(model_type)] + return TOKENIZER_CLASS_DICT[ModelArchitectures(model_type)] def resolve_is_lower_case(tokenizer): diff --git a/jiant/tasks/core.py b/jiant/tasks/core.py index 59471f163..85cc0e0e4 100644 --- a/jiant/tasks/core.py +++ b/jiant/tasks/core.py @@ -85,18 +85,18 @@ def data_row_collate_fn(batch): class TaskTypes(Enum): - CLASSIFICATION = 1 - REGRESSION = 2 - SPAN_COMPARISON_CLASSIFICATION = 3 - MULTIPLE_CHOICE = 4 - SPAN_CHOICE_PROB_TASK = 5 - SQUAD_STYLE_QA = 6 - TAGGING = 7 - MASKED_LANGUAGE_MODELING = 8 - EMBEDDING = 9 - MULTI_LABEL_SPAN_CLASSIFICATION = 10 - SPAN_PREDICTION = 11 - UNDEFINED = -1 + CLASSIFICATION = "classification" + REGRESSION = "regression" + SPAN_COMPARISON_CLASSIFICATION = "span_comparison_classification" + MULTIPLE_CHOICE = "multiple_choice" + SPAN_CHOICE_PROB_TASK = "span_choice_prob_task" + SQUAD_STYLE_QA = "squad_style_qa" + TAGGING = "tagging" + MASKED_LANGUAGE_MODELING = "masked_language_modeling" + EMBEDDING = "embedding" + MULTI_LABEL_SPAN_CLASSIFICATION = "multi_label_span_classification" + SPAN_PREDICTION = "span_prediction" + UNDEFINED = "undefined" class BatchTuple(NamedTuple): diff --git a/tests/tasks/lib/test_mlm_premasked.py b/tests/tasks/lib/test_mlm_premasked.py index 17a5dd0fe..48d25405f 100644 --- a/tests/tasks/lib/test_mlm_premasked.py +++ b/tests/tasks/lib/test_mlm_premasked.py @@ -35,7 +35,7 @@ def test_tokenization_and_featurization(): data_row = tokenized_example.featurize( tokenizer=tokenizer, feat_spec=model_resolution.build_featurization_spec( - model_type="roberta-base", max_seq_length=16, + model_type="roberta", max_seq_length=16, ), ) assert list(data_row.masked_input_ids) == [ diff --git a/tests/tasks/lib/test_mlm_pretokenized.py b/tests/tasks/lib/test_mlm_pretokenized.py index 2b682f234..55230ee29 100644 --- a/tests/tasks/lib/test_mlm_pretokenized.py +++ b/tests/tasks/lib/test_mlm_pretokenized.py @@ -3,6 +3,8 @@ import jiant.shared.model_resolution as model_resolution import jiant.tasks as tasks +from jiant.shared.model_resolution import ModelArchitectures + def test_tokenization_and_featurization(): task = tasks.MLMPretokenizedTask(name="mlm_pretokenized", path_dict={}) @@ -37,7 +39,7 @@ def test_tokenization_and_featurization(): data_row = tokenized_example.featurize( tokenizer=tokenizer, feat_spec=model_resolution.build_featurization_spec( - model_type="roberta-base", max_seq_length=16, + model_type=ModelArchitectures.ROBERTA.value, max_seq_length=16, ), ) assert list(data_row.masked_input_ids) == [ diff --git a/tests/tasks/lib/test_mnli.py b/tests/tasks/lib/test_mnli.py index 2b564e41d..d4740dd26 100644 --- a/tests/tasks/lib/test_mnli.py +++ b/tests/tasks/lib/test_mnli.py @@ -1,8 +1,10 @@ +import numpy as np import os + from collections import Counter -import numpy as np from jiant.shared import model_resolution +from jiant.shared.model_resolution import ModelArchitectures from jiant.tasks import create_task_from_config_path from jiant.utils.testing.tokenizer import SimpleSpaceTokenizer @@ -301,7 +303,7 @@ def test_featurization_of_task_data(): tokenized_examples[0].hypothesis ) feat_spec = model_resolution.build_featurization_spec( - model_type="bert-", max_seq_length=train_example_0_length + model_type=ModelArchitectures.BERT.value, max_seq_length=train_example_0_length ) featurized_examples = [ tokenized_example.featurize(tokenizer=tokenizer, feat_spec=feat_spec) diff --git a/tests/tasks/lib/test_spr1.py b/tests/tasks/lib/test_spr1.py index fb8aefb2a..d3c394a58 100644 --- a/tests/tasks/lib/test_spr1.py +++ b/tests/tasks/lib/test_spr1.py @@ -315,7 +315,7 @@ def test_featurization_of_task_data(): # Testing conversion of a tokenized example to a featurized example train_example_0_length = len(tokenized_examples[0].tokens) + 4 feat_spec = model_resolution.build_featurization_spec( - model_type="bert-", max_seq_length=train_example_0_length + model_type="bert", max_seq_length=train_example_0_length ) featurized_examples = [ tokenized_example.featurize(tokenizer=tokenizer, feat_spec=feat_spec)