Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AverageMeter implementation #138

Merged
merged 22 commits into from
Apr 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added other metrics:
* Added `CohenKappa` ([#69](https://github.com/PyTorchLightning/metrics/pull/69))
* Added `MatthewsCorrcoef` ([#98](https://github.com/PyTorchLightning/metrics/pull/98))
* Added `PearsonCorrcoef` ([#157](https://github.com/PyTorchLightning/metrics/pull/157))
* Added `SpearmanCorrcoef` ([#158](https://github.com/PyTorchLightning/metrics/pull/158))
* Added `Hinge` ([#120](https://github.com/PyTorchLightning/metrics/pull/120))
- Added Binned metrics ([#128](https://github.com/PyTorchLightning/metrics/pull/128))
- Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110))
- Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114))
- Added testing for `half` precision ([#77](https://github.com/PyTorchLightning/metrics/pull/77),
[#135](https://github.com/PyTorchLightning/metrics/pull/135)
)
- Added `AverageMeter` for ad-hoc averages of values ([#138](https://github.com/PyTorchLightning/metrics/pull/138))
- Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70))
- Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142))
- Added property `is_differentiable` to metrics and test for differentiability ([#154](https://github.com/PyTorchLightning/metrics/pull/154))


### Changed

Expand Down
24 changes: 13 additions & 11 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,28 @@
SPHINX_MOCK_REQUIREMENTS = int(os.environ.get("SPHINX_MOCK_REQUIREMENTS", True))

try:
from torchmetrics import info
from torchmetrics import __about__ as about
except ImportError:
# alternative https://stackoverflow.com/a/67692/4521646
spec = spec_from_file_location("torchmetrics/info.py", os.path.join(_PATH_ROOT, "torchmetrics", "info.py"))
info = module_from_spec(spec)
spec.loader.exec_module(info)
spec = spec_from_file_location(
"torchmetrics/__about__.py", os.path.join(_PATH_ROOT, "torchmetrics", "__about__.py")
)
about = module_from_spec(spec)
spec.loader.exec_module(about)

html_favicon = '_static/images/icon.svg'

# -- Project information -----------------------------------------------------

# this name shall match the project name in Github as it is used for linking to code
project = "PyTorch-Metrics"
copyright = info.__copyright__
author = info.__author__
copyright = about.__copyright__
author = about.__author__

# The short X.Y version
version = info.__version__
version = about.__version__
# The full version, including alpha/beta/rc tags
release = info.__version__
release = about.__version__

# Options for the linkcode extension
# ----------------------------------
Expand Down Expand Up @@ -171,7 +173,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None:

html_theme_options = {
'pytorch_project': 'https://pytorchlightning.ai',
'canonical_url': info.__docs_url__,
'canonical_url': about.__docs_url__,
"collapse_navigation": False,
"display_version": True,
"logo_only": False,
Expand Down Expand Up @@ -237,7 +239,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
project + " Documentation",
author,
project,
info.__docs__,
about.__docs__,
"Miscellaneous",
),
]
Expand Down Expand Up @@ -284,7 +286,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None:

# packages for which sphinx-apidoc should generate the docs (.rst files)
PACKAGES = [
info.__name__,
about.__name__,
]

# def run_apidoc(_):
Expand Down
12 changes: 0 additions & 12 deletions docs/source/pages/implement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,6 @@ Example implementation:
def compute(self):
return self.correct.float() / self.total

Metrics support backpropagation, if all computations involved in the metric calculation
are differentiable. However, note that the cached state is detached from the computational
graph and cannot be backpropagated. Not doing this would mean storing the computational
graph for each update call, which can lead to out-of-memory errors.
In practise this means that:

.. code-block:: python

metric = MyMetric()
val = metric(pred, target) # this value can be backpropagated
val = metric.compute() # this value cannot be backpropagated


Internal implementation details
-------------------------------
Expand Down
28 changes: 28 additions & 0 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,31 @@ They simply compute the metric value based on the given inputs.
Also, the integration within other parts of PyTorch Lightning will never be as tight as with the Module-based interface.
If you look for just computing the values, the functional metrics are the way to go.
However, if you are looking for the best integration and user experience, please consider also using the Module interface.


*****************************
Metrics and differentiability
*****************************

Metrics support backpropagation, if all computations involved in the metric calculation
are differentiable. All modular metrics have a property that determines if a metric is
differentible or not.

.. code-block:: python

@property
def is_differentiable(self) -> bool:
return True/False

However, note that the cached state is detached from the computational
graph and cannot be backpropagated. Not doing this would mean storing the computational
graph for each update call, which can lead to out-of-memory errors.
In practise this means that:

.. code-block:: python

metric = MyMetric()
val = metric(pred, target) # this value can be backpropagated
val = metric.compute() # this value cannot be backpropagated

A functional metric is differentiable if its corresponding modular metric is differentiable.
25 changes: 20 additions & 5 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,17 +196,17 @@ mean_squared_log_error [func]
:noindex:


psnr [func]
~~~~~~~~~~~
pearson_corrcoef [func]
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.psnr
.. autofunction:: torchmetrics.functional.pearson_corrcoef
:noindex:


ssim [func]
psnr [func]
~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.ssim
.. autofunction:: torchmetrics.functional.psnr
:noindex:


Expand All @@ -216,6 +216,21 @@ r2score [func]
.. autofunction:: torchmetrics.functional.r2score
:noindex:


spearman_corrcoef [func]
~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.spearman_corrcoef
:noindex:


ssim [func]
~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.ssim
:noindex:


***
NLP
***
Expand Down
50 changes: 45 additions & 5 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ metrics.
.. autoclass:: torchmetrics.Metric
:noindex:

We also have an ``AverageMeter`` class that is helpful for defining ad-hoc metrics, when creating
your own metric type might be too burdensome.

.. autoclass:: torchmetrics.AverageMeter
:noindex:

**********************
Classification Metrics
**********************
Expand Down Expand Up @@ -138,6 +144,24 @@ AUROC
.. autoclass:: torchmetrics.AUROC
:noindex:

BinnedAveragePrecision
~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.BinnedAveragePrecision
:noindex:

BinnedPrecisionRecallCurve
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.BinnedPrecisionRecallCurve
:noindex:

BinnedRecallAtFixedPrecision
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.BinnedRecallAtFixedPrecision
:noindex:

CohenKappa
~~~~~~~~~~

Expand Down Expand Up @@ -204,6 +228,7 @@ Recall
.. autoclass:: torchmetrics.Recall
:noindex:


ROC
~~~

Expand Down Expand Up @@ -250,17 +275,17 @@ MeanSquaredLogError
:noindex:


PSNR
~~~~
PearsonCorrcoef
~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.PSNR
.. autoclass:: torchmetrics.PearsonCorrcoef
:noindex:


SSIM
PSNR
~~~~

.. autoclass:: torchmetrics.SSIM
.. autoclass:: torchmetrics.PSNR
:noindex:


Expand All @@ -271,6 +296,21 @@ R2Score
:noindex:


SpearmanCorrcoef
~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.SpearmanCorrcoef
:noindex:


SSIM
~~~~

.. autoclass:: torchmetrics.SSIM
:noindex:



*********
Retrieval
*********
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
numpy
torch>=1.3.1
torch>=1.3.1
packaging
27 changes: 14 additions & 13 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@

_PATH_ROOT = os.path.realpath(os.path.dirname(__file__))
try:
from torchmetrics import info, setup_tools
from torchmetrics import __about__ as about
from torchmetrics import setup_tools
except ImportError:
# alternative https://stackoverflow.com/a/67692/4521646
sys.path.append("torchmetrics")
import info
import __about__ as about
import setup_tools

long_description = setup_tools._load_readme_description(
_PATH_ROOT,
homepage=info.__homepage__,
version=f'v{info.__version__}',
homepage=about.__homepage__,
version=f'v{about.__version__}',
)

# https://packaging.python.org/discussions/install-requires-vs-requirements /
Expand All @@ -27,13 +28,13 @@
# engineer specific practices
setup(
name='torchmetrics',
version=info.__version__,
description=info.__docs__,
author=info.__author__,
author_email=info.__author_email__,
url=info.__homepage__,
download_url=os.path.join(info.__homepage__, 'archive', 'master.zip'),
license=info.__license__,
version=about.__version__,
description=about.__docs__,
author=about.__author__,
author_email=about.__author_email__,
url=about.__homepage__,
download_url=os.path.join(about.__homepage__, 'archive', 'master.zip'),
license=about.__license__,
packages=find_packages(exclude=['tests', 'docs']),
long_description=long_description,
long_description_content_type='text/markdown',
Expand All @@ -44,9 +45,9 @@
setup_requires=[],
install_requires=setup_tools._load_requirements(_PATH_ROOT),
project_urls={
"Bug Tracker": os.path.join(info.__homepage__, 'issues'),
"Bug Tracker": os.path.join(about.__homepage__, 'issues'),
"Documentation": "https://torchmetrics.rtfd.io/en/latest/",
"Source Code": info.__homepage__,
"Source Code": about.__homepage__,
},
classifiers=[
'Environment :: Console',
Expand Down
Loading