Skip to content

Commit

Permalink
add dynamic registry for taskmodels
Browse files Browse the repository at this point in the history
  • Loading branch information
jeswan committed Jan 21, 2021
1 parent b3699e5 commit 07d3fdb
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 120 deletions.
2 changes: 1 addition & 1 deletion jiant/proj/main/modeling/primary.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def forward(self, batch: tasks.BatchMixin, task: tasks.Task, compute_loss: bool
taskmodel_key = self.task_to_taskmodel_map[task_name]
taskmodel = self.taskmodels_dict[taskmodel_key]
return taskmodel(
batch=batch, task=task, tokenizer=self.tokenizer, compute_loss=compute_loss,
batch=batch, tokenizer=self.tokenizer, compute_loss=compute_loss,
).to_dict()


Expand Down
192 changes: 85 additions & 107 deletions jiant/proj/main/modeling/taskmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dataclasses import dataclass
from typing import Any
from typing import Callable

import torch
import torch.nn as nn
Expand All @@ -17,83 +18,62 @@


class TaskModelFactory:
@staticmethod
def create_taskmodel(task, jiant_transformers_model, head, taskmodel_kwargs):
if task.TASK_TYPE == TaskTypes.CLASSIFICATION:
taskmodel = ClassificationModel(
encoder=jiant_transformers_model, classification_head=head,
)
elif task.TASK_TYPE == TaskTypes.REGRESSION:
taskmodel = RegressionModel(encoder=jiant_transformers_model, regression_head=head)
elif task.TASK_TYPE == TaskTypes.MULTIPLE_CHOICE:
taskmodel = MultipleChoiceModel(
encoder=jiant_transformers_model,
num_choices=task.NUM_CHOICES,
choice_scoring_head=head,
)
elif task.TASK_TYPE == TaskTypes.SPAN_PREDICTION:
taskmodel = SpanPredictionModel(encoder=encoder, span_prediction_head=head,)
elif task.TASK_TYPE == TaskTypes.SPAN_COMPARISON_CLASSIFICATION:
taskmodel = SpanComparisonModel(
encoder=jiant_transformers_model, span_comparison_head=head,
)
elif task.TASK_TYPE == TaskTypes.MULTI_LABEL_SPAN_CLASSIFICATION:
taskmodel = MultiLabelSpanComparisonModel(
encoder=jiant_transformers_model, span_comparison_head=head,
)
elif task.TASK_TYPE == TaskTypes.TAGGING:
taskmodel = TokenClassificationModel(
encoder=jiant_transformers_model, token_classification_head=head,
)
elif task.TASK_TYPE == TaskTypes.SQUAD_STYLE_QA:
taskmodel = QAModel(encoder=jiant_transformers_model, qa_head=head)
elif task.TASK_TYPE == TaskTypes.MASKED_LANGUAGE_MODELING:
taskmodel = MLMModel(encoder=encoder, mlm_head=head)
elif task.TASK_TYPE == TaskTypes.EMBEDDING:
taskmodel = EmbeddingModel(
encoder=encoder, pooler_head=pooler_head, layer=taskmodel_kwargs["layer"],
)
else:
raise KeyError(task.TASK_TYPE)

# Internal registry for available task models
registry = {}

@classmethod
def register(cls, task_type: TaskTypes) -> Callable:
def inner_wrapper(wrapped_class: Taskmodel) -> Callable:
assert task_type not in cls.registry
cls.registry[task_type] = wrapped_class
return wrapped_class

return inner_wrapper

@classmethod
def create_taskmodel(cls, task, jiant_transformers_model, head, taskmodel_kwargs=None):
taskmodel_class = cls.registry[task.TASK_TYPE]
taskmodel = taskmodel_class(task, jiant_transformers_model, head, taskmodel_kwargs)
return taskmodel


class Taskmodel(nn.Module, metaclass=abc.ABCMeta):
def __init__(self, encoder):
def __init__(self, task, encoder, head):
super().__init__()
self.task = task
self.encoder = encoder
self.head = head

def forward(self, batch, task, tokenizer, compute_loss: bool = False):
def forward(self, batch, tokenizer, compute_loss: bool = False):
raise NotImplementedError


@TaskModelFactory.register(TaskTypes.CLASSIFICATION)
class ClassificationModel(Taskmodel):
def __init__(self, encoder, classification_head: heads.ClassificationHead):
super().__init__(encoder=encoder)
self.classification_head = classification_head
def __init__(self, task, encoder, head: heads.ClassificationHead, taskmodel_kwargs=None):
super().__init__(task=task, encoder=encoder, head=head)

def forward(self, batch, task, tokenizer, compute_loss: bool = False):
def forward(self, batch, tokenizer, compute_loss: bool = False):
encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch)
logits = self.classification_head(pooled=encoder_output.pooled)
logits = self.head(pooled=encoder_output.pooled)
if compute_loss:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
logits.view(-1, self.classification_head.num_labels), batch.label_id.view(-1),
)
loss = loss_fct(logits.view(-1, self.head.num_labels), batch.label_id.view(-1),)
return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other)
else:
return LogitsOutput(logits=logits, other=encoder_output.other)


