Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(doctest): SSIM metric doctest #2241

Merged
merged 11 commits into from
Oct 15, 2021
25 changes: 25 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,28 @@ jobs:
- name: make linkcheck
working-directory: ./docs/
run: make linkcheck --jobs 2 SPHINXOPTS="--color -W"

doctest:
if: github.event_name == 'pull_request' || github.event_name == 'push'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: 3.7

- run: sudo npm install katex@0.13.0 -g
- uses: actions/cache@v2
with:
path: ~/.cache/pip
key: pip-${{ hashFiles('requirements-dev.txt') }}-${{ hashFiles('docs/requirements.txt') }}

- name: Install docs deps
run: bash .github/workflows/install_docs_deps.sh

- name: make doctest
working-directory: ./docs/
run: |
make html SPHINXOPTS="--color -W"
make doctest
make coverage
13 changes: 13 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,16 @@ def run(self):
("py:class", "torch.optim.lr_scheduler._LRScheduler"),
("py:class", "torch.utils.data.dataloader.DataLoader"),
]

# doctest config
doctest_global_setup = """
import torch
from torch import nn, optim

from ignite.engine import *
from ignite.handlers import *
from ignite.metrics import *
from ignite.utils import *

manual_seed(666)
"""
14 changes: 11 additions & 3 deletions ignite/metrics/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,22 @@ class SSIM(Metric):
``y_pred`` and ``y`` can be un-normalized or normalized image tensors. Depending on that, the user might need
to adjust ``data_range``. ``y_pred`` and ``y`` should have the same shape.

.. code-block:: python
.. testcode::

def process_function(engine, batch):
# ...
y_pred, y = batch
return y_pred, y
engine = Engine(process_function)
metric = SSIM(data_range=1.0)
metric.attach(engine, "ssim")
metric.attach(engine, 'ssim')
preds = torch.rand([4, 3, 16, 16])
target = preds * 0.75
state = engine.run([[preds, target]])
print(state.metrics['ssim'])

.. testoutput::

0.9218971...

.. versionadded:: 0.4.2
"""
Expand Down