Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric docs fix #2209

Merged
merged 34 commits into from
Jun 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
2d5b8e9
fix docs
Jun 16, 2020
5400081
Update docs/source/metrics.rst
williamFalcon Jun 16, 2020
de5083f
Update docs/source/metrics.rst
williamFalcon Jun 16, 2020
a902b30
Update docs/source/metrics.rst
williamFalcon Jun 16, 2020
c2145fa
Update docs/source/metrics.rst
williamFalcon Jun 16, 2020
c8d3a91
Update metrics.rst
williamFalcon Jun 16, 2020
4556f6e
title
Borda Jun 16, 2020
254b023
fix
Jun 16, 2020
6b11067
fix for num_classes
Jun 16, 2020
4cadc58
chlog
Borda Jun 16, 2020
63690bf
nb classes
Borda Jun 16, 2020
0ecc521
hints
Borda Jun 16, 2020
9f1cbd1
zero division
Borda Jun 16, 2020
46f59d9
add tests
Borda Jun 16, 2020
1660e42
Update metrics.rst
edenlightning Jun 16, 2020
2a407a9
Update classification.py
edenlightning Jun 16, 2020
81095a7
Update classification.py
edenlightning Jun 16, 2020
764b52a
prune doctests
Borda Jun 16, 2020
d5e1b31
Merge branch 'metric_docs' of https://github.com/SkafteNicki/pytorch-…
Borda Jun 16, 2020
770787e
docs
Borda Jun 16, 2020
8cd4a52
Apply suggestions from code review
Borda Jun 16, 2020
4de28c3
Apply suggestions from code review
Borda Jun 16, 2020
70f9162
flake8
Borda Jun 16, 2020
f6144ee
doctests
Borda Jun 16, 2020
7f201ee
formatting
Borda Jun 16, 2020
7701bd4
cleaning
Borda Jun 16, 2020
fe6699a
formatting
Borda Jun 16, 2020
1810dfe
formatting
Borda Jun 16, 2020
59e49a4
doctests
Borda Jun 16, 2020
f3a0fb3
flake8
Borda Jun 16, 2020
caaacf9
docs
Borda Jun 16, 2020
af872fc
rename
Borda Jun 17, 2020
f9ee46b
rename
Borda Jun 17, 2020
ae0b3c6
typo
Borda Jun 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added metric Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
- Added Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
- Added Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488))
- Added metrics
* Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
* Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
* Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488))
* docs for all Metrics ([#2184](https://github.com/PyTorchLightning/pytorch-lightning/pull/2184), [#2209](https://github.com/PyTorchLightning/pytorch-lightning/pull/2209))
- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723))
- Allow dataloaders without sampler field present ([#1907](https://github.com/PyTorchLightning/pytorch-lightning/pull/1907))
- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` [(#1908)](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908)
Expand Down
109 changes: 105 additions & 4 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ Metrics are used to monitor model performance.
In this package we provide two major pieces of functionality.

1. A Metric class you can use to implement metrics with built-in distributed (ddp) support which are device agnostic.
2. A collection of popular metrics already implemented for you.
2. A collection of ready to use pupular metrics. There are two types of metrics: Class metrics and Functional metrics.
3. A interface to call `sklearns metrics <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics>`_

Example::

Expand All @@ -28,12 +29,17 @@ Out::

tensor(0.7500)

.. warning::
The metrics package is still in development! If we're missing a metric or you find a mistake, please send a PR!
to a few metrics. Please feel free to create an issue/PR if you have a proposed
metric or have found a bug.

--------------

Implement a metric
------------------
You can implement metrics as either a PyTorch metric or a Numpy metric. Numpy metrics
will slow down training, use PyTorch metrics when possible.
You can implement metrics as either a PyTorch metric or a Numpy metric (It is recommend to use PyTorch metrics when possible,
since Numpy metrics slow down training).

Use :class:`TensorMetric` to implement native PyTorch metrics. This class
handles automated DDP syncing and converts all inputs and outputs to tensors.
Expand Down Expand Up @@ -76,7 +82,7 @@ Here's an example showing how to implement a NumpyMetric

Class Metrics
-------------
The following are metrics which can be instantiated as part of a module definition (even with just
Class metrics can be instantiated as part of a module definition (even with just
plain PyTorch).

.. testcode::
Expand Down Expand Up @@ -316,3 +322,98 @@ to_onehot (F)

.. autofunction:: pytorch_lightning.metrics.functional.to_onehot
:noindex:

----------------

Sklearn interface
-----------------

Lightning supports `sklearns metrics module <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics>`_
as a backend for calculating metrics. Sklearns metrics are well tested and robust,
but requires conversion between pytorch and numpy thus may slow down your computations.

To use the sklearn backend of metrics simply import as

.. code-block:: python

import pytorch_lightning.metrics.sklearns import plm
metric = plm.Accuracy(normalize=True)
val = metric(pred, target)

Each converted sklearn metric comes has the same interface as its
originally counterpart (e.g. accuracy takes the additional `normalize` keyword).
Like the native Lightning metrics these converted sklearn metrics also come
with built-in distributed (ddp) support.

SklearnMetric (sk)
^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.sklearns.SklearnMetric
:noindex:

Accuracy (sk)
^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.sklearns.Accuracy
:noindex:

AUC (sk)
^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.sklearns.AUC
:noindex:

AveragePrecision (sk)
^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.sklearns.AveragePrecision
:noindex:


ConfusionMatrix (sk)
^^^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.sklearns.ConfusionMatrix
:noindex:

F1 (sk)
^^^^^^^

.. autofunction:: pytorch_lightning.metrics.sklearns.F1
:noindex:

FBeta (sk)
^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.sklearns.FBeta
:noindex:

Precision (sk)
^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.sklearns.Precision
:noindex:

Recall (sk)
^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.sklearns.Recall
:noindex:

PrecisionRecallCurve (sk)
^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.sklearns.PrecisionRecallCurve
:noindex:

ROC (sk)
^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.sklearns.ROC
:noindex:

AUROC (sk)
^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.sklearns.AUROC
:noindex:
38 changes: 32 additions & 6 deletions pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,41 @@
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
from pytorch_lightning.metrics.sklearn import (
SklearnMetric,
from pytorch_lightning.metrics.classification import (
Accuracy,
AveragePrecision,
AUC,
ConfusionMatrix,
F1,
FBeta,
Precision,
Recall,
PrecisionRecallCurve,
ROC,
AUROC)
AUROC,
DiceCoefficient,
MulticlassPrecisionRecall,
MulticlassROC,
Precision,
PrecisionRecall,
)
from pytorch_lightning.metrics.sklearns import (
AUC,
PrecisionRecallCurve,
SklearnMetric,
)

__all__ = [
'AUC',
'AUROC',
'Accuracy',
'AveragePrecision',
'ConfusionMatrix',
'DiceCoefficient',
'F1',
'FBeta',
'MulticlassPrecisionRecall',
'MulticlassROC',
'Precision',
'PrecisionRecall',
'PrecisionRecallCurve',
'ROC',
'Recall',
'SklearnMetric',
]
Loading