From 07d3fdb0a6b7f07cf10c8047b1fbd0c63ffad1e6 Mon Sep 17 00:00:00 2001 From: Jesse Swanson Date: Wed, 20 Jan 2021 16:08:00 -0800 Subject: [PATCH] add dynamic registry for taskmodels --- jiant/proj/main/modeling/primary.py | 2 +- jiant/proj/main/modeling/taskmodels.py | 192 +++++++++++-------------- jiant/tasks/core.py | 24 ++-- 3 files changed, 98 insertions(+), 120 deletions(-) diff --git a/jiant/proj/main/modeling/primary.py b/jiant/proj/main/modeling/primary.py index 0726cf4cb..24e3e65b6 100644 --- a/jiant/proj/main/modeling/primary.py +++ b/jiant/proj/main/modeling/primary.py @@ -53,7 +53,7 @@ def forward(self, batch: tasks.BatchMixin, task: tasks.Task, compute_loss: bool taskmodel_key = self.task_to_taskmodel_map[task_name] taskmodel = self.taskmodels_dict[taskmodel_key] return taskmodel( - batch=batch, task=task, tokenizer=self.tokenizer, compute_loss=compute_loss, + batch=batch, tokenizer=self.tokenizer, compute_loss=compute_loss, ).to_dict() diff --git a/jiant/proj/main/modeling/taskmodels.py b/jiant/proj/main/modeling/taskmodels.py index f7b24451c..342585424 100644 --- a/jiant/proj/main/modeling/taskmodels.py +++ b/jiant/proj/main/modeling/taskmodels.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import Any +from typing import Callable import torch import torch.nn as nn @@ -17,83 +18,62 @@ class TaskModelFactory: - @staticmethod - def create_taskmodel(task, jiant_transformers_model, head, taskmodel_kwargs): - if task.TASK_TYPE == TaskTypes.CLASSIFICATION: - taskmodel = ClassificationModel( - encoder=jiant_transformers_model, classification_head=head, - ) - elif task.TASK_TYPE == TaskTypes.REGRESSION: - taskmodel = RegressionModel(encoder=jiant_transformers_model, regression_head=head) - elif task.TASK_TYPE == TaskTypes.MULTIPLE_CHOICE: - taskmodel = MultipleChoiceModel( - encoder=jiant_transformers_model, - num_choices=task.NUM_CHOICES, - choice_scoring_head=head, - ) - elif task.TASK_TYPE == TaskTypes.SPAN_PREDICTION: - taskmodel = SpanPredictionModel(encoder=encoder, span_prediction_head=head,) - elif task.TASK_TYPE == TaskTypes.SPAN_COMPARISON_CLASSIFICATION: - taskmodel = SpanComparisonModel( - encoder=jiant_transformers_model, span_comparison_head=head, - ) - elif task.TASK_TYPE == TaskTypes.MULTI_LABEL_SPAN_CLASSIFICATION: - taskmodel = MultiLabelSpanComparisonModel( - encoder=jiant_transformers_model, span_comparison_head=head, - ) - elif task.TASK_TYPE == TaskTypes.TAGGING: - taskmodel = TokenClassificationModel( - encoder=jiant_transformers_model, token_classification_head=head, - ) - elif task.TASK_TYPE == TaskTypes.SQUAD_STYLE_QA: - taskmodel = QAModel(encoder=jiant_transformers_model, qa_head=head) - elif task.TASK_TYPE == TaskTypes.MASKED_LANGUAGE_MODELING: - taskmodel = MLMModel(encoder=encoder, mlm_head=head) - elif task.TASK_TYPE == TaskTypes.EMBEDDING: - taskmodel = EmbeddingModel( - encoder=encoder, pooler_head=pooler_head, layer=taskmodel_kwargs["layer"], - ) - else: - raise KeyError(task.TASK_TYPE) + + # Internal registry for available task models + registry = {} + + @classmethod + def register(cls, task_type: TaskTypes) -> Callable: + def inner_wrapper(wrapped_class: Taskmodel) -> Callable: + assert task_type not in cls.registry + cls.registry[task_type] = wrapped_class + return wrapped_class + + return inner_wrapper + + @classmethod + def create_taskmodel(cls, task, jiant_transformers_model, head, taskmodel_kwargs=None): + taskmodel_class = cls.registry[task.TASK_TYPE] + taskmodel = taskmodel_class(task, jiant_transformers_model, head, taskmodel_kwargs) return taskmodel class Taskmodel(nn.Module, metaclass=abc.ABCMeta): - def __init__(self, encoder): + def __init__(self, task, encoder, head): super().__init__() + self.task = task self.encoder = encoder + self.head = head - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): raise NotImplementedError +@TaskModelFactory.register(TaskTypes.CLASSIFICATION) class ClassificationModel(Taskmodel): - def __init__(self, encoder, classification_head: heads.ClassificationHead): - super().__init__(encoder=encoder) - self.classification_head = classification_head + def __init__(self, task, encoder, head: heads.ClassificationHead, taskmodel_kwargs=None): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.classification_head(pooled=encoder_output.pooled) + logits = self.head(pooled=encoder_output.pooled) if compute_loss: loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - logits.view(-1, self.classification_head.num_labels), batch.label_id.view(-1), - ) + loss = loss_fct(logits.view(-1, self.head.num_labels), batch.label_id.view(-1),) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other) +@TaskModelFactory.register(TaskTypes.REGRESSION) class RegressionModel(Taskmodel): - def __init__(self, encoder, regression_head: heads.RegressionHead): - super().__init__(encoder=encoder) - self.regression_head = regression_head + def __init__(self, task, encoder, head: heads.RegressionHead, taskmodel_kwargs=None): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) # TODO: Abuse of notation - these aren't really logits (issue #1187) - logits = self.regression_head(pooled=encoder_output.pooled) + logits = self.head(pooled=encoder_output.pooled) if compute_loss: loss_fct = nn.MSELoss() loss = loss_fct(logits.view(-1), batch.label.view(-1)) @@ -102,13 +82,13 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@TaskModelFactory.register(TaskTypes.MULTIPLE_CHOICE) class MultipleChoiceModel(Taskmodel): - def __init__(self, encoder, num_choices: int, choice_scoring_head: heads.RegressionHead): - super().__init__(encoder=encoder) - self.num_choices = num_choices - self.choice_scoring_head = choice_scoring_head + def __init__(self, task, encoder, head: heads.RegressionHead, taskmodel_kwargs=None): + super().__init__(task=task, encoder=encoder, head=head) + self.num_choices = task.NUM_CHOICES - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): input_ids = batch.input_ids segment_ids = batch.segment_ids input_mask = batch.input_mask @@ -122,7 +102,7 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): segment_ids=segment_ids[:, i], input_mask=input_mask[:, i], ) - choice_score = self.choice_scoring_head(pooled=encoder_output.pooled) + choice_score = self.head(pooled=encoder_output.pooled) choice_score_list.append(choice_score) encoder_output_other_ls.append(encoder_output.other) @@ -149,36 +129,34 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=reshaped_outputs) +@TaskModelFactory.register(TaskTypes.SPAN_COMPARISON_CLASSIFICATION) class SpanComparisonModel(Taskmodel): - def __init__(self, encoder, span_comparison_head: heads.SpanComparisonHead): - super().__init__(encoder=encoder) - self.span_comparison_head = span_comparison_head + def __init__(self, task, encoder, head: heads.SpanComparisonHead, taskmodel_kwargs=None): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.span_comparison_head(unpooled=encoder_output.unpooled, spans=batch.spans) + logits = self.head(unpooled=encoder_output.unpooled, spans=batch.spans) if compute_loss: loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - logits.view(-1, self.span_comparison_head.num_labels), batch.label_id.view(-1), - ) + loss = loss_fct(logits.view(-1, self.head.num_labels), batch.label_id.view(-1),) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other) +@TaskModelFactory.register(TaskTypes.SPAN_PREDICTION) class SpanPredictionModel(Taskmodel): - def __init__(self, encoder, span_prediction_head: heads.TokenClassificationHead): - super().__init__(encoder=encoder) + def __init__(self, task, encoder, head: heads.TokenClassificationHead, taskmodel_kwargs=None): + super().__init__(task=task, encoder=encoder, head=head) self.offset_margin = 1000 # 1000 is a big enough number that exp(-1000) will be strict 0 in float32. # So that if we add 1000 to the valid dimensions in the input of softmax, # we can guarantee the output distribution will only be non-zero at those dimensions. - self.span_prediction_head = span_prediction_head - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.span_prediction_head(unpooled=encoder_output.unpooled) + logits = self.head(unpooled=encoder_output.unpooled) # Ensure logits in valid range is at least self.offset_margin higher than others logits_offset = logits.max() - logits.min() + self.offset_margin logits = logits + logits_offset * batch.selection_token_mask.unsqueeze(dim=2) @@ -192,38 +170,36 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@TaskModelFactory.register(TaskTypes.MULTI_LABEL_SPAN_CLASSIFICATION) class MultiLabelSpanComparisonModel(Taskmodel): - def __init__(self, encoder, span_comparison_head: heads.SpanComparisonHead): - super().__init__(encoder=encoder) - self.span_comparison_head = span_comparison_head + def __init__(self, task, encoder, head: heads.SpanComparisonHead, taskmodel_kwargs=None): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.span_comparison_head(unpooled=encoder_output.unpooled, spans=batch.spans) + logits = self.head(unpooled=encoder_output.unpooled, spans=batch.spans) if compute_loss: loss_fct = nn.BCEWithLogitsLoss() - loss = loss_fct( - logits.view(-1, self.span_comparison_head.num_labels), batch.label_ids.float(), - ) + loss = loss_fct(logits.view(-1, self.head.num_labels), batch.label_ids.float(),) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other) +@TaskModelFactory.register(TaskTypes.TAGGING) class TokenClassificationModel(Taskmodel): """From RobertaForTokenClassification""" - def __init__(self, encoder, token_classification_head: heads.TokenClassificationHead): - super().__init__(encoder=encoder) - self.token_classification_head = token_classification_head + def __init__(self, task, encoder, head: heads.TokenClassificationHead, taskmodel_kwargs=None): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.token_classification_head(unpooled=encoder_output.unpooled) + logits = self.head(unpooled=encoder_output.unpooled) if compute_loss: loss_fct = nn.CrossEntropyLoss() active_loss = batch.label_mask.view(-1) == 1 - active_logits = logits.view(-1, self.token_classification_head.num_labels)[active_loss] + active_logits = logits.view(-1, self.head.num_labels)[active_loss] active_labels = batch.label_ids.view(-1)[active_loss] loss = loss_fct(active_logits, active_labels) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) @@ -231,14 +207,14 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@TaskModelFactory.register(TaskTypes.SQUAD_STYLE_QA) class QAModel(Taskmodel): - def __init__(self, encoder, qa_head: heads.QAHead): - super().__init__(encoder=encoder) - self.qa_head = qa_head + def __init__(self, task, encoder, head: heads.QAHead, taskmodel_kwargs=None): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.qa_head(unpooled=encoder_output.unpooled) + logits = self.head(unpooled=encoder_output.unpooled) if compute_loss: loss = compute_qa_loss( logits=logits, @@ -250,14 +226,16 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@TaskModelFactory.register(TaskTypes.MASKED_LANGUAGE_MODELING) class MLMModel(Taskmodel): - def __init__(self, encoder, mlm_head: heads.BaseMLMHead): - super().__init__(encoder=encoder) - self.mlm_head = mlm_head + def __init__(self, task, encoder, head: heads.BaseMLMHead, taskmodel_kwargs=None): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): masked_batch = batch.get_masked( - mlm_probability=task.mlm_probability, tokenizer=tokenizer, do_mask=task.do_mask, + mlm_probability=self.task.mlm_probability, + tokenizer=tokenizer, + do_mask=self.task.do_mask, ) encoder_output = get_output_from_encoder( encoder=self.encoder, @@ -265,7 +243,7 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): segment_ids=masked_batch.segment_ids, input_mask=masked_batch.input_mask, ) - logits = self.mlm_head(unpooled=encoder_output.unpooled) + logits = self.head(unpooled=encoder_output.unpooled) if compute_loss: loss = compute_mlm_loss(logits=logits, masked_lm_labels=masked_batch.masked_lm_labels) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) @@ -273,25 +251,25 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@TaskModelFactory.register(TaskTypes.EMBEDDING) class EmbeddingModel(Taskmodel): - def __init__(self, encoder, pooler_head: heads.AbstractPoolerHead, layer): - super().__init__(encoder=encoder) - self.pooler_head = pooler_head - self.layer = layer + def __init__(self, task, encoder, head: heads.AbstractPoolerHead, taskmodel_kwargs): + super().__init__(task=task, encoder=encoder, head=head) + self.layer = taskmodel_kwargs["layer"] - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): with transformer_utils.output_hidden_states_context(self.encoder): encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) # A tuple of layers of hidden states hidden_states = take_one(encoder_output.other) layer_hidden_states = hidden_states[self.layer] - if isinstance(self.pooler_head, heads.MeanPoolerHead): - logits = self.pooler_head(unpooled=layer_hidden_states, input_mask=batch.input_mask) - elif isinstance(self.pooler_head, heads.FirstPoolerHead): - logits = self.pooler_head(layer_hidden_states) + if isinstance(self.head, heads.MeanPoolerHead): + logits = self.head(unpooled=layer_hidden_states, input_mask=batch.input_mask) + elif isinstance(self.head, heads.FirstPoolerHead): + logits = self.head(layer_hidden_states) else: - raise TypeError(type(self.pooler_head)) + raise TypeError(type(self.head)) # TODO: Abuse of notation - these aren't really logits (issue #1187) if compute_loss: diff --git a/jiant/tasks/core.py b/jiant/tasks/core.py index 59471f163..85cc0e0e4 100644 --- a/jiant/tasks/core.py +++ b/jiant/tasks/core.py @@ -85,18 +85,18 @@ def data_row_collate_fn(batch): class TaskTypes(Enum): - CLASSIFICATION = 1 - REGRESSION = 2 - SPAN_COMPARISON_CLASSIFICATION = 3 - MULTIPLE_CHOICE = 4 - SPAN_CHOICE_PROB_TASK = 5 - SQUAD_STYLE_QA = 6 - TAGGING = 7 - MASKED_LANGUAGE_MODELING = 8 - EMBEDDING = 9 - MULTI_LABEL_SPAN_CLASSIFICATION = 10 - SPAN_PREDICTION = 11 - UNDEFINED = -1 + CLASSIFICATION = "classification" + REGRESSION = "regression" + SPAN_COMPARISON_CLASSIFICATION = "span_comparison_classification" + MULTIPLE_CHOICE = "multiple_choice" + SPAN_CHOICE_PROB_TASK = "span_choice_prob_task" + SQUAD_STYLE_QA = "squad_style_qa" + TAGGING = "tagging" + MASKED_LANGUAGE_MODELING = "masked_language_modeling" + EMBEDDING = "embedding" + MULTI_LABEL_SPAN_CLASSIFICATION = "multi_label_span_classification" + SPAN_PREDICTION = "span_prediction" + UNDEFINED = "undefined" class BatchTuple(NamedTuple):