From ba2378e7d681aa1085d2fbec909afb4483e7b8be Mon Sep 17 00:00:00 2001 From: cyberslack_lee Date: Sat, 20 Jul 2024 11:52:20 +0800 Subject: [PATCH] [Typing][B-64] Add type annotations for `python/paddle/text/datasets/uci_housing.py` (#66057) --------- Co-authored-by: SigureMo --- python/paddle/text/datasets/uci_housing.py | 34 +++++++++++++++++----- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/python/paddle/text/datasets/uci_housing.py b/python/paddle/text/datasets/uci_housing.py index 42854bc81902c3..acebf28d33047c 100644 --- a/python/paddle/text/datasets/uci_housing.py +++ b/python/paddle/text/datasets/uci_housing.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal import numpy as np @@ -18,6 +21,12 @@ from paddle.dataset.common import _check_exists_and_download from paddle.io import Dataset +if TYPE_CHECKING: + import numpy.typing as npt + + from paddle._typing.dtype_like import _DTypeLiteral + + _UciHousingDataSetMode = Literal["train", "test"] __all__ = [] URL = 'http://paddlemodels.bj.bcebos.com/uci_housing/housing.data' @@ -45,11 +54,11 @@ class UCIHousing(Dataset): dataset Args: - data_file(str): path to data file, can be set None if - :attr:`download` is True. Default None + data_file(str|None): path to data file, can be set None if + :attr:`download` is True. Default None. mode(str): 'train' or 'test' mode. Default 'train'. download(bool): whether to download dataset automatically if - :attr:`data_file` is not set. Default True + :attr:`data_file` is not set. Default True. Returns: Dataset: instance of UCI housing dataset. @@ -93,7 +102,16 @@ class UCIHousing(Dataset): """ - def __init__(self, data_file=None, mode='train', download=True): + mode: _UciHousingDataSetMode + data_file: str | None + dtype: _DTypeLiteral + + def __init__( + self, + data_file: str | None = None, + mode: _UciHousingDataSetMode = 'train', + download: bool = True, + ) -> None: assert mode.lower() in [ 'train', 'test', @@ -114,7 +132,7 @@ def __init__(self, data_file=None, mode='train', download=True): self.dtype = paddle.get_default_dtype() - def _load_data(self, feature_num=14, ratio=0.8): + def _load_data(self, feature_num: int = 14, ratio: float = 0.8) -> None: data = np.fromfile(self.data_file, sep=' ') data = data.reshape(data.shape[0] // feature_num, feature_num) maximums, minimums, avgs = ( @@ -130,11 +148,13 @@ def _load_data(self, feature_num=14, ratio=0.8): elif self.mode == 'test': self.data = data[offset:] - def __getitem__(self, idx): + def __getitem__( + self, idx: int + ) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]: data = self.data[idx] return np.array(data[:-1]).astype(self.dtype), np.array( data[-1:] ).astype(self.dtype) - def __len__(self): + def __len__(self) -> int: return len(self.data)