Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add testing agains each feat PT version #127

Merged
merged 8 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions .github/workflows/ci_test-conda.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
name: PyTorch & Conda

# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows
on: # Trigger the workflow on push or pull request, but only for the master branch
push:
branches: [master, "release/*"]
pull_request:
branches: [master, "release/*"]

jobs:
conda:
runs-on: ubuntu-20.04
strategy:
fail-fast: false
matrix:
python-version: [3.7]
pytorch-version: [1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9]

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 35
steps:
- uses: actions/checkout@v2

- name: Cache conda
uses: actions/cache@v2
with:
path: ~/conda_pkgs_dir
key: conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}-${{ hashFiles('environment.yml') }}
restore-keys: conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}-

# Add another cache for Pip as not all packages lives in Conda env
- name: Cache pip
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: pip-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}-${{ hashFiles('requirements/base.txt') }}
restore-keys: pip-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}-

# https://docs.conda.io/projects/conda/en/4.6.0/_downloads/52a95608c49671267e40c689e0bc00ca/conda-cheatsheet.pdf
# https://gist.github.com/mwouts/9842452d020c08faf9e84a3bba38a66f
- name: Setup Miniconda
uses: conda-incubator/setup-miniconda@v2
with:
miniconda-version: "4.7.12"
python-version: ${{ matrix.python-version }}
channels: conda-forge,pytorch,pytorch-test,pytorch-nightly
channel-priority: true
auto-activate-base: true
# environment-file: ./environment.yml
use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly!

- name: Update Environment
run: |
conda info
conda install pytorch=${{ matrix.pytorch-version }} cpuonly
conda list
pip --version
pip install --requirement requirements.txt --upgrade-strategy only-if-needed --quiet
pip install --requirement requirements/test.txt --upgrade-strategy only-if-needed --quiet
pip list
python -c "import torch; assert torch.__version__[:3] == '${{ matrix.pytorch-version }}', torch.__version__"
shell: bash -l {0}

- name: Testing
run: |
# NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003
python -m pytest torchmetrics tests -v --durations=35 --junitxml=junit/test-conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}.xml
shell: bash -l {0}

- name: Upload pytest test results
uses: actions/upload-artifact@master
with:
name: test-conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}.xml
path: junit/test-conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}.xml
# Use always() to always run this step to publish test results when there are test failures
if: failure()
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ prune notebook*
prune temp*
prune test*
prune benchmark*
prune integration*
2 changes: 1 addition & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,5 @@ jobs:
condition: succeededOrFailed()

- bash: |
python -m pytest integrations --durations=25
python -m pytest integrations -v --durations=25
displayName: 'Integrations'
3 changes: 3 additions & 0 deletions integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from torchmetrics.utilities.imports import _module_available

_PL_AVAILABLE = _module_available('pytorch_lightning')
1 change: 1 addition & 0 deletions integrations/lightning_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset
Expand Down
5 changes: 4 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ exclude_lines =

[flake8]
max-line-length = 120
exclude = .tox,*.egg,build,temp
exclude =
*.egg
build
temp
select = E,W,F
doctests = True
verbose = 2
Expand Down
10 changes: 5 additions & 5 deletions tests/classification/test_stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch import Tensor, tensor

from tests.classification.inputs import _input_binary, _input_binary_prob, _input_multiclass
from tests.classification.inputs import _input_multiclass_prob as _input_mccls_prob
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel as _input_mcls
Expand Down Expand Up @@ -104,8 +104,8 @@ def _sk_stat_scores_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, i
["macro", None, None, _input_binary, None],
["micro", None, None, _input_mdmc_prob, None],
["micro", None, None, _input_binary_prob, 0],
["micro", None, None, _input_mccls_prob, NUM_CLASSES],
["micro", None, NUM_CLASSES, _input_mccls_prob, NUM_CLASSES],
["micro", None, None, _input_mcls_prob, NUM_CLASSES],
["micro", None, NUM_CLASSES, _input_mcls_prob, NUM_CLASSES],
],
)
def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index):
Expand Down Expand Up @@ -141,8 +141,8 @@ def test_wrong_threshold():
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
(_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None),
(_input_mccls_prob.preds, _input_mccls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mccls_prob.preds, _input_mccls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
(_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None, None),
(
Expand Down
2 changes: 1 addition & 1 deletion tests/retrieval/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_input_data(torch_class_metric: Metric) -> None:
"""Check PL metrics inputs are controlled correctly. """

device = 'cuda' if torch.cuda.is_available() else 'cpu'
length = random.randint(0, 20)
length = random.randint(1, 20)
Borda marked this conversation as resolved.
Show resolved Hide resolved

# check error when `query_without_relevant_docs='error'` is raised correctly
indexes = torch.tensor([0] * length, device=device, dtype=torch.int64)
Expand Down
58 changes: 58 additions & 0 deletions torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,64 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distutils.version import LooseVersion
from importlib import import_module
from importlib.util import find_spec

import torch
from pkg_resources import DistributionNotFound


def _module_available(module_path: str) -> bool:
"""
Check if a path is available in your environment

>>> _module_available('os')
True
>>> _module_available('bla.bla')
False
"""
try:
return find_spec(module_path) is not None
except AttributeError:
# Python 3.6
return False
except ModuleNotFoundError:
# Python 3.7+
return False


def _compare_version(package: str, op, version) -> bool:
"""
Compare package version with some requirements

>>> import operator
>>> _compare_version("torch", operator.ge, "0.1")
True
"""
try:
pkg = import_module(package)
except (ModuleNotFoundError, DistributionNotFound):
return False
try:
pkg_version = LooseVersion(pkg.__version__)
except AttributeError:
return False
if not (hasattr(pkg_version, "vstring") and hasattr(pkg_version, "version")):
# this is mock by sphinx, so it shall return True ro generate all summaries
return True
return op(pkg_version, LooseVersion(version))


_TORCH_LOWER_1_4 = LooseVersion(torch.__version__) < LooseVersion("1.4.0")
_TORCH_LOWER_1_5 = LooseVersion(torch.__version__) < LooseVersion("1.5.0")
Expand Down