Skip to content

Commit

Permalink
doctests fix
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Aug 26, 2022
1 parent 893143d commit 6875845
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
30 changes: 18 additions & 12 deletions src/torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ class MulticlassStatScores(_AbstractStatScores):
>>> from torchmetrics import MulticlassStatScores
>>> target = torch.tensor([2, 1, 0, 0])
>>> preds = torch.tensor([2, 1, 0, 1])
>>> metric = MulticlassStatScores(num_classes=3)
>>> metric = MulticlassStatScores(num_classes=3, average='micro')
>>> metric(preds, target)
tensor([3, 1, 7, 1, 4])
>>> metric = MulticlassStatScores(num_classes=3, average=None)
Expand All @@ -251,7 +251,7 @@ class MulticlassStatScores(_AbstractStatScores):
... [0.71, 0.09, 0.20],
... [0.05, 0.82, 0.13],
... ])
>>> metric = MulticlassStatScores(num_classes=3)
>>> metric = MulticlassStatScores(num_classes=3, average='micro')
>>> metric(preds, target)
tensor([3, 1, 7, 1, 4])
>>> metric = MulticlassStatScores(num_classes=3, average=None)
Expand All @@ -264,7 +264,7 @@ class MulticlassStatScores(_AbstractStatScores):
>>> from torchmetrics import MulticlassStatScores
>>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassStatScores(num_classes=3, multidim_average="samplewise")
>>> metric = MulticlassStatScores(num_classes=3, multidim_average="samplewise", average='micro')
>>> metric(preds, target)
tensor([[3, 3, 9, 3, 6],
[2, 4, 8, 4, 6]])
Expand Down Expand Up @@ -379,40 +379,46 @@ class MultilabelStatScores(_AbstractStatScores):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example (preds is int tensor):
>>> from torchmetrics.functional import multilabel_stat_scores
>>> from torchmetrics import MultilabelStatScores
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_stat_scores(preds, target, num_labels=3)
>>> metric = MultilabelStatScores(num_labels=3, average='micro')
>>> metric(preds, target)
tensor([2, 1, 2, 1, 3])
>>> multilabel_stat_scores(preds, target, num_labels=3, average=None)
>>> metric = MultilabelStatScores(num_labels=3, average=None)
>>> metric(preds, target)
tensor([[1, 0, 1, 0, 1],
[0, 0, 1, 1, 1],
[1, 1, 0, 0, 1]])
Example (preds is float tensor):
>>> from torchmetrics.functional import multilabel_stat_scores
>>> from torchmetrics import MultilabelStatScores
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_stat_scores(preds, target, num_labels=3)
>>> metric = MultilabelStatScores(num_labels=3, average='micro')
>>> metric(preds, target)
tensor([2, 1, 2, 1, 3])
>>> multilabel_stat_scores(preds, target, num_labels=3, average=None)
>>> metric = MultilabelStatScores(num_labels=3, average=None)
>>> metric(preds, target)
tensor([[1, 0, 1, 0, 1],
[0, 0, 1, 1, 1],
[1, 1, 0, 0, 1]])
Example (multidim tensors):
>>> from torchmetrics.functional import multilabel_stat_scores
>>> from torchmetrics import MultilabelStatScores
>>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = torch.tensor(
... [
... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]],
... ]
... )
>>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise')
>>> metric = MultilabelStatScores(num_labels=3, multidim_average='samplewise', average='micro')
>>> metric(preds, target)
tensor([[2, 3, 0, 1, 3],
[0, 2, 1, 3, 3]])
>>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise', average=None)
>>> metric = MultilabelStatScores(num_labels=3, multidim_average='samplewise', average=None)
>>> metric(preds, target)
tensor([[[1, 1, 0, 0, 1],
[1, 1, 0, 0, 1],
[0, 1, 0, 1, 1]],
Expand Down
12 changes: 6 additions & 6 deletions src/torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def multiclass_stat_scores(
>>> from torchmetrics.functional import multiclass_stat_scores
>>> target = torch.tensor([2, 1, 0, 0])
>>> preds = torch.tensor([2, 1, 0, 1])
>>> multiclass_stat_scores(preds, target, num_classes=3)
>>> multiclass_stat_scores(preds, target, num_classes=3, average='micro')
tensor([3, 1, 7, 1, 4])
>>> multiclass_stat_scores(preds, target, num_classes=3, average=None)
tensor([[1, 0, 2, 1, 2],
Expand All @@ -525,7 +525,7 @@ def multiclass_stat_scores(
... [0.71, 0.09, 0.20],
... [0.05, 0.82, 0.13],
... ])
>>> multiclass_stat_scores(preds, target, num_classes=3)
>>> multiclass_stat_scores(preds, target, num_classes=3, average='micro')
tensor([3, 1, 7, 1, 4])
>>> multiclass_stat_scores(preds, target, num_classes=3, average=None)
tensor([[1, 0, 2, 1, 2],
Expand All @@ -536,7 +536,7 @@ def multiclass_stat_scores(
>>> from torchmetrics.functional import multiclass_stat_scores
>>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise')
>>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise', average='micro')
tensor([[3, 3, 9, 3, 6],
[2, 4, 8, 4, 6]])
>>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise', average=None)
Expand Down Expand Up @@ -764,7 +764,7 @@ def multilabel_stat_scores(
>>> from torchmetrics.functional import multilabel_stat_scores
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_stat_scores(preds, target, num_labels=3)
>>> multilabel_stat_scores(preds, target, num_labels=3, average='micro')
tensor([2, 1, 2, 1, 3])
>>> multilabel_stat_scores(preds, target, num_labels=3, average=None)
tensor([[1, 0, 1, 0, 1],
Expand All @@ -775,7 +775,7 @@ def multilabel_stat_scores(
>>> from torchmetrics.functional import multilabel_stat_scores
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_stat_scores(preds, target, num_labels=3)
>>> multilabel_stat_scores(preds, target, num_labels=3, average='micro')
tensor([2, 1, 2, 1, 3])
>>> multilabel_stat_scores(preds, target, num_labels=3, average=None)
tensor([[1, 0, 1, 0, 1],
Expand All @@ -791,7 +791,7 @@ def multilabel_stat_scores(
... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]],
... ]
... )
>>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise')
>>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise', average='micro')
tensor([[2, 3, 0, 1, 3],
[0, 2, 1, 3, 3]])
>>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise', average=None)
Expand Down

0 comments on commit 6875845

Please sign in to comment.