Skip to content

Commit

Permalink
Feature pre commit yaml (#145)
Browse files Browse the repository at this point in the history
* added pre-commit yaml from PL

* added yapf and flake8 to pre-commit config

* manifest

* names

* formatting

Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
jspaezp and Borda authored Mar 30, 2021
1 parent d1af80f commit dfc4c38
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 17 deletions.
50 changes: 50 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

default_language_version:
python: python3.8

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer

- repo: local
hooks:
- id: isort
name: imports
entry: isort
args: [--settings-path, ./pyproject.toml]
language: system
types: [python]

- id: yapf
name: formatting
entry: yapf
args: [ --parallel ]
language: system
types: [python]

- id: flake8
name: PEP8
entry: flake8
language: system
types: [python]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.790
hooks:
- id: mypy
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ recursive-exclude requirements *.txt

# Exclude build configs
exclude *.yml
exclude *.yaml
exclude Makefile

prune .git
Expand Down
16 changes: 7 additions & 9 deletions tests/functional/test_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def test_metrics_output_values_with_k(sklearn_metric, torch_metric, size, k):
assert torch.allclose(sk.float(), tm.float())


@pytest.mark.parametrize(['torch_metric'], [
@pytest.mark.parametrize(['torch_metric'], (
[retrieval_average_precision],
[retrieval_reciprocal_rank],
[retrieval_precision]
])
[retrieval_precision],
))
def test_input_dtypes(torch_metric) -> None:
""" Check wrong input dtypes are managed correctly. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Expand All @@ -104,11 +104,11 @@ def test_input_dtypes(torch_metric) -> None:
assert torch.allclose(torch_metric(preds=preds, target=target), torch.tensor(0.0))


@pytest.mark.parametrize(['torch_metric'], [
@pytest.mark.parametrize(['torch_metric'], (
[retrieval_average_precision],
[retrieval_reciprocal_rank],
[retrieval_precision]
])
[retrieval_precision],
))
def test_input_shapes(torch_metric) -> None:
""" Check wrong input shapes are managed correctly. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Expand All @@ -129,9 +129,7 @@ def test_input_shapes(torch_metric) -> None:


# test metrics using top K parameter
@pytest.mark.parametrize(['torch_metric'], [
[retrieval_precision]
])
@pytest.mark.parametrize(['torch_metric'], ([retrieval_precision], ))
@pytest.mark.parametrize('k', [-1, 1.0])
def test_input_params(torch_metric, k) -> None:
""" Check wrong input shapes are managed correctly. """
Expand Down
2 changes: 1 addition & 1 deletion tests/retrieval/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _test_retrieval_against_sklearn(
size: int,
n_documents: int,
query_without_relevant_docs_options: str,
**kwargs
**kwargs,
) -> None:
""" Compare PL metrics to standard version. """
metric = torch_metric(query_without_relevant_docs=query_without_relevant_docs_options, **kwargs)
Expand Down
7 changes: 1 addition & 6 deletions tests/retrieval/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,7 @@ def _precision_at_k(target: np.array, preds: np.array, k: int = None):
def test_results(size, n_documents, query_without_relevant_docs_options, k):
""" Test metrics are computed correctly. """
_test_retrieval_against_sklearn(
_precision_at_k,
RetrievalPrecision,
size,
n_documents,
query_without_relevant_docs_options,
k=k
_precision_at_k, RetrievalPrecision, size, n_documents, query_without_relevant_docs_options, k=k
)


Expand Down
5 changes: 4 additions & 1 deletion torchmetrics/wrappers/bootstrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_7


def _bootstrap_sampler(size: int, sampling_strategy: str = 'poisson') -> Tensor:
def _bootstrap_sampler(
size: int,
sampling_strategy: str = 'poisson',
) -> Tensor:
""" Resample a tensor along its first dimension with replacement
Args:
size: number of samples
Expand Down

0 comments on commit dfc4c38

Please sign in to comment.