Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Make most metrics work on GPU #3851

Merged
merged 8 commits into from
Feb 27, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion allennlp/common/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Utilities and helpers for writing tests.
"""
from allennlp.common.testing.test_case import AllenNlpTestCase
from allennlp.common.testing.test_case import AllenNlpTestCase, multi_device
from allennlp.common.testing.model_test_case import ModelTestCase
34 changes: 34 additions & 0 deletions allennlp/common/testing/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import pathlib
import shutil
import tempfile
from typing import Any, Iterable
from unittest import TestCase

import torch

from allennlp.common.checks import log_pytorch_version_info

TEST_DIR = tempfile.mkdtemp(prefix="allennlp_tests")
Expand Down Expand Up @@ -40,3 +43,34 @@ def setUp(self):

def tearDown(self):
shutil.rmtree(self.TEST_DIR)


def parametrize(arg_names: Iterable[str], arg_values: Iterable[Iterable[Any]]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cute, this is really nice!

"""
Decorator to create parameterized tests.

# Parameters

arg_names : `Iterable[str]`, required.
Argument names to pass to the test function.
arg_values : `Iterable[Iterable[Any]]`, required.
Iterable of values to pass to each of the args.
A function call is gonna be made for each inner iterable.
bryant1410 marked this conversation as resolved.
Show resolved Hide resolved
"""

def decorator(func):
def wrapper(*args, **kwargs):
for arg_value in arg_values:
kwargs_extra = {name: value for name, value in zip(arg_names, arg_value)}
func(*args, **kwargs, **kwargs_extra)

return wrapper

return decorator


_available_devices = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
multi_device = parametrize(("device",), [(device,) for device in _available_devices])
"""
Decorator that provides an argument `device` of type `str` for each available PyTorch device.
"""
19 changes: 19 additions & 0 deletions allennlp/tests/common/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch

from allennlp.common.testing import AllenNlpTestCase, multi_device


class TestFromParams(AllenNlpTestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Class name here needs updating.

(I came to see what Mark thought looked cute, noticed a copy-paste bug.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Good catch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put TestTesting after the module name, hope it's fine.

def test_multi_device(self):
actual_devices = set()

@multi_device
def dummy_func(_self, device: str):
# Have `self` as in class test functions.
nonlocal actual_devices
actual_devices.add(device)

dummy_func(self)

expected_devices = {"cpu", "cuda"} if torch.cuda.is_available() else {"cpu"}
self.assertSetEqual(expected_devices, actual_devices)
29 changes: 24 additions & 5 deletions allennlp/tests/training/metrics/attachment_scores_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from allennlp.common.testing import AllenNlpTestCase
from allennlp.common.testing import AllenNlpTestCase, multi_device
from allennlp.training.metrics import AttachmentScores


Expand All @@ -19,15 +19,28 @@ def setUp(self):

self.mask = torch.Tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]])

def test_perfect_scores(self):
def _send_tensors_to_device(self, device: str):
self.predictions = self.predictions.to(device)
self.gold_indices = self.gold_indices.to(device)
self.label_predictions = self.label_predictions.to(device)
self.gold_labels = self.gold_labels.to(device)
self.mask = self.mask.to(device)

@multi_device
def test_perfect_scores(self, device: str):
self._send_tensors_to_device(device)

self.scorer(
self.predictions, self.label_predictions, self.gold_indices, self.gold_labels, self.mask
)

for value in self.scorer.get_metric().values():
assert value == 1.0

def test_unlabeled_accuracy_ignores_incorrect_labels(self):
@multi_device
def test_unlabeled_accuracy_ignores_incorrect_labels(self, device: str):
self._send_tensors_to_device(device)

label_predictions = self.label_predictions
# Change some stuff so our 4 of our label predictions are wrong.
label_predictions[0, 3:] = 3
Expand All @@ -47,7 +60,10 @@ def test_unlabeled_accuracy_ignores_incorrect_labels(self):
# Neither should have labeled exact match.
assert metrics["LEM"] == 0.0

def test_labeled_accuracy_is_affected_by_incorrect_heads(self):
@multi_device
def test_labeled_accuracy_is_affected_by_incorrect_heads(self, device: str):
self._send_tensors_to_device(device)

predictions = self.predictions
# Change some stuff so our 4 of our predictions are wrong.
predictions[0, 3:] = 3
Expand All @@ -71,7 +87,10 @@ def test_labeled_accuracy_is_affected_by_incorrect_heads(self):
assert metrics["LEM"] == 0.0
assert metrics["UEM"] == 0.0

def test_attachment_scores_can_ignore_labels(self):
@multi_device
def test_attachment_scores_can_ignore_labels(self, device: str):
self._send_tensors_to_device(device)

scorer = AttachmentScores(ignore_classes=[1])

label_predictions = self.label_predictions
Expand Down
57 changes: 31 additions & 26 deletions allennlp/tests/training/metrics/auc_test.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import pytest
import torch
from sklearn import metrics
from numpy.testing import assert_almost_equal
import pytest
from torch.testing import assert_allclose

from allennlp.common.testing import AllenNlpTestCase
from allennlp.training.metrics import Auc
from allennlp.common.checks import ConfigurationError
from allennlp.common.testing import AllenNlpTestCase, multi_device
from allennlp.training.metrics import Auc


class AucTest(AllenNlpTestCase):
def test_auc_computation(self):
@multi_device
def test_auc_computation(self, device: str):
auc = Auc()
all_predictions = []
all_labels = []
for _ in range(5):
predictions = torch.randn(8).float()
labels = torch.randint(0, 2, (8,)).long()
predictions = torch.randn(8, device=device)
labels = torch.randint(0, 2, (8,), dtype=torch.long, device=device)

auc(predictions, labels)

Expand All @@ -25,62 +26,66 @@ def test_auc_computation(self):
computed_auc_value = auc.get_metric(reset=True)

false_positive_rates, true_positive_rates, _ = metrics.roc_curve(
torch.cat(all_labels, dim=0).numpy(), torch.cat(all_predictions, dim=0).numpy()
torch.cat(all_labels, dim=0).cpu().numpy(),
torch.cat(all_predictions, dim=0).cpu().numpy(),
)
real_auc_value = metrics.auc(false_positive_rates, true_positive_rates)
assert_almost_equal(real_auc_value, computed_auc_value)
assert_allclose(real_auc_value, computed_auc_value)

# One more computation to assure reset works.
predictions = torch.randn(8).float()
labels = torch.randint(0, 2, (8,)).long()
predictions = torch.randn(8, device=device)
labels = torch.randint(0, 2, (8,), dtype=torch.long, device=device)

auc(predictions, labels)
computed_auc_value = auc.get_metric(reset=True)

false_positive_rates, true_positive_rates, _ = metrics.roc_curve(
labels.numpy(), predictions.numpy()
labels.cpu().numpy(), predictions.cpu().numpy()
)
real_auc_value = metrics.auc(false_positive_rates, true_positive_rates)
assert_almost_equal(real_auc_value, computed_auc_value)
assert_allclose(real_auc_value, computed_auc_value)

def test_auc_gold_labels_behaviour(self):
@multi_device
def test_auc_gold_labels_behaviour(self, device: str):
# Check that it works with different pos_label
auc = Auc(positive_label=4)

predictions = torch.randn(8).float()
labels = torch.randint(3, 5, (8,)).long()
predictions = torch.randn(8, device=device)
labels = torch.randint(3, 5, (8,), dtype=torch.long, device=device)
# We make sure that the positive label is always present.
labels[0] = 4
auc(predictions, labels)
computed_auc_value = auc.get_metric(reset=True)

false_positive_rates, true_positive_rates, _ = metrics.roc_curve(
labels.numpy(), predictions.numpy(), pos_label=4
labels.cpu().numpy(), predictions.cpu().numpy(), pos_label=4
)
real_auc_value = metrics.auc(false_positive_rates, true_positive_rates)
assert_almost_equal(real_auc_value, computed_auc_value)
assert_allclose(real_auc_value, computed_auc_value)

# Check that it errs on getting more than 2 labels.
with pytest.raises(ConfigurationError) as _:
labels = torch.LongTensor([3, 4, 5, 6, 7, 8, 9, 10])
labels = torch.tensor([3, 4, 5, 6, 7, 8, 9, 10], device=device)
auc(predictions, labels)

def test_auc_with_mask(self):
@multi_device
def test_auc_with_mask(self, device: str):
auc = Auc()

predictions = torch.randn(8).float()
labels = torch.randint(0, 2, (8,)).long()
mask = torch.ByteTensor([1, 1, 1, 1, 0, 0, 0, 0])
predictions = torch.randn(8, device=device)
labels = torch.randint(0, 2, (8,), dtype=torch.long, device=device)
mask = torch.tensor([1, 1, 1, 1, 0, 0, 0, 0], dtype=torch.uint8, device=device)

auc(predictions, labels, mask)
computed_auc_value = auc.get_metric(reset=True)

false_positive_rates, true_positive_rates, _ = metrics.roc_curve(
labels[:4].numpy(), predictions[:4].numpy()
labels[:4].cpu().numpy(), predictions[:4].cpu().numpy()
)
real_auc_value = metrics.auc(false_positive_rates, true_positive_rates)
assert_almost_equal(real_auc_value, computed_auc_value)
assert_allclose(real_auc_value, computed_auc_value)

def test_auc_works_without_calling_metric_at_all(self):
@multi_device
def test_auc_works_without_calling_metric_at_all(self, device: str):
auc = Auc()
auc.get_metric()
39 changes: 21 additions & 18 deletions allennlp/tests/training/metrics/bleu_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from collections import Counter
import math
from collections import Counter

import numpy as np
import torch
from torch.testing import assert_allclose

from allennlp.common.testing import AllenNlpTestCase
from allennlp.common.testing import AllenNlpTestCase, multi_device
from allennlp.training.metrics import BLEU


Expand All @@ -13,15 +13,16 @@ def setUp(self):
super().setUp()
self.metric = BLEU(ngram_weights=(0.5, 0.5), exclude_indices={0})

def test_get_valid_tokens_mask(self):
tensor = torch.tensor([[1, 2, 3, 0], [0, 1, 1, 0]])
result = self.metric._get_valid_tokens_mask(tensor)
result = result.long().numpy()
check = np.array([[1, 1, 1, 0], [0, 1, 1, 0]])
np.testing.assert_array_equal(result, check)
@multi_device
def test_get_valid_tokens_mask(self, device: str):
tensor = torch.tensor([[1, 2, 3, 0], [0, 1, 1, 0]], device=device)
result = self.metric._get_valid_tokens_mask(tensor).long()
check = torch.tensor([[1, 1, 1, 0], [0, 1, 1, 0]], device=device)
assert_allclose(result, check)

def test_ngrams(self):
tensor = torch.tensor([1, 2, 3, 1, 2, 0])
@multi_device
def test_ngrams(self, device: str):
tensor = torch.tensor([1, 2, 3, 1, 2, 0], device=device)

# Unigrams.
counts = Counter(self.metric._ngrams(tensor, 1))
Expand All @@ -42,22 +43,23 @@ def test_ngrams(self):
counts = Counter(self.metric._ngrams(tensor, 7))
assert counts == {}

def test_bleu_computed_correctly(self):
@multi_device
def test_bleu_computed_correctly(self, device: str):
self.metric.reset()

# shape: (batch_size, max_sequence_length)
predictions = torch.tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]])
predictions = torch.tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]], device=device)

# shape: (batch_size, max_gold_sequence_length)
gold_targets = torch.tensor([[2, 0, 0], [1, 0, 0], [1, 1, 2]])
gold_targets = torch.tensor([[2, 0, 0], [1, 0, 0], [1, 1, 2]], device=device)

self.metric(predictions, gold_targets)

assert self.metric._prediction_lengths == 6
assert self.metric._reference_lengths == 5

# Number of unigrams in predicted sentences that match gold sentences
# (but not more than maximum occurence of gold unigram within batch).
# (but not more than maximum occurrence of gold unigram within batch).
assert self.metric._precision_matches[1] == (
0
+ 1 # no matches in first sentence.
Expand All @@ -68,7 +70,7 @@ def test_bleu_computed_correctly(self):
assert self.metric._precision_totals[1] == (1 + 2 + 3)

# Number of bigrams in predicted sentences that match gold sentences
# (but not more than maximum occurence of gold bigram within batch).
# (but not more than maximum occurrence of gold bigram within batch).
assert self.metric._precision_matches[2] == (0 + 0 + 1)

# Total number of predicted bigrams.
Expand All @@ -79,8 +81,9 @@ def test_bleu_computed_correctly(self):

bleu = self.metric.get_metric(reset=True)["BLEU"]
check = math.exp(0.5 * (math.log(3) - math.log(6)) + 0.5 * (math.log(1) - math.log(3)))
np.testing.assert_approx_equal(bleu, check)
assert_allclose(bleu, check)

def test_bleu_computed_with_zero_counts(self):
@multi_device
def test_bleu_computed_with_zero_counts(self, device: str):
self.metric.reset()
assert self.metric.get_metric()["BLEU"] == 0
Loading