@TaskModelFactory.register(TaskTypes.REGRESSION)
class RegressionModel(Taskmodel):
def __init__(self, encoder, regression_head: heads.RegressionHead):
super().__init__(encoder=encoder)
self.regression_head = regression_head
def __init__(self, task, encoder, head: heads.RegressionHead, taskmodel_kwargs=None):
super().__init__(task=task, encoder=encoder, head=head)

def forward(self, batch, task, tokenizer, compute_loss: bool = False):
def forward(self, batch, tokenizer, compute_loss: bool = False):
encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch)
# TODO: Abuse of notation - these aren't really logits (issue #1187)
logits = self.regression_head(pooled=encoder_output.pooled)
logits = self.head(pooled=encoder_output.pooled)
if compute_loss:
loss_fct = nn.MSELoss()
loss = loss_fct(logits.view(-1), batch.label.view(-1))
Expand All @@ -102,13 +82,13 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False):
return LogitsOutput(logits=logits, other=encoder_output.other)


@TaskModelFactory.register(TaskTypes.MULTIPLE_CHOICE)
class MultipleChoiceModel(Taskmodel):
def __init__(self, encoder, num_choices: int, choice_scoring_head: heads.RegressionHead):
super().__init__(encoder=encoder)
self.num_choices = num_choices
self.choice_scoring_head = choice_scoring_head
def __init__(self, task, encoder, head: heads.RegressionHead, taskmodel_kwargs=None):
super().__init__(task=task, encoder=encoder, head=head)
self.num_choices = task.NUM_CHOICES

def forward(self, batch, task, tokenizer, compute_loss: bool = False):
def forward(self, batch, tokenizer, compute_loss: bool = False):
input_ids = batch.input_ids
segment_ids = batch.segment_ids
input_mask = batch.input_mask
Expand All @@ -122,7 +102,7 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False):
segment_ids=segment_ids[:, i],
input_mask=input_mask[:, i],
)
choice_score = self.choice_scoring_head(pooled=encoder_output.pooled)
choice_score = self.head(pooled=encoder_output.pooled)
choice_score_list.append(choice_score)
encoder_output_other_ls.append(encoder_output.other)

Expand All @@ -149,36 +129,34 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False):
return LogitsOutput(logits=logits, other=reshaped_outputs)


@TaskModelFactory.register(TaskTypes.SPAN_COMPARISON_CLASSIFICATION)
class SpanComparisonModel(Taskmodel):
def __init__(self, encoder, span_comparison_head: heads.SpanComparisonHead):
super().__init__(encoder=encoder)
self.span_comparison_head = span_comparison_head
def __init__(self, task, encoder, head: heads.SpanComparisonHead, taskmodel_kwargs=None):
super().__init__(task=task, encoder=encoder, head=head)

def forward(self, batch, task, tokenizer, compute_loss: bool = False):
def forward(self, batch, tokenizer, compute_loss: bool = False):
encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch)
logits = self.span_comparison_head(unpooled=encoder_output.unpooled, spans=batch.spans)
logits = self.head(unpooled=encoder_output.unpooled, spans=batch.spans)
if compute_loss:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
logits.view(-1, self.span_comparison_head.num_labels), batch.label_id.view(-1),
)
loss = loss_fct(logits.view(-1, self.head.num_labels), batch.label_id.view(-1),)
return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other)
else:
return LogitsOutput(logits=logits, other=encoder_output.other)


