Skip to content

Commit

Permalink
MultiLabel-MultiClass Model for Joint Sequence Tagging (facebookresea…
Browse files Browse the repository at this point in the history
…rch#1335)

Summary:
Pull Request resolved: facebookresearch#1335

We need to support multi-class as well as multi-label prediction for joint models in pytext. This diff implements a

1. Joint Multi Label Decoder
2. MultiLabelClassification Output Layer
3. Loss computation for multi-label-multi-class scenarios
4. Label weights per label and per class
5. Softmax options for output layers
6. Custom Metric Reporter, Metric Class and Output for flow

Reviewed By: seayoung1112

Differential Revision: D20210880

fbshipit-source-id: 701ca0a32302f923f13efe012618bba693b2d4db
  • Loading branch information
shivanipods authored and facebook-github-bot committed Apr 27, 2020
1 parent 80677f3 commit f74a7ba
Show file tree
Hide file tree
Showing 10 changed files with 372 additions and 7 deletions.
4 changes: 2 additions & 2 deletions pytext/data/tensorizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,8 +853,8 @@ def numberize(self, row):
label_idx_list.append(self.pad_idx)
else:
raise Exception(
"Found none or empty value in the list,"
+ " while pad_missing is disabled"
"Found none or empty value in the list, \
while pad_missing is disabled"
)
else:
label_idx_list.append(self.vocab.lookup_all(label))
Expand Down
3 changes: 2 additions & 1 deletion pytext/metric_reporters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .regression_metric_reporter import RegressionMetricReporter
from .squad_metric_reporter import SquadMetricReporter
from .word_tagging_metric_reporter import (
MultiLabelSequenceTaggingMetricReporter,
NERMetricReporter,
SequenceTaggingMetricReporter,
WordTaggingMetricReporter,
Expand All @@ -34,7 +35,7 @@
"CompositionalMetricReporter",
"PairwiseRankingMetricReporter",
"SequenceTaggingMetricReporter",
"PureLossMetricReporter",
"MultiLabelSequenceTaggingMetricReporter",
"NERMetricReporter",
"DenseRetrievalMetricReporter",
]
68 changes: 68 additions & 0 deletions pytext/metric_reporters/word_tagging_metric_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
LabelPrediction,
PRF1Metrics,
compute_classification_metrics,
compute_multi_label_multi_class_soft_metrics,
)
from pytext.metrics.intent_slot_metrics import (
Node,
Expand Down Expand Up @@ -92,6 +93,73 @@ def get_model_select_metric(self, metrics):
return metrics.micro_scores.f1


class MultiLabelSequenceTaggingMetricReporter(MetricReporter):
def __init__(self, label_names, pad_idx, channels, label_vocabs=None):
super().__init__(channels)
self.label_names = label_names
self.pad_idx = pad_idx
self.label_vocabs = label_vocabs

@classmethod
def from_config(cls, config, tensorizers):
return MultiLabelSequenceTaggingMetricReporter(
channels=[ConsoleChannel(), FileChannel((Stage.TEST,), config.output_path)],
label_names=tensorizers.keys(),
pad_idx=[v.pad_idx for _, v in tensorizers.items()],
label_vocabs=[v.vocab._vocab for _, v in tensorizers.items()],
)

def calculate_metric(self):
if len(self.all_scores) == 0:
return {}
list_score_pred_expect = []
for label_idx in range(0, len(self.label_names)):
list_score_pred_expect.append(
list(
itertools.chain.from_iterable(
(
LabelPrediction(s, p, e)
for s, p, e in zip(scores, pred, expect)
if e != self.pad_idx[label_idx]
)
for scores, pred, expect in zip(
self.all_scores[label_idx],
self.all_preds[label_idx],
self.all_targets[label_idx],
)
)
)
)
metrics = compute_multi_label_multi_class_soft_metrics(
list_score_pred_expect,
self.label_names,
self.label_vocabs,
self.calculate_loss(),
)
return metrics

def batch_context(self, raw_batch, batch):
return {}

@staticmethod
def get_model_select_metric(metrics):
if isinstance(metrics, dict):
# There are multiclass precision/recall labels
# Compute average precision
avg_precision = 0.0
for _, metric in metrics.items():
if metric:
avg_precision += sum(
v.average_precision
for k, v in metric.items()
if v.average_precision > 0
) / (len(metric.keys()) * 1.0)
avg_precision = avg_precision / (len(metrics.keys()) * 1.0)
else:
avg_precision = metrics.accuracy
return avg_precision


class SequenceTaggingMetricReporter(MetricReporter):
def __init__(self, label_names, pad_idx, channels):
super().__init__(channels)
Expand Down
31 changes: 31 additions & 0 deletions pytext/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,37 @@ def compute_multi_label_soft_metrics(
return soft_metrics


def compute_multi_label_multi_class_soft_metrics(
predictions: Sequence[Sequence[LabelListPrediction]],
label_names: Sequence[str],
label_vocabs: Sequence[Sequence[str]],
recall_at_precision_thresholds: Sequence[float] = RECALL_AT_PRECISION_THRESHOLDS,
precision_at_recall_thresholds: Sequence[float] = PRECISION_AT_RECALL_THRESHOLDS,
) -> Dict[int, SoftClassificationMetrics]:
"""
Computes multi-label soft classification metrics with multi-class accommodation
Args:
predictions: multi-label predictions,
including the confidence score for each label.
label_names: Indexed label names.
recall_at_precision_thresholds: precision thresholds at which to calculate
recall
precision_at_recall_thresholds: recall thresholds at which to calculate
precision
Returns:
Dict from label strings to their corresponding soft metrics.
"""
soft_metrics = {}
for label_idx, label_vocab in enumerate(label_vocabs):
label = list(label_names)[label_idx]
soft_metrics[label] = compute_soft_metrics(predictions[label_idx], label_vocab)
return soft_metrics


def compute_matthews_correlation_coefficients(
TP: int, FP: int, FN: int, TN: int
) -> float:
Expand Down
64 changes: 64 additions & 0 deletions pytext/models/decoders/multilabel_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Dict, List

import torch
import torch.nn as nn
from pytext.utils.usage import log_class_usage

from .decoder_base import DecoderBase


class MultiLabelDecoder(DecoderBase):
"""
Implements a 'n-tower' MLP: one for each of the multi labels
Used in USM/EA: the user satisfaction modeling, pTSR prediction and
Error Attribution are all 3 label sets that need predicting.
"""

class Config(DecoderBase.Config):
# Intermediate hidden dimensions
hidden_dims: List[int] = []

def __init__(
self,
config: Config,
in_dim: int,
output_dim: Dict[str, int],
label_names: List[str],
) -> None:
super().__init__(config)
self.label_mlps = nn.ModuleDict({})
# Store the ordered list to preserve the ordering of the labels
# when generating the output layer
self.label_names = label_names
aggregate_out_dim = 0
for label_, _ in output_dim.items():
self.label_mlps[label_] = MultiLabelDecoder.get_mlp(
in_dim, output_dim[label_], config.hidden_dims
)
aggregate_out_dim += output_dim[label_]
self.out_dim = (1, aggregate_out_dim)
log_class_usage(__class__)

@staticmethod
def get_mlp(in_dim: int, out_dim: int, hidden_dims: List[int]):
layers = []
current_dim = in_dim
for dim in hidden_dims or []:
layers.append(nn.Linear(current_dim, dim))
layers.append(nn.ReLU())
current_dim = dim
layers.append(nn.Linear(current_dim, out_dim))
return nn.Sequential(*layers)

def forward(self, *input: torch.Tensor):
logits = tuple(
self.label_mlps[x](torch.cat(input, 1)) for x in self.label_names
)
return logits

def get_decoder(self) -> List[nn.Module]:
return self.label_mlps
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def forward(self, logits: torch.Tensor):
class BinaryClassificationOutputLayer(ClassificationOutputLayer):
def get_pred(self, logit, *args, **kwargs):
"""See `OutputLayerBase.get_pred()`."""
preds = torch.max(logit, 1)[1]
preds = torch.max(logit, -1)[1]
scores = F.logsigmoid(logit)
return preds, scores

Expand All @@ -153,7 +153,7 @@ def export_to_caffe2(
class MulticlassOutputLayer(ClassificationOutputLayer):
def get_pred(self, logit, *args, **kwargs):
"""See `OutputLayerBase.get_pred()`."""
preds = torch.max(logit, 1)[1]
preds = torch.max(logit, -1)[1]
scores = F.log_softmax(logit, 1)
return preds, scores

Expand Down
Loading

0 comments on commit f74a7ba

Please sign in to comment.