Skip to content

Commit

Permalink
[Typing][A-91] Add type annotations for paddle/metric/metrics.py (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Asthestarsfalll authored Jul 2, 2024
1 parent 18dd803 commit 888f213
Showing 1 changed file with 83 additions and 34 deletions.
117 changes: 83 additions & 34 deletions python/paddle/metric/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import abc
from typing import TYPE_CHECKING, Any, Literal, Sequence

import numpy as np

Expand All @@ -24,6 +26,12 @@
from ..base.layer_helper import LayerHelper
from ..framework import in_dynamic_mode

if TYPE_CHECKING:
import numpy.typing as npt

from paddle import Tensor


__all__ = []


Expand Down Expand Up @@ -114,11 +122,11 @@ class Metric(metaclass=abc.ABCMeta):
... return accs
"""

def __init__(self):
def __init__(self) -> None:
pass

@abc.abstractmethod
def reset(self):
def reset(self) -> None:
"""
Reset states and result
"""
Expand All @@ -127,7 +135,7 @@ def reset(self):
)

@abc.abstractmethod
def update(self, *args):
def update(self, *args: Any) -> None:
"""
Update states for metric
Expand All @@ -143,7 +151,7 @@ def update(self, *args):
)

@abc.abstractmethod
def accumulate(self):
def accumulate(self) -> Any:
"""
Accumulates statistics, computes and returns the metric value
"""
Expand All @@ -152,15 +160,15 @@ def accumulate(self):
)

@abc.abstractmethod
def name(self):
def name(self) -> str:
"""
Returns metric name
"""
raise NotImplementedError(
f"function 'name' not implemented in {self.__class__.__name__}."
)

def compute(self, *args):
def compute(self, *args: Any) -> Any:
"""
This API is advanced usage to accelerate metric calculating, calculations
from outputs of model to the states which should be updated by Metric can
Expand Down Expand Up @@ -189,7 +197,7 @@ class Accuracy(Metric):
Args:
topk (list[int]|tuple[int]): Number of top elements to look at
for computing accuracy. Default is (1,).
name (str, optional): String name of the metric instance. Default
name (str|None, optional): String name of the metric instance. Default
is `acc`.
Examples:
Expand All @@ -216,6 +224,7 @@ class Accuracy(Metric):
.. code-block:: python
:name: code-model-api-example
>>> # doctest: +TIMEOUT(80)
>>> import paddle
>>> from paddle.static import InputSpec
>>> import paddle.vision.transforms as T
Expand All @@ -238,14 +247,23 @@ class Accuracy(Metric):
"""

def __init__(self, topk=(1,), name=None, *args, **kwargs):
topk: Sequence[int]
maxk: int

