-
Notifications
You must be signed in to change notification settings - Fork 402
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* implementation and test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * changelog * add example * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review * update to torch fidelity 0.3.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 35min Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jirka <jirka.borovec@seznam.cz>
- Loading branch information
1 parent
1841cad
commit f54ccca
Showing
9 changed files
with
335 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import pickle | ||
|
||
import pytest | ||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
from torchmetrics.image.inception import IS | ||
from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE | ||
|
||
torch.manual_seed(42) | ||
|
||
|
||
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") | ||
def test_no_train(): | ||
""" Assert that metric never leaves evaluation mode """ | ||
|
||
class MyModel(torch.nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.metric = IS() | ||
|
||
def forward(self, x): | ||
return x | ||
|
||
model = MyModel() | ||
model.train() | ||
assert model.training | ||
assert not model.metric.inception.training, 'IS metric was changed to training mode which should not happen' | ||
|
||
|
||
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity') | ||
def test_is_pickle(): | ||
""" Assert that we can initialize the metric and pickle it""" | ||
metric = IS() | ||
assert metric | ||
|
||
# verify metrics work after being loaded from pickled state | ||
pickled_metric = pickle.dumps(metric) | ||
metric = pickle.loads(pickled_metric) | ||
|
||
|
||
def test_is_raises_errors_and_warnings(): | ||
""" Test that expected warnings and errors are raised """ | ||
with pytest.warns( | ||
UserWarning, | ||
match='Metric `IS` will save all extracted features in buffer.' | ||
' For large datasets this may lead to large memory footprint.' | ||
): | ||
IS() | ||
|
||
if _TORCH_FIDELITY_AVAILABLE: | ||
with pytest.raises(ValueError, match='Integer input to argument `feature` must be one of .*'): | ||
_ = IS(feature=2) | ||
else: | ||
with pytest.raises( | ||
ValueError, | ||
match='IS metric requires that Torch-fidelity is installed.' | ||
'Either install as `pip install torchmetrics[image-quality]`' | ||
' or `pip install torch-fidelity`' | ||
): | ||
IS() | ||
|
||
with pytest.raises(TypeError, match='Got unknown input to argument `feature`'): | ||
IS(feature=[1, 2]) | ||
|
||
|
||
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity') | ||
def test_is_update_compute(): | ||
metric = IS() | ||
|
||
for _ in range(2): | ||
img = torch.randint(0, 255, (10, 3, 299, 299), dtype=torch.uint8) | ||
metric.update(img) | ||
|
||
mean, std = metric.compute() | ||
assert mean != 0.0 | ||
assert std != 0.0 | ||
|
||
|
||
class _ImgDataset(Dataset): | ||
|
||
def __init__(self, imgs): | ||
self.imgs = imgs | ||
|
||
def __getitem__(self, idx): | ||
return self.imgs[idx] | ||
|
||
def __len__(self): | ||
return self.imgs.shape[0] | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='test is too slow without gpu') | ||
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity') | ||
def test_compare_is(tmpdir): | ||
""" check that the hole pipeline give the same result as torch-fidelity """ | ||
from torch_fidelity import calculate_metrics | ||
|
||
metric = IS(splits=1).cuda() | ||
|
||
# Generate some synthetic data | ||
img1 = torch.randint(0, 255, (100, 3, 299, 299), dtype=torch.uint8) | ||
|
||
batch_size = 10 | ||
for i in range(img1.shape[0] // batch_size): | ||
metric.update(img1[batch_size * i:batch_size * (i + 1)].cuda()) | ||
|
||
torch_fid = calculate_metrics(input1=_ImgDataset(img1), isc=True, isc_splits=1, batch_size=batch_size) | ||
|
||
tm_mean, tm_std = metric.compute() | ||
|
||
assert torch.allclose(tm_mean.cpu(), torch.tensor([torch_fid['inception_score_mean']]), atol=1e-3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.