forked from facebookresearch/pytext
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MultiLabel-MultiClass Model for Joint Sequence Tagging (facebookresea…
…rch#1335) Summary: Pull Request resolved: facebookresearch#1335 We need to support multi-class as well as multi-label prediction for joint models in pytext. This diff implements a 1. Joint Multi Label Decoder 2. MultiLabelClassification Output Layer 3. Loss computation for multi-label-multi-class scenarios 4. Label weights per label and per class 5. Softmax options for output layers 6. Custom Metric Reporter, Metric Class and Output for flow Reviewed By: seayoung1112 Differential Revision: D20210880 fbshipit-source-id: 701ca0a32302f923f13efe012618bba693b2d4db
- Loading branch information
1 parent
80677f3
commit f74a7ba
Showing
10 changed files
with
372 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
|
||
from typing import Dict, List | ||
|
||
import torch | ||
import torch.nn as nn | ||
from pytext.utils.usage import log_class_usage | ||
|
||
from .decoder_base import DecoderBase | ||
|
||
|
||
class MultiLabelDecoder(DecoderBase): | ||
""" | ||
Implements a 'n-tower' MLP: one for each of the multi labels | ||
Used in USM/EA: the user satisfaction modeling, pTSR prediction and | ||
Error Attribution are all 3 label sets that need predicting. | ||
""" | ||
|
||
class Config(DecoderBase.Config): | ||
# Intermediate hidden dimensions | ||
hidden_dims: List[int] = [] | ||
|
||
def __init__( | ||
self, | ||
config: Config, | ||
in_dim: int, | ||
output_dim: Dict[str, int], | ||
label_names: List[str], | ||
) -> None: | ||
super().__init__(config) | ||
self.label_mlps = nn.ModuleDict({}) | ||
# Store the ordered list to preserve the ordering of the labels | ||
# when generating the output layer | ||
self.label_names = label_names | ||
aggregate_out_dim = 0 | ||
for label_, _ in output_dim.items(): | ||
self.label_mlps[label_] = MultiLabelDecoder.get_mlp( | ||
in_dim, output_dim[label_], config.hidden_dims | ||
) | ||
aggregate_out_dim += output_dim[label_] | ||
self.out_dim = (1, aggregate_out_dim) | ||
log_class_usage(__class__) | ||
|
||
@staticmethod | ||
def get_mlp(in_dim: int, out_dim: int, hidden_dims: List[int]): | ||
layers = [] | ||
current_dim = in_dim | ||
for dim in hidden_dims or []: | ||
layers.append(nn.Linear(current_dim, dim)) | ||
layers.append(nn.ReLU()) | ||
current_dim = dim | ||
layers.append(nn.Linear(current_dim, out_dim)) | ||
return nn.Sequential(*layers) | ||
|
||
def forward(self, *input: torch.Tensor): | ||
logits = tuple( | ||
self.label_mlps[x](torch.cat(input, 1)) for x in self.label_names | ||
) | ||
return logits | ||
|
||
def get_decoder(self) -> List[nn.Module]: | ||
return self.label_mlps |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.