def __init__(
self,
topk: Sequence[int] = (1,),
name: str | None = None,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.topk = topk
self.maxk = max(topk)
self._init_name(name)
self.reset()

def compute(self, pred, label, *args):
def compute(self, pred: Tensor, label: Tensor, *args: Any) -> Tensor:
"""
Compute the top-k (maximum value in `topk`) indices.
Expand Down Expand Up @@ -276,7 +294,7 @@ def compute(self, pred, label, *args):
correct = pred == label.astype(pred.dtype)
return paddle.cast(correct, dtype='float32')

def update(self, correct, *args):
def update(self, correct: Tensor, *args: Any) -> Tensor:
"""
Update the metrics states (correct count and total count), in order to
calculate cumulative accuracy of all instances. This function also
Expand All @@ -300,14 +318,14 @@ def update(self, correct, *args):
accs = accs[0] if len(self.topk) == 1 else accs
return accs

def reset(self):
def reset(self) -> None:
"""
Resets all of the metric state.
"""
self.total = [0.0] * len(self.topk)
self.count = [0] * len(self.topk)

def accumulate(self):
def accumulate(self) -> list[float]:
"""
Computes and returns the accumulated metric.
"""
Expand All @@ -318,14 +336,14 @@ def accumulate(self):
res = res[0] if len(self.topk) == 1 else res
return res

def _init_name(self, name):
def _init_name(self, name: str | None) -> None:
name = name or 'acc'
if self.maxk != 1:
self._name = [f'{name}_top{k}' for k in self.topk]
else:
self._name = [name]

def name(self):
def name(self) -> list[str]:
"""
Return name of metric instance.
"""
Expand Down Expand Up @@ -397,13 +415,22 @@ class Precision(Metric):
>>> model.fit(data, batch_size=16)
"""

def __init__(self, name='precision', *args, **kwargs):
tp: int
fp: int

def __init__(
self, name: str = 'precision', *args: Any, **kwargs: Any
) -> None:
super().__init__(*args, **kwargs)
self.tp = 0 # true positive
self.fp = 0 # false positive
self._name = name

def update(self, preds, labels):
def update(
self,
preds: npt.NDArray[np.float32 | np.float64] | Tensor,
labels: npt.NDArray[np.int32 | np.int64] | Tensor,
) -> None:
"""
Update the states based on the current mini-batch prediction results.
Expand Down Expand Up @@ -437,14 +464,14 @@ def update(self, preds, labels):
else:
self.fp += 1

def reset(self):
def reset(self) -> None:
"""
Resets all of the metric state.
"""
self.tp = 0
self.fp = 0

def accumulate(self):
def accumulate(self) -> float:
"""
Calculate the final precision.
Expand All @@ -454,7 +481,7 @@ def accumulate(self):
ap = self.tp + self.fp
return float(self.tp) / ap if ap != 0 else 0.0

def name(self):
def name(self) -> str:
"""
Returns metric name
"""
Expand Down Expand Up @@ -529,13 +556,20 @@ class Recall(Metric):
>>> model.fit(data, batch_size=16)
"""

def __init__(self, name='recall', *args, **kwargs):
tp: int
fn: int

def __init__(self, name: str = 'recall', *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.tp = 0 # true positive
self.fn = 0 # false negative
self._name = name

def update(self, preds, labels):
def update(
self,
preds: npt.NDArray[np.float32 | np.float64] | Tensor,
labels: npt.NDArray[np.int32 | np.int64] | Tensor,
) -> None:
"""
Update the states based on the current mini-batch prediction results.
Expand Down Expand Up @@ -569,7 +603,7 @@ def update(self, preds, labels):
else:
self.fn += 1

def accumulate(self):
def accumulate(self) -> float:
"""
Calculate the final recall.
Expand All @@ -579,14 +613,14 @@ def accumulate(self):
recall = self.tp + self.fn
return float(self.tp) / recall if recall != 0 else 0.0

def reset(self):
def reset(self) -> None:
"""
Resets all of the metric state.
"""
self.tp = 0
self.fn = 0

def name(self):
def name(self) -> str:
"""
Returns metric name
"""
Expand All @@ -612,7 +646,6 @@ class Auc(Metric):
'ROC' or 'PR' for the Precision-Recall-curve. Default is 'ROC'.
num_thresholds (int): The number of thresholds to use when
discretizing the roc curve. Default is 4095.
'ROC' or 'PR' for the Precision-Recall-curve. Default is 'ROC'.
name (str, optional): String name of the metric instance. Default
is `auc`.
Expand Down Expand Up @@ -675,8 +708,13 @@ class Auc(Metric):
"""

def __init__(
self, curve='ROC', num_thresholds=4095, name='auc', *args, **kwargs
):
self,
curve: Literal['ROC', 'PR'] = 'ROC',
num_thresholds: int = 4095,
name: str = 'auc',
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self._curve = curve
self._num_thresholds = num_thresholds
Expand All @@ -686,7 +724,11 @@ def __init__(
self._stat_neg = np.zeros(_num_pred_buckets)
self._name = name

def update(self, preds, labels):
def update(
self,
preds: npt.NDArray[np.float32 | np.float64] | Tensor,
labels: npt.NDArray[np.int32 | np.int64] | Tensor,
) -> None:
"""
Update the auc curve with the given predictions and labels.
Expand Down Expand Up @@ -718,10 +760,10 @@ def update(self, preds, labels):
self._stat_neg[bin_idx] += 1.0

@staticmethod
def trapezoid_area(x1, x2, y1, y2):
def trapezoid_area(x1: float, x2: float, y1: float, y2: float) -> float:
return abs(x1 - x2) * (y1 + y2) / 2.0

def accumulate(self):
def accumulate(self) -> float:
"""
Return the area (a float score) under auc curve
Expand All @@ -747,22 +789,29 @@ def accumulate(self):
auc / tot_pos / tot_neg if tot_pos > 0.0 and tot_neg > 0.0 else 0.0
)

def reset(self):
def reset(self) -> None:
"""
Reset states and result
"""
_num_pred_buckets = self._num_thresholds + 1
self._stat_pos = np.zeros(_num_pred_buckets)
self._stat_neg = np.zeros(_num_pred_buckets)

def name(self):
def name(self) -> str:
"""
Returns metric name
"""
return self._name


def accuracy(input, label, k=1, correct=None, total=None, name=None):
def accuracy(
input: Tensor,
label: Tensor,
k: int = 1,
correct: Tensor | None = None,
total: Tensor | None = None,
name: str | None = None,
) -> Tensor:
"""
accuracy layer.
Refer to the https://en.wikipedia.org/wiki/Precision_and_recall
Expand All @@ -778,7 +827,7 @@ def accuracy(input, label, k=1, correct=None, total=None, name=None):
k(int, optional): The top k predictions for each class will be checked. Data type is int64 or int32.
correct(Tensor, optional): The correct predictions count. A Tensor with type int64 or int32.
total(Tensor, optional): The total entries count. A tensor with type int64 or int32.
name(str, optional): The default value is None. Normally there is no need for
name(str|None, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Expand Down

0 comments on commit 888f213

Please sign in to comment.