Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Dimensions in FBetaMultiLabelMeasure (#5501)
Browse files Browse the repository at this point in the history
* Makes multilabel FBeta work in multiple dimensions

* Changelog

* Invalidate package cache for CI
  • Loading branch information
dirkgr authored Dec 10, 2021
1 parent d77ba3d commit e0ee7f4
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ env:
TORCH_CPU_INSTALL: conda install pytorch torchvision torchaudio cpuonly -c pytorch
TORCH_GPU_INSTALL: conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
# Change this to invalidate existing cache.
CACHE_PREFIX: v4
CACHE_PREFIX: v5
# Disable tokenizers parallelism because this doesn't help, and can cause issues in distributed tests.
TOKENIZERS_PARALLELISM: 'false'
# Disable multithreading with OMP because this can lead to dead-locks in distributed tests.
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed the docstring information for the `FBetaMultiLabelMeasure` metric.
- Various fixes for Python 3.9
- Fixed the name that the `push-to-hf` command uses to store weights.
- `FBetaMultiLabelMeasure` now works with multiple dimensions
- Support for inferior operating systems when making hardlinks

### Removed
Expand Down
6 changes: 2 additions & 4 deletions allennlp/training/metrics/fbeta_multi_label_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,8 @@ def __call__(
pred_mask = (predictions.sum(dim=-1) != 0).unsqueeze(-1)
threshold_predictions = (predictions >= self._threshold).float()

class_indices = (
torch.arange(num_classes, device=predictions.device)
.unsqueeze(0)
.repeat(gold_labels.size(0), 1)
class_indices = torch.arange(num_classes, device=predictions.device).repeat(
gold_labels.shape[:-1] + (1,)
)
true_positives = (gold_labels * threshold_predictions).bool() & mask & pred_mask
true_positives_bins = class_indices[true_positives]
Expand Down
22 changes: 22 additions & 0 deletions tests/training/metrics/fbeta_multi_label_measure_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,28 @@ def test_fbeta_multilabel_metric(self, device: str):
assert isinstance(recalls, List)
assert isinstance(fscores, List)

@multi_device
def test_fbeta_multilable_with_extra_dimensions(self, device: str):
self.predictions = self.predictions.to(device)
self.targets = self.targets.to(device)

fbeta = FBetaMultiLabelMeasure()
fbeta(self.predictions.unsqueeze(1), self.targets.unsqueeze(1))
metric = fbeta.get_metric()
precisions = metric["precision"]
recalls = metric["recall"]
fscores = metric["fscore"]

# check value
assert_allclose(precisions, self.desired_precisions)
assert_allclose(recalls, self.desired_recalls)
assert_allclose(fscores, self.desired_fscores)

# check type
assert isinstance(precisions, List)
assert isinstance(recalls, List)
assert isinstance(fscores, List)

@multi_device
def test_fbeta_multilabel_with_mask(self, device: str):
self.predictions = self.predictions.to(device)
Expand Down

0 comments on commit e0ee7f4

Please sign in to comment.