Skip to content

Commit

Permalink
[Typing][B-63] Add type annotations for `python/paddle/text/datasets/…
Browse files Browse the repository at this point in the history
…movielens.py` (#66054)
  • Loading branch information
enkilee authored Jul 19, 2024
1 parent 18e3c83 commit 0360515
Showing 1 changed file with 45 additions and 20 deletions.
65 changes: 45 additions & 20 deletions python/paddle/text/datasets/movielens.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import re
import zipfile
from typing import TYPE_CHECKING, Any, Literal

import numpy as np

from paddle.dataset.common import _check_exists_and_download
from paddle.io import Dataset

if TYPE_CHECKING:
import numpy.typing as npt

_MovieLensDataSetMode = Literal["train", "test"]
__all__ = []

age_table = [1, 18, 25, 35, 45, 50, 56]
Expand All @@ -33,7 +39,11 @@ class MovieInfo:
Movie id, title and categories information are stored in MovieInfo.
"""

def __init__(self, index, categories, title):
index: int
categories: list[str]
title: str

def __init__(self, index: str, categories: list[str], title: str) -> None:
self.index = int(index)
self.categories = categories
self.title = title
Expand All @@ -48,14 +58,14 @@ def value(self, categories_dict, movie_title_dict):
[movie_title_dict[w.lower()] for w in self.title.split()],
]

def __str__(self):
def __str__(self) -> str:
return "<MovieInfo id(%d), title(%s), categories(%s)>" % (
self.index,
self.title,
self.categories,
)

def __repr__(self):
def __repr__(self) -> str:
return self.__str__()


Expand All @@ -64,7 +74,12 @@ class UserInfo:
User id, gender, age, and job information are stored in UserInfo.
"""

def __init__(self, index, gender, age, job_id):
index: int
is_male: bool
age: int
job_id: int

def __init__(self, index: str, gender: str, age: str, job_id: str) -> None:
self.index = int(index)
self.is_male = gender == 'M'
self.age = age_table.index(int(age))
Expand All @@ -81,15 +96,15 @@ def value(self):
[self.job_id],
]

def __str__(self):
def __str__(self) -> str:
return "<UserInfo id(%d), gender(%s), age(%d), job(%d)>" % (
self.index,
"M" if self.is_male else "F",
age_table[self.age],
self.job_id,
)

def __repr__(self):
def __repr__(self) -> str:
return str(self)


Expand All @@ -98,16 +113,16 @@ class Movielens(Dataset):
Implementation of `Movielens 1-M <https://grouplens.org/datasets/movielens/1m/>`_ dataset.
Args:
data_file(str): path to data tar file, can be set None if
:attr:`download` is True. Default None
data_file(str|None): path to data tar file, can be set None if
:attr:`download` is True. Default None.
mode(str): 'train' or 'test' mode. Default 'train'.
test_ratio(float): split ratio for test sample. Default 0.1.
rand_seed(int): random seed. Default 0.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
:attr:`data_file` is not set. Default True.
Returns:
Dataset: instance of Movielens 1-M dataset
Dataset: instance of Movielens 1-M dataset.
Examples:
Expand Down Expand Up @@ -149,14 +164,24 @@ class Movielens(Dataset):
"""

mode: _MovieLensDataSetMode
data_file: str | None
test_ratio: float
rand_seed: int
movie_info: dict[int, MovieInfo]
movie_title_dict: dict[str, int]
categories_dict: dict[str, int]
user_info: dict[int, UserInfo]
data: list[list[float]]

def __init__(
self,
data_file=None,
mode='train',
test_ratio=0.1,
rand_seed=0,
download=True,
):
data_file: str | None = None,
mode: _MovieLensDataSetMode = 'train',
test_ratio: float = 0.1,
rand_seed: int = 0,
download: bool = True,
) -> None:
assert mode.lower() in [
'train',
'test',
Expand All @@ -179,7 +204,7 @@ def __init__(
self._load_meta_info()
self._load_data()

def _load_meta_info(self):
def _load_meta_info(self) -> None:
pattern = re.compile(r'^(.*)\((\d+)\)$')
self.movie_info = {}
self.movie_title_dict = {}
Expand Down Expand Up @@ -218,7 +243,7 @@ def _load_meta_info(self):
index=uid, gender=gender, age=age, job_id=job
)

def _load_data(self):
def _load_data(self) -> None:
self.data = []
is_test = self.mode == 'test'
with zipfile.ZipFile(self.data_file) as package:
Expand All @@ -241,9 +266,9 @@ def _load_data(self):
+ [[rating]]
)

def __getitem__(self, idx):
def __getitem__(self, idx: int) -> tuple[npt.NDArray[Any], ...]:
data = self.data[idx]
return tuple([np.array(d) for d in data])

def __len__(self):
def __len__(self) -> int:
return len(self.data)

0 comments on commit 0360515

Please sign in to comment.