From 036051538e7476b86908c12611dbf7c296528e7c Mon Sep 17 00:00:00 2001 From: cyberslack_lee Date: Fri, 19 Jul 2024 21:54:21 +0800 Subject: [PATCH] [Typing][B-63] Add type annotations for `python/paddle/text/datasets/movielens.py` (#66054) --- python/paddle/text/datasets/movielens.py | 65 ++++++++++++++++-------- 1 file changed, 45 insertions(+), 20 deletions(-) diff --git a/python/paddle/text/datasets/movielens.py b/python/paddle/text/datasets/movielens.py index 55572d33f8387..405a31aca83e2 100644 --- a/python/paddle/text/datasets/movielens.py +++ b/python/paddle/text/datasets/movielens.py @@ -11,15 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import re import zipfile +from typing import TYPE_CHECKING, Any, Literal import numpy as np from paddle.dataset.common import _check_exists_and_download from paddle.io import Dataset +if TYPE_CHECKING: + import numpy.typing as npt + + _MovieLensDataSetMode = Literal["train", "test"] __all__ = [] age_table = [1, 18, 25, 35, 45, 50, 56] @@ -33,7 +39,11 @@ class MovieInfo: Movie id, title and categories information are stored in MovieInfo. """ - def __init__(self, index, categories, title): + index: int + categories: list[str] + title: str + + def __init__(self, index: str, categories: list[str], title: str) -> None: self.index = int(index) self.categories = categories self.title = title @@ -48,14 +58,14 @@ def value(self, categories_dict, movie_title_dict): [movie_title_dict[w.lower()] for w in self.title.split()], ] - def __str__(self): + def __str__(self) -> str: return "" % ( self.index, self.title, self.categories, ) - def __repr__(self): + def __repr__(self) -> str: return self.__str__() @@ -64,7 +74,12 @@ class UserInfo: User id, gender, age, and job information are stored in UserInfo. """ - def __init__(self, index, gender, age, job_id): + index: int + is_male: bool + age: int + job_id: int + + def __init__(self, index: str, gender: str, age: str, job_id: str) -> None: self.index = int(index) self.is_male = gender == 'M' self.age = age_table.index(int(age)) @@ -81,7 +96,7 @@ def value(self): [self.job_id], ] - def __str__(self): + def __str__(self) -> str: return "" % ( self.index, "M" if self.is_male else "F", @@ -89,7 +104,7 @@ def __str__(self): self.job_id, ) - def __repr__(self): + def __repr__(self) -> str: return str(self) @@ -98,16 +113,16 @@ class Movielens(Dataset): Implementation of `Movielens 1-M `_ dataset. Args: - data_file(str): path to data tar file, can be set None if - :attr:`download` is True. Default None + data_file(str|None): path to data tar file, can be set None if + :attr:`download` is True. Default None. mode(str): 'train' or 'test' mode. Default 'train'. test_ratio(float): split ratio for test sample. Default 0.1. rand_seed(int): random seed. Default 0. download(bool): whether to download dataset automatically if - :attr:`data_file` is not set. Default True + :attr:`data_file` is not set. Default True. Returns: - Dataset: instance of Movielens 1-M dataset + Dataset: instance of Movielens 1-M dataset. Examples: @@ -149,14 +164,24 @@ class Movielens(Dataset): """ + mode: _MovieLensDataSetMode + data_file: str | None + test_ratio: float + rand_seed: int + movie_info: dict[int, MovieInfo] + movie_title_dict: dict[str, int] + categories_dict: dict[str, int] + user_info: dict[int, UserInfo] + data: list[list[float]] + def __init__( self, - data_file=None, - mode='train', - test_ratio=0.1, - rand_seed=0, - download=True, - ): + data_file: str | None = None, + mode: _MovieLensDataSetMode = 'train', + test_ratio: float = 0.1, + rand_seed: int = 0, + download: bool = True, + ) -> None: assert mode.lower() in [ 'train', 'test', @@ -179,7 +204,7 @@ def __init__( self._load_meta_info() self._load_data() - def _load_meta_info(self): + def _load_meta_info(self) -> None: pattern = re.compile(r'^(.*)\((\d+)\)$') self.movie_info = {} self.movie_title_dict = {} @@ -218,7 +243,7 @@ def _load_meta_info(self): index=uid, gender=gender, age=age, job_id=job ) - def _load_data(self): + def _load_data(self) -> None: self.data = [] is_test = self.mode == 'test' with zipfile.ZipFile(self.data_file) as package: @@ -241,9 +266,9 @@ def _load_data(self): + [[rating]] ) - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> tuple[npt.NDArray[Any], ...]: data = self.data[idx] return tuple([np.array(d) for d in data]) - def __len__(self): + def __len__(self) -> int: return len(self.data)