@TaskModelFactory.register(TaskTypes.SPAN_PREDICTION)
class SpanPredictionModel(Taskmodel):
def __init__(self, encoder, span_prediction_head: heads.TokenClassificationHead):
super().__init__(encoder=encoder)
def __init__(self, task, encoder, head: heads.TokenClassificationHead, taskmodel_kwargs=None):
super().__init__(task=task, encoder=encoder, head=head)
self.offset_margin = 1000
# 1000 is a big enough number that exp(-1000) will be strict 0 in float32.
# So that if we add 1000 to the valid dimensions in the input of softmax,
# we can guarantee the output distribution will only be non-zero at those dimensions.
self.span_prediction_head = span_prediction_head

def forward(self, batch, task, tokenizer, compute_loss: bool = False):
def forward(self, batch, tokenizer, compute_loss: bool = False):
encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch)
logits = self.span_prediction_head(unpooled=encoder_output.unpooled)
logits = self.head(unpooled=encoder_output.unpooled)
# Ensure logits in valid range is at least self.offset_margin higher than others
logits_offset = logits.max() - logits.min() + self.offset_margin
logits = logits + logits_offset * batch.selection_token_mask.unsqueeze(dim=2)
Expand All @@ -192,53 +170,51 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False):
return LogitsOutput(logits=logits, other=encoder_output.other)


@TaskModelFactory.register(TaskTypes.MULTI_LABEL_SPAN_CLASSIFICATION)
class MultiLabelSpanComparisonModel(Taskmodel):
def __init__(self, encoder, span_comparison_head: heads.SpanComparisonHead):
super().__init__(encoder=encoder)
self.span_comparison_head = span_comparison_head
def __init__(self, task, encoder, head: heads.SpanComparisonHead, taskmodel_kwargs=None):
super().__init__(task=task, encoder=encoder, head=head)

def forward(self, batch, task, tokenizer, compute_loss: bool = False):
def forward(self, batch, tokenizer, compute_loss: bool = False):
encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch)
logits = self.span_comparison_head(unpooled=encoder_output.unpooled, spans=batch.spans)
logits = self.head(unpooled=encoder_output.unpooled, spans=batch.spans)
if compute_loss:
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(
logits.view(-1, self.span_comparison_head.num_labels), batch.label_ids.float(),
)
loss = loss_fct(logits.view(-1, self.head.num_labels), batch.label_ids.float(),)
return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other)
else:
return LogitsOutput(logits=logits, other=encoder_output.other)


@TaskModelFactory.register(TaskTypes.TAGGING)
class TokenClassificationModel(Taskmodel):
"""From RobertaForTokenClassification"""

def __init__(self, encoder, token_classification_head: heads.TokenClassificationHead):
super().__init__(encoder=encoder)
self.token_classification_head = token_classification_head
def __init__(self, task, encoder, head: heads.TokenClassificationHead, taskmodel_kwargs=None):
super().__init__(task=task, encoder=encoder, head=head)

def forward(self, batch, task, tokenizer, compute_loss: bool = False):
def forward(self, batch, tokenizer, compute_loss: bool = False):
encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch)
logits = self.token_classification_head(unpooled=encoder_output.unpooled)
logits = self.head(unpooled=encoder_output.unpooled)
if compute_loss:
loss_fct = nn.CrossEntropyLoss()
active_loss = batch.label_mask.view(-1) == 1
active_logits = logits.view(-1, self.token_classification_head.num_labels)[active_loss]
active_logits = logits.view(-1, self.head.num_labels)[active_loss]
active_labels = batch.label_ids.view(-1)[active_loss]
loss = loss_fct(active_logits, active_labels)
return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other)
else:
return LogitsOutput(logits=logits, other=encoder_output.other)


@TaskModelFactory.register(TaskTypes.SQUAD_STYLE_QA)
class QAModel(Taskmodel):
def __init__(self, encoder, qa_head: heads.QAHead):
super().__init__(encoder=encoder)
self.qa_head = qa_head
def __init__(self, task, encoder, head: heads.QAHead, taskmodel_kwargs=None):
super().__init__(task=task, encoder=encoder, head=head)

def forward(self, batch, task, tokenizer, compute_loss: bool = False):
def forward(self, batch, tokenizer, compute_loss: bool = False):
encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch)
logits = self.qa_head(unpooled=encoder_output.unpooled)
logits = self.head(unpooled=encoder_output.unpooled)
if compute_loss:
loss = compute_qa_loss(
logits=logits,
Expand All @@ -250,48 +226,50 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False):
return LogitsOutput(logits=logits, other=encoder_output.other)


