diff --git a/.github/assistant.py b/.github/assistant.py new file mode 100644 index 00000000000..cfc8db85561 --- /dev/null +++ b/.github/assistant.py @@ -0,0 +1,130 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import glob +import json +import logging +import os +import re +import sys +import traceback +from typing import List, Optional, Tuple, Union + +import fire +import requests + +_REQUEST_TIMEOUT = 10 +_PATH_ROOT = os.path.dirname(os.path.dirname(__file__)) +_PKG_WIDE_SUBPACKAGES = ("utilities",) +LUT_PYTHON_TORCH = { + "3.8": "1.4", + "3.9": "1.7.1", +} +REQUIREMENTS_FILES = (os.path.join(_PATH_ROOT, "requirements.txt"),) + tuple( + glob.glob(os.path.join(_PATH_ROOT, "requirements", "*.txt")) +) + + +def request_url(url: str, auth_token: Optional[str] = None) -> Optional[dict]: + """General request with checking if request limit was reached.""" + auth_header = {"Authorization": f"token {auth_token}"} if auth_token else {} + try: + req = requests.get(url, headers=auth_header, timeout=_REQUEST_TIMEOUT) + except requests.exceptions.Timeout: + traceback.print_exc() + return None + if req.status_code == 403: + return None + return json.loads(req.content.decode(req.encoding)) + + +class AssistantCLI: + @staticmethod + def prune_packages(req_file: str, *pkgs: str) -> None: + """Prune packages from requirement file.""" + with open(req_file) as fp: + lines = fp.readlines() + + for pkg in pkgs: + lines = [ln for ln in lines if not ln.startswith(pkg)] + logging.info(lines) + + with open(req_file, "w") as fp: + fp.writelines(lines) + + @staticmethod + def set_min_torch_by_python(fpath: str = "requirements.txt") -> None: + """Set minimal torch version according to Python actual version.""" + py_ver = f"{sys.version_info.major}.{sys.version_info.minor}" + if py_ver not in LUT_PYTHON_TORCH: + return + with open(fpath) as fp: + req = fp.read() + req = re.sub(r"torch>=[\d\.]+", f"torch>={LUT_PYTHON_TORCH[py_ver]}", req) + with open(fpath, "w") as fp: + fp.write(req) + + @staticmethod + def replace_min_requirements(fpath: str) -> None: + """Replace all `>=` by `==` in given file.""" + logging.info(f"processing: {fpath}") + with open(fpath) as fp: + req = fp.read() + req = req.replace(">=", "==") + with open(fpath, "w") as fp: + fp.write(req) + + @staticmethod + def set_oldest_versions(req_files: List[str] = REQUIREMENTS_FILES) -> None: + AssistantCLI.set_min_torch_by_python() + for fpath in req_files: + AssistantCLI.replace_min_requirements(fpath) + + @staticmethod + def changed_domains( + pr: int, + auth_token: Optional[str] = None, + as_list: bool = False, + general_sub_pkgs: Tuple[str] = _PKG_WIDE_SUBPACKAGES, + ) -> Union[str, List[str]]: + """Determine what domains were changed in particular PR.""" + url = f"https://api.github.com/repos/PyTorchLightning/metrics/pulls/{pr}/files" + logging.debug(url) + data = request_url(url, auth_token) + if not data: + logging.error("No data was received.") + return "tests" + files = [d["filename"] for d in data] + # filter only package files and skip inits + _filter = lambda fn: (fn.startswith("torchmetrics") and "__init__.py" not in fn) or fn.startswith("tests") + files = [fn for fn in files if _filter(fn)] + if not files: + return "tests" + # parse domains + files = [fn.replace("torchmetrics/", "").replace("tests/", "").replace("functional/", "") for fn in files] + # filter domain names + tm_modules = [fn.split("/")[0] for fn in files if "/" in fn] + # filter general (used everywhere) sub-packages + tm_modules = [md for md in tm_modules if md not in general_sub_pkgs] + if len(files) > len(tm_modules): + return "tests" + # keep only unique + tm_modules = set(tm_modules) + if as_list: + return list(tm_modules) + return " ".join([f"tests/{md}" for md in tm_modules]) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + fire.Fire(AssistantCLI) diff --git a/.github/prune-packages.py b/.github/prune-packages.py deleted file mode 100644 index fb0e6018d30..00000000000 --- a/.github/prune-packages.py +++ /dev/null @@ -1,18 +0,0 @@ -import sys -from pprint import pprint - - -def main(req_file: str, *pkgs): - with open(req_file) as fp: - lines = fp.readlines() - - for pkg in pkgs: - lines = [ln for ln in lines if not ln.startswith(pkg)] - pprint(lines) - - with open(req_file, "w") as fp: - fp.writelines(lines) - - -if __name__ == "__main__": - main(*sys.argv[1:]) diff --git a/.github/set-oldest-versions.py b/.github/set-oldest-versions.py deleted file mode 100644 index 80bac375bc8..00000000000 --- a/.github/set-oldest-versions.py +++ /dev/null @@ -1,40 +0,0 @@ -import glob -import logging -import os -import re -import sys - -LUT_PYTHON_TORCH = { - "3.8": "1.4", - "3.9": "1.7.1", -} -REQUIREMENTS_FILES = ("requirements.txt",) + tuple(glob.glob(os.path.join("requirements", "*.txt"))) - - -def set_min_torch_by_python(fpath: str = "requirements.txt") -> None: - """set minimal torch version.""" - py_ver = f"{sys.version_info.major}.{sys.version_info.minor}" - if py_ver not in LUT_PYTHON_TORCH: - return - with open(fpath) as fp: - req = fp.read() - req = re.sub(r"torch>=[\d\.]+", f"torch>={LUT_PYTHON_TORCH[py_ver]}", req) - with open(fpath, "w") as fp: - fp.write(req) - - -def replace_min_requirements(fpath: str) -> None: - """replace all `>=` by `==` in given file.""" - logging.info(f"processing: {fpath}") - with open(fpath) as fp: - req = fp.read() - req = req.replace(">=", "==") - with open(fpath, "w") as fp: - fp.write(req) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - set_min_torch_by_python() - for fpath in REQUIREMENTS_FILES: - replace_min_requirements(fpath) diff --git a/.github/workflows/ci_integrate.yml b/.github/workflows/ci_integrate.yml index 33052b9163f..d3de845070a 100644 --- a/.github/workflows/ci_integrate.yml +++ b/.github/workflows/ci_integrate.yml @@ -13,10 +13,10 @@ concurrency: cancel-in-progress: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }} jobs: + pytest: runs-on: ${{ matrix.os }} - if: github.event.pull_request.draft == false strategy: fail-fast: false matrix: @@ -42,10 +42,12 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: PIP install assitant's deps + run: pip install fire requests + - name: Set min. dependencies if: matrix.requires == 'oldest' - run: | - python .github/set-oldest-versions.py + run: python .github/assistant.py set-oldest-versions - run: echo "::set-output name=period::$(python -c 'import time ; days = time.time() / 60 / 60 / 24 ; print(int(days / 7))' 2>&1)" if: matrix.requires == 'latest' diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index 8efbb834acf..d0a50d633f3 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -12,16 +12,40 @@ concurrency: cancel-in-progress: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }} jobs: + # ToDo: consider unifying in a single workflow and distributing outputs to all others depending ones + check-diff: + runs-on: ubuntu-20.04 + if: github.event.pull_request.draft == false + timeout-minutes: 5 + # Map a step output to a job output + outputs: + focus: ${{ steps.diff-domains.outputs.focus }} + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: Get PR diff + id: diff-domains + env: + PR_NUMBER: "${{ github.event.pull_request.number }}" + run: | + pip install fire requests + # python actions/assistant.py list_runtimes $PR_NUMBER + echo "::set-output name=focus::$(python .github/assistant.py changed_domains $PR_NUMBER 2>&1)" + - run: echo "${{ steps.diff-domains.outputs.focus }}" + conda: runs-on: ubuntu-20.04 + needs: check-diff strategy: fail-fast: false matrix: python-version: ["3.8"] pytorch-version: ["1.4", "1.5", "1.6", "1.7", "1.8", "1.9", "1.10", "1.11"] include: - - python-version: 3.7 - pytorch-version: '1.3' + - {python-version: 3.7, pytorch-version: '1.3'} env: PYTEST_ARTEFACT: test-conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}.xml TRANSFORMERS_CACHE: .cache/huggingface/ @@ -83,7 +107,9 @@ jobs: conda list pip --version python ./requirements/adjust-versions.py requirements.txt - python ./.github/prune-packages.py requirements/image.txt torchvision + pip install --requirement requirements/test.txt --quiet + python ./.github/assistant.py prune-packages requirements/image.txt torchvision + python ./.github/assistant.py prune-packages requirements/detection.txt torchvision pip install -q "numpy==1.20.0" # try to fix cocotools for PT 1.4 & 1.9 pip install --requirement requirements.txt --quiet pip install --requirement requirements/devel.txt --quiet @@ -92,9 +118,11 @@ jobs: shell: bash -l {0} - name: Testing + env: + TEST_DIRS: "${{ needs.check-diff.outputs.focus }}" run: | # NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003 - python -m pytest torchmetrics tests -v --durations=50 --junitxml=$PYTEST_ARTEFACT + python -m pytest torchmetrics $TEST_DIRS --durations=50 --junitxml=$PYTEST_ARTEFACT shell: bash -l {0} - name: Upload pytest test results diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 22618f2ea2e..84848284515 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -13,10 +13,33 @@ concurrency: cancel-in-progress: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }} jobs: - pytest: + # ToDo: consider unifying in a single workflow and distributing outputs to all others depending ones + check-diff: + runs-on: ubuntu-20.04 + if: github.event.pull_request.draft == false + timeout-minutes: 5 + # Map a step output to a job output + outputs: + focus: ${{ steps.diff-domains.outputs.focus }} + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Get PR diff + id: diff-domains + env: + PR_NUMBER: "${{ github.event.pull_request.number }}" + run: | + pip install fire requests + # python actions/assistant.py list_runtimes $PR_NUMBER + echo "::set-output name=focus::$(python .github/assistant.py changed_domains $PR_NUMBER 2>&1)" + - run: echo "${{ steps.diff-domains.outputs.focus }}" + + pytest: runs-on: ${{ matrix.os }} - if: github.event.pull_request.draft == false + needs: check-diff strategy: fail-fast: false matrix: @@ -57,10 +80,12 @@ jobs: run: | choco install ffmpeg + - name: PIP install assitant's deps + run: pip install fire requests + - name: Set min. dependencies if: matrix.requires == 'oldest' - run: | - python .github/set-oldest-versions.py + run: python .github/assistant.py set-oldest-versions - run: echo "::set-output name=period::$(python -c 'import time ; days = time.time() / 60 / 60 / 24 ; print(int(days / 7))' 2>&1)" if: matrix.requires == 'latest' @@ -97,9 +122,11 @@ jobs: key: cache-transformers - name: Tests + env: + TEST_DIRS: "${{ needs.check-diff.outputs.focus }}" run: | phmdoctest README.md --outfile tests/test_readme.py - python -m pytest torchmetrics tests -v --cov=torchmetrics --junitxml="junit/$PYTEST_ARTEFACT" --durations=50 + python -m pytest torchmetrics $TEST_DIRS --cov=torchmetrics --junitxml="junit/$PYTEST_ARTEFACT" --durations=50 - name: Upload pytest test results uses: actions/upload-artifact@v2 diff --git a/requirements/adjust-versions.py b/requirements/adjust-versions.py index d436a285eee..25f621cb988 100644 --- a/requirements/adjust-versions.py +++ b/requirements/adjust-versions.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import os import re diff --git a/requirements/test.txt b/requirements/test.txt index 447d85c9072..d3ed7d945d6 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -10,5 +10,8 @@ mypy>=0.790 phmdoctest>=1.1.1 pre-commit>=1.0 +requests +fire + cloudpickle>=1.3 scikit-learn>=0.24