Skip to content

Commit cd9daa1

Browse files
authored
Merge branch 'master' into deprecation-decorators
2 parents a310730 + 80972b5 commit cd9daa1

File tree

16 files changed

+252
-147
lines changed

16 files changed

+252
-147
lines changed

.github/workflows/code-style.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ jobs:
2222
code-style:
2323
runs-on: ubuntu-latest
2424
steps:
25-
- uses: actions/checkout@master
25+
- uses: actions/checkout@v2
2626
- uses: actions/setup-python@v2
2727
with:
28-
python-version: "3.7"
28+
python-version: "3.8"
2929
- run: |
3030
python -m pip install autopep8 "black==19.10b0" "isort==4.3.21"
3131
isort -rc .

CONTRIBUTING.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,17 @@ git merge upstream/master
203203
### Writing documentation
204204

205205
Ignite uses [Google style](http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html)
206-
for formatting docstrings. Length of line inside docstrings block must be limited to 120 characters.
206+
for formatting docstrings and
207+
208+
- [`.. versionadded::`] directive for adding new classes, class methods, functions,
209+
- [`.. versionchanged::`] directive for adding new arguments, changing internal behaviours, fixing bugs and
210+
- [`.. deprecated::`] directive for deprecations.
211+
212+
Length of line inside docstrings block must be limited to 120 characters.
213+
214+
[`.. versionadded::`]: https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded
215+
[`.. versionchanged::`]: https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionchanged
216+
[`.. deprecated::`]: https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-deprecated
207217

208218
#### Local documentation building and deploying
209219

docs/source/conf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,10 @@
193193
# -- Options for intersphinx extension ---------------------------------------
194194

195195
# Example configuration for intersphinx: refer to the Python standard library.
196-
intersphinx_mapping = {"https://docs.python.org/3/": None}
196+
intersphinx_mapping = {
197+
"python": ("https://docs.python.org/3", None),
198+
"torch": ("https://pytorch.org/docs/stable/", None),
199+
}
197200

198201
# -- Options for todo extension ----------------------------------------------
199202

ignite/handlers/checkpoint.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ class Checkpoint(Serializable):
9898
details.
9999
include_self (bool): Whether to include the `state_dict` of this object in the checkpoint. If `True`, then
100100
there must not be another object in ``to_save`` with key ``checkpointer``.
101+
greater_or_equal (bool): if `True`, the latest equally scored model is stored. Otherwise, the first model.
102+
Default, `False`.
101103
102104
.. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
103105
torch.nn.parallel.DistributedDataParallel.html
@@ -245,6 +247,8 @@ def score_function(engine):
245247
trainer.run(data_loader, max_epochs=10)
246248
> ["best_model_9_val_acc=0.77.pt", "best_model_10_val_acc=0.78.pt", ]
247249
250+
.. versionchanged:: 0.4.3
251+
Added ``greater_or_equal`` parameter.
248252
"""
249253

250254
Item = NamedTuple("Item", [("priority", int), ("filename", str)])
@@ -261,6 +265,7 @@ def __init__(
261265
global_step_transform: Optional[Callable] = None,
262266
filename_pattern: Optional[str] = None,
263267
include_self: bool = False,
268+
greater_or_equal: bool = False,
264269
) -> None:
265270

266271
if to_save is not None: # for compatibility with ModelCheckpoint
@@ -301,6 +306,7 @@ def __init__(
301306
self.filename_pattern = filename_pattern
302307
self._saved = [] # type: List["Checkpoint.Item"]
303308
self.include_self = include_self
309+
self.greater_or_equal = greater_or_equal
304310

305311
def reset(self) -> None:
306312
"""Method to reset saved checkpoint names.
@@ -339,6 +345,12 @@ def _check_lt_n_saved(self, or_equal: bool = False) -> bool:
339345
return True
340346
return len(self._saved) < self.n_saved + int(or_equal)
341347

348+
def _compare_fn(self, new: Union[int, float]) -> bool:
349+
if self.greater_or_equal:
350+
return new >= self._saved[0].priority
351+
else:
352+
return new > self._saved[0].priority
353+
342354
def __call__(self, engine: Engine) -> None:
343355