@TaskModelFactory.register(TaskTypes.MASKED_LANGUAGE_MODELING)
class MLMModel(Taskmodel):
def __init__(self, encoder, mlm_head: heads.BaseMLMHead):
super().__init__(encoder=encoder)
self.mlm_head = mlm_head
def __init__(self, task, encoder, head: heads.BaseMLMHead, taskmodel_kwargs=None):
super().__init__(task=task, encoder=encoder, head=head)

def forward(self, batch, task, tokenizer, compute_loss: bool = False):
def forward(self, batch, tokenizer, compute_loss: bool = False):
masked_batch = batch.get_masked(
mlm_probability=task.mlm_probability, tokenizer=tokenizer, do_mask=task.do_mask,
mlm_probability=self.task.mlm_probability,
tokenizer=tokenizer,
do_mask=self.task.do_mask,
)
encoder_output = get_output_from_encoder(
encoder=self.encoder,
input_ids=masked_batch.masked_input_ids,
segment_ids=masked_batch.segment_ids,
input_mask=masked_batch.input_mask,
)
logits = self.mlm_head(unpooled=encoder_output.unpooled)
logits = self.head(unpooled=encoder_output.unpooled)
if compute_loss:
loss = compute_mlm_loss(logits=logits, masked_lm_labels=masked_batch.masked_lm_labels)
return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other)
else:
return LogitsOutput(logits=logits, other=encoder_output.other)


@TaskModelFactory.register(TaskTypes.EMBEDDING)
class EmbeddingModel(Taskmodel):
def __init__(self, encoder, pooler_head: heads.AbstractPoolerHead, layer):
super().__init__(encoder=encoder)
self.pooler_head = pooler_head
self.layer = layer
def __init__(self, task, encoder, head: heads.AbstractPoolerHead, taskmodel_kwargs):
super().__init__(task=task, encoder=encoder, head=head)
self.layer = taskmodel_kwargs["layer"]

def forward(self, batch, task, tokenizer, compute_loss: bool = False):
def forward(self, batch, tokenizer, compute_loss: bool = False):
with transformer_utils.output_hidden_states_context(self.encoder):
encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch)
# A tuple of layers of hidden states
hidden_states = take_one(encoder_output.other)
layer_hidden_states = hidden_states[self.layer]

if isinstance(self.pooler_head, heads.MeanPoolerHead):
logits = self.pooler_head(unpooled=layer_hidden_states, input_mask=batch.input_mask)
elif isinstance(self.pooler_head, heads.FirstPoolerHead):
logits = self.pooler_head(layer_hidden_states)
if isinstance(self.head, heads.MeanPoolerHead):
logits = self.head(unpooled=layer_hidden_states, input_mask=batch.input_mask)
elif isinstance(self.head, heads.FirstPoolerHead):
logits = self.head(layer_hidden_states)
else:
raise TypeError(type(self.pooler_head))
raise TypeError(type(self.head))

# TODO: Abuse of notation - these aren't really logits (issue #1187)
if compute_loss:
Expand Down
24 changes: 12 additions & 12 deletions jiant/tasks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,18 @@ def data_row_collate_fn(batch):


class TaskTypes(Enum):
CLASSIFICATION = 1
REGRESSION = 2
SPAN_COMPARISON_CLASSIFICATION = 3
MULTIPLE_CHOICE = 4
SPAN_CHOICE_PROB_TASK = 5
SQUAD_STYLE_QA = 6
TAGGING = 7
MASKED_LANGUAGE_MODELING = 8
EMBEDDING = 9
MULTI_LABEL_SPAN_CLASSIFICATION = 10
SPAN_PREDICTION = 11
UNDEFINED = -1
CLASSIFICATION = "classification"
REGRESSION = "regression"
SPAN_COMPARISON_CLASSIFICATION = "span_comparison_classification"
MULTIPLE_CHOICE = "multiple_choice"
SPAN_CHOICE_PROB_TASK = "span_choice_prob_task"
SQUAD_STYLE_QA = "squad_style_qa"
TAGGING = "tagging"
MASKED_LANGUAGE_MODELING = "masked_language_modeling"
EMBEDDING = "embedding"
MULTI_LABEL_SPAN_CLASSIFICATION = "multi_label_span_classification"
SPAN_PREDICTION = "span_prediction"
UNDEFINED = "undefined"


class BatchTuple(NamedTuple):
Expand Down

0 comments on commit 07d3fdb

Please sign in to comment.