Skip to content

Commit

Permalink
Use jiant transformers model wrapper instead of if-else. Use taskmode…
Browse files Browse the repository at this point in the history
…l and head factory instead of if-else.
  • Loading branch information
jeswan committed Jan 26, 2021
1 parent 723786a commit ee889fa
Show file tree
Hide file tree
Showing 10 changed files with 259 additions and 419 deletions.
88 changes: 75 additions & 13 deletions jiant/proj/main/modeling/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,55 @@
import torch.nn as nn

import transformers

from jiant.ext.allennlp import SelfAttentiveSpanExtractor
from jiant.shared.model_resolution import ModelArchitectures
from jiant.tasks.core import TaskTypes
from typing import Callable
from typing import List


"""
In HuggingFace/others, these heads differ slightly across different encoder models.
We're going to abstract away from that and just choose one implementation.
"""


class JiantHeadFactory:
# Internal registry for available task models
registry = {}

@classmethod
def register(cls, task_type_list: List[TaskTypes]) -> Callable:
def inner_wrapper(wrapped_class: BaseHead) -> Callable:
for task_type in task_type_list:
assert task_type not in cls.registry
cls.registry[task_type] = wrapped_class
return wrapped_class

return inner_wrapper

def __call__(self, task, **kwargs):
head_class = self.registry[task.TASK_TYPE]
head = head_class(task, **kwargs)
return head


class BaseHead(nn.Module, metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
def __init__(self):
super().__init__()


@JiantHeadFactory.register([TaskTypes.CLASSIFICATION])
class ClassificationHead(BaseHead):
def __init__(self, hidden_size, hidden_dropout_prob, num_labels):
def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs):
"""From RobertaClassificationHead"""
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.out_proj = nn.Linear(hidden_size, num_labels)
self.num_labels = num_labels
self.out_proj = nn.Linear(hidden_size, task.num_labels)
self.num_labels = len(task.LABELS)

def forward(self, pooled):
x = self.dropout(pooled)
Expand All @@ -34,8 +63,9 @@ def forward(self, pooled):
return logits


@JiantHeadFactory.register([TaskTypes.REGRESSION, TaskTypes.MULTIPLE_CHOICE])
class RegressionHead(BaseHead):
def __init__(self, hidden_size, hidden_dropout_prob):
def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs):
"""From RobertaClassificationHead"""
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
Expand All @@ -51,12 +81,13 @@ def forward(self, pooled):
return scores


@JiantHeadFactory.register([TaskTypes.SPAN_COMPARISON_CLASSIFICATION])
class SpanComparisonHead(BaseHead):
def __init__(self, hidden_size, hidden_dropout_prob, num_spans, num_labels):
def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs):
"""From RobertaForSpanComparisonClassification"""
super().__init__()
self.num_spans = num_spans
self.num_labels = num_labels
self.num_spans = task.num_spans
self.num_labels = len(task.LABELS)
self.hidden_size = hidden_size
self.dropout = nn.Dropout(hidden_dropout_prob)
self.span_attention_extractor = SelfAttentiveSpanExtractor(hidden_size)
Expand All @@ -70,22 +101,24 @@ def forward(self, unpooled, spans):
return logits


@JiantHeadFactory.register([TaskTypes.TAGGING])
class TokenClassificationHead(BaseHead):
def __init__(self, hidden_size, num_labels, hidden_dropout_prob):
def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs):
"""From RobertaForTokenClassification"""
super().__init__()
self.num_labels = num_labels
self.num_labels = len(task.LABELS)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.classifier = nn.Linear(hidden_size, num_labels)
self.classifier = nn.Linear(hidden_size, self.num_labels)

def forward(self, unpooled):
unpooled = self.dropout(unpooled)
logits = self.classifier(unpooled)
return logits


@JiantHeadFactory.register([TaskTypes.SQUAD_STYLE_QA])
class QAHead(BaseHead):
def __init__(self, hidden_size):
def __init__(self, task, hidden_size, **kwargs):
"""From RobertaForQuestionAnswering"""
super().__init__()
self.qa_outputs = nn.Linear(hidden_size, 2)
Expand All @@ -98,10 +131,37 @@ def forward(self, unpooled):
return logits


@JiantHeadFactory.register([TaskTypes.MASKED_LANGUAGE_MODELING])
class JiantMLMHeadFactory:
# Internal registry for available task models
registry = {}

@classmethod
def register(cls, model_arch_list: List[ModelArchitectures]) -> Callable:
def inner_wrapper(wrapped_class: BaseMLMHead) -> Callable:
for model_arch in model_arch_list:
assert model_arch not in cls.registry
cls.registry[model_arch] = wrapped_class
return wrapped_class

return inner_wrapper

def __call__(
self,
task,
**kwargs
# task_type: TaskTypes, model_arch: ModelArchitectures, hidden_size, hidden_dropout_prob
):
mlm_head_class = self.registry[task.TASK_TYPE]
mlm_head = mlm_head_class(task, **kwargs)
return mlm_head


class BaseMLMHead(BaseHead, metaclass=abc.ABCMeta):
pass


@JiantMLMHeadFactory.register([ModelArchitectures.BERT])
class BertMLMHead(BaseMLMHead):
"""From BertOnlyMLMHead, BertLMPredictionHead, BertPredictionHeadTransform"""

Expand All @@ -126,6 +186,7 @@ def forward(self, unpooled):
return logits


@JiantMLMHeadFactory.register([ModelArchitectures.ROBERTA, ModelArchitectures.XLM_ROBERTA])
class RobertaMLMHead(BaseMLMHead):
"""From RobertaLMHead"""

Expand All @@ -151,7 +212,8 @@ def forward(self, unpooled):
return logits


class AlbertMLMHead(nn.Module):
@JiantMLMHeadFactory.register([ModelArchitectures.ALBERT])
class AlbertMLMHead(BaseMLMHead):
"""From AlbertMLMHead"""

def __init__(self, hidden_size, embedding_size, vocab_size, hidden_act="gelu"):
Expand Down
Loading

0 comments on commit ee889fa

Please sign in to comment.