344356
global_step = None
@@ -354,7 +366,7 @@ def __call__(self, engine: Engine) -> None:
354366
global_step = engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED)
355367
priority = global_step
356368

357-
if self._check_lt_n_saved() or self._saved[0].priority < priority:
369+
if self._check_lt_n_saved() or self._compare_fn(priority):
358370

359371
priority_str = f"{priority}" if isinstance(priority, numbers.Integral) else f"{priority:.4f}"
360372

ignite/metrics/accuracy.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,12 @@ def _check_type(self, output: Sequence[torch.Tensor]) -> None:
9292

9393

9494
class Accuracy(_BaseClassification):
95-
"""
96-
Calculates the accuracy for binary, multiclass and multilabel data.
95+
r"""Calculates the accuracy for binary, multiclass and multilabel data.
96+
97+
.. math:: \text{Accuracy} = \frac{ TP + TN }{ TP + TN + FP + FN }
98+
99+
where :math:`\text{TP}` is true positives, :math:`\text{TN}` is true negatives,
100+
:math:`\text{FP}` is false positives and :math:`\text{FN}` is false negatives.
97101
98102
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
99103
- `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...).

ignite/metrics/confusion_matrix.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ def normalize(matrix: torch.Tensor, average: str) -> torch.Tensor:
130130

131131

132132
def IoU(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambda:
133-
"""Calculates Intersection over Union using :class:`~ignite.metrics.ConfusionMatrix` metric.
133+
r"""Calculates Intersection over Union using :class:`~ignite.metrics.ConfusionMatrix` metric.
134+
135+
.. math:: \text{J}(A, B) = \frac{ \lvert A \cap B \rvert }{ \lvert A \cup B \rvert }
134136
135137
Args:
136138
cm (ConfusionMatrix): instance of confusion matrix metric

ignite/metrics/fbeta.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@ def Fbeta(
1717
output_transform: Optional[Callable] = None,
1818
device: Union[str, torch.device] = torch.device("cpu"),
1919
) -> MetricsLambda:
20-
"""Calculates F-beta score
20+
r"""Calculates F-beta score.
21+
22+
.. math::
23+
F_\beta = \left( 1 + \beta^2 \right) * \frac{ \text{precision} * \text{recall} }
24+
{ \left( \beta^2 * \text{precision} \right) + \text{recall} }
25+
26+
where :math:`\beta` is a positive real factor.
2127
2228
Args:
2329
beta (float): weight of precision in harmonic mean

ignite/metrics/mean_absolute_error.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99

1010

1111
class MeanAbsoluteError(Metric):
12-
"""
13-
Calculates the mean absolute error.
12+
r"""Calculates `the mean absolute error <https://en.wikipedia.org/wiki/Mean_absolute_error>`_.
13+
14+
.. math:: \text{MAE} = \frac{1}{N} \sum_{i=1}^N \lvert y_{i} - x_{i} \rvert
15+
16+
where :math:`y_{i}` is the prediction tensor and :math:`x_{i}` is ground true tensor.
1417
1518
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
1619
"""

ignite/metrics/mean_pairwise_distance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111

1212
class MeanPairwiseDistance(Metric):
13-
"""
14-
Calculates the mean pairwise distance: average of pairwise distances computed on provided batches.
13+
"""Calculates the mean :class:`~torch.nn.PairwiseDistance`.
14+
Average of pairwise distances computed on provided batches.
1515
1616
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
1717
"""

ignite/metrics/mean_squared_error.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99

1010

1111
class MeanSquaredError(Metric):
12-
"""
13-
Calculates the mean squared error.
12+
r"""Calculates the `mean squared error <https://en.wikipedia.org/wiki/Mean_squared_error>`_.
13+
14+
.. math:: \text{MSE} = \frac{1}{N} \sum_{i=1}^N \left(y_{i} - x_{i} \right)^2
15+
16+
where :math:`y_{i}` is the prediction tensor and :math:`x_{i}` is ground true tensor.
1417
1518
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
1619
"""

0 commit comments

Comments
 (0)