From e268c925828e229a198ceca504f025aa0deca856 Mon Sep 17 00:00:00 2001 From: enkilee Date: Mon, 15 Jul 2024 16:07:09 +0800 Subject: [PATCH 1/2] fix --- python/paddle/text/datasets/wmt14.py | 48 +++++++++++++++++------ python/paddle/text/datasets/wmt16.py | 58 +++++++++++++++++++++------- 2 files changed, 81 insertions(+), 25 deletions(-) diff --git a/python/paddle/text/datasets/wmt14.py b/python/paddle/text/datasets/wmt14.py index a2d7c9ebe5871..37bf657c717f4 100644 --- a/python/paddle/text/datasets/wmt14.py +++ b/python/paddle/text/datasets/wmt14.py @@ -11,14 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import tarfile +from typing import TYPE_CHECKING, Literal import numpy as np from paddle.dataset.common import _check_exists_and_download from paddle.io import Dataset +if TYPE_CHECKING: + import numpy.typing as npt + + _Wmt14DataSetMode = Literal["train", "test", "gen"] __all__ = [] URL_DEV_TEST = ( @@ -45,12 +51,12 @@ class WMT14(Dataset): http://paddlemodels.bj.bcebos.com/wmt/wmt14.tgz . Args: - data_file(str): path to data tar file, can be set None if - :attr:`download` is True. Default None - mode(str): 'train', 'test' or 'gen'. Default 'train' + data_file(str|None): path to data tar file, can be set None if + :attr:`download` is True. Default None. + mode(str): 'train', 'test' or 'gen'. Default 'train'. dict_size(int): word dictionary size. Default -1. download(bool): whether to download dataset automatically if - :attr:`data_file` is not set. Default True + :attr:`data_file` is not set. Default True. Returns: Dataset: Instance of WMT14 dataset @@ -95,9 +101,21 @@ class WMT14(Dataset): """ + mode: _Wmt14DataSetMode + data_file: str | None + dict_size: int + out_dict: dict[str, int] + src_ids: list[npt.NDArray[np.int_]] + trg_ids: list[npt.NDArray[np.int_]] + trg_ids_next: list[npt.NDArray[np.int_]] + def __init__( - self, data_file=None, mode='train', dict_size=-1, download=True - ): + self, + data_file: str | None = None, + mode: _Wmt14DataSetMode = 'train', + dict_size: int = -1, + download: bool = True, + ) -> None: assert mode.lower() in [ 'train', 'test', @@ -119,8 +137,8 @@ def __init__( self.dict_size = dict_size self._load_data() - def _load_data(self): - def __to_dict(fd, size): + def _load_data(self) -> None: + def __to_dict(fd, size: int) -> dict: out_dict = {} for line_count, line in enumerate(fd): if line_count < size: @@ -181,17 +199,25 @@ def __to_dict(fd, size): self.trg_ids.append(trg_ids) self.trg_ids_next.append(trg_ids_next) - def __getitem__(self, idx): + def __getitem__( + self, idx: int + ) -> tuple[ + npt.NDArray[np.int_], + npt.NDArray[np.int_], + npt.NDArray[np.int_], + ]: return ( np.array(self.src_ids[idx]), np.array(self.trg_ids[idx]), np.array(self.trg_ids_next[idx]), ) - def __len__(self): + def __len__(self) -> int: return len(self.src_ids) - def get_dict(self, reverse=False): + def get_dict( + self, reverse: bool = False + ) -> tuple[dict[str, int], dict[int, str]]: """ Get the source and target dictionary. diff --git a/python/paddle/text/datasets/wmt16.py b/python/paddle/text/datasets/wmt16.py index b8bcb98bef13a..fd84bdcbbb91f 100644 --- a/python/paddle/text/datasets/wmt16.py +++ b/python/paddle/text/datasets/wmt16.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import os import tarfile from collections import defaultdict +from typing import TYPE_CHECKING, Literal import numpy as np @@ -23,6 +25,11 @@ from paddle.dataset.common import _check_exists_and_download from paddle.io import Dataset +if TYPE_CHECKING: + import numpy.typing as npt + + _Wmt16DataSetMode = Literal["train", "test", "val"] + _Wmt16Language = Literal["en", "de"] __all__ = [] DATA_URL = "http://paddlemodels.bj.bcebos.com/wmt/wmt16.tar.gz" @@ -57,7 +64,7 @@ class WMT16(Dataset): } Args: - data_file(str): path to data tar file, can be set None if + data_file(str|None): path to data tar file, can be set None if :attr:`download` is True. Default None. mode(str): 'train', 'test' or 'val'. Default 'train'. src_dict_size(int): word dictionary size for source language word. Default -1. @@ -109,15 +116,26 @@ class WMT16(Dataset): 55 24 25 """ + mode: _Wmt16DataSetMode + data_file: str | None + lang: _Wmt16Language + src_dict_size: int + trg_dict_size: int + src_dict: dict[str, int] + trg_dict: dict[str, int] + src_ids: list[npt.NDArray[np.int_]] + trg_ids: list[npt.NDArray[np.int_]] + trg_ids_next: list[npt.NDArray[np.int_]] + def __init__( self, - data_file=None, - mode='train', - src_dict_size=-1, - trg_dict_size=-1, - lang='en', - download=True, - ): + data_file: str | None = None, + mode: _Wmt16DataSetMode = 'train', + src_dict_size: int = -1, + trg_dict_size: int = -1, + lang: _Wmt16Language = 'en', + download: bool = True, + ) -> None: assert mode.lower() in [ 'train', 'test', @@ -153,7 +171,9 @@ def __init__( # load data self.data = self._load_data() - def _load_dict(self, lang, dict_size, reverse=False): + def _load_dict( + self, lang: _Wmt16Language, dict_size: int, reverse: bool = False + ) -> dict[str, int] | dict[int, str]: dict_path = os.path.join( paddle.dataset.common.DATA_HOME, "wmt16/%s_%d.dict" % (lang, dict_size), @@ -174,7 +194,9 @@ def _load_dict(self, lang, dict_size, reverse=False): word_dict[line.strip().decode()] = idx return word_dict - def _build_dict(self, dict_path, dict_size, lang): + def _build_dict( + self, dict_path: str, dict_size: int, lang: _Wmt16Language + ) -> None: word_dict = defaultdict(int) with tarfile.open(self.data_file, mode="r") as f: for line in f.extractfile("wmt16/train"): @@ -196,7 +218,7 @@ def _build_dict(self, dict_path, dict_size, lang): fout.write(word[0].encode()) fout.write(b'\n') - def _load_data(self): + def _load_data(self) -> None: # the index for start mark, end mark, and unk are the same in source # language and target language. Here uses the source language # dictionary to determine their indices. @@ -233,17 +255,25 @@ def _load_data(self): self.trg_ids.append(trg_ids) self.trg_ids_next.append(trg_ids_next) - def __getitem__(self, idx): + def __getitem__( + self, idx: int + ) -> tuple[ + npt.NDArray[np.int_], + npt.NDArray[np.int_], + npt.NDArray[np.int_], + ]: return ( np.array(self.src_ids[idx]), np.array(self.trg_ids[idx]), np.array(self.trg_ids_next[idx]), ) - def __len__(self): + def __len__(self) -> int: return len(self.src_ids) - def get_dict(self, lang, reverse=False): + def get_dict( + self, lang: _Wmt16Language, reverse: bool = False + ) -> dict[str, int] | dict[int, str]: """ return the word dictionary for the specified language. From b45effff3a0ec4461ca9a5538f67346e7b8c6939 Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 16 Jul 2024 14:52:35 +0800 Subject: [PATCH 2/2] fix --- python/paddle/text/datasets/wmt14.py | 36 +++++++++++++++----- python/paddle/text/datasets/wmt16.py | 51 +++++++++++++++++++++++----- 2 files changed, 71 insertions(+), 16 deletions(-) diff --git a/python/paddle/text/datasets/wmt14.py b/python/paddle/text/datasets/wmt14.py index 37bf657c717f4..c499aeeb9e8c0 100644 --- a/python/paddle/text/datasets/wmt14.py +++ b/python/paddle/text/datasets/wmt14.py @@ -14,7 +14,7 @@ from __future__ import annotations import tarfile -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, overload import numpy as np @@ -104,10 +104,11 @@ class WMT14(Dataset): mode: _Wmt14DataSetMode data_file: str | None dict_size: int - out_dict: dict[str, int] - src_ids: list[npt.NDArray[np.int_]] - trg_ids: list[npt.NDArray[np.int_]] - trg_ids_next: list[npt.NDArray[np.int_]] + src_ids: list[list[int]] + trg_ids: list[list[int]] + trg_ids_next: list[list[int]] + src_dict: dict[str, int] + trg_dict: dict[str, int] def __init__( self, @@ -138,7 +139,7 @@ def __init__( self._load_data() def _load_data(self) -> None: - def __to_dict(fd, size: int) -> dict: + def __to_dict(fd, size: int) -> dict[str, int]: out_dict = {} for line_count, line in enumerate(fd): if line_count < size: @@ -215,9 +216,28 @@ def __getitem__( def __len__(self) -> int: return len(self.src_ids) + @overload def get_dict( - self, reverse: bool = False - ) -> tuple[dict[str, int], dict[int, str]]: + self, reverse: Literal[True] = ... + ) -> tuple[dict[int, str], dict[int, str]]: + ... + + @overload + def get_dict( + self, reverse: Literal[False] = ... + ) -> tuple[dict[str, int], dict[str, int]]: + ... + + @overload + def get_dict( + self, reverse: bool = ... + ) -> ( + tuple[dict[str, int], dict[str, int]] + | tuple[dict[int, str], dict[int, str]] + ): + ... + + def get_dict(self, reverse=False): """ Get the source and target dictionary. diff --git a/python/paddle/text/datasets/wmt16.py b/python/paddle/text/datasets/wmt16.py index fd84bdcbbb91f..7df0788c10a56 100644 --- a/python/paddle/text/datasets/wmt16.py +++ b/python/paddle/text/datasets/wmt16.py @@ -17,7 +17,7 @@ import os import tarfile from collections import defaultdict -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, overload import numpy as np @@ -123,9 +123,9 @@ class WMT16(Dataset): trg_dict_size: int src_dict: dict[str, int] trg_dict: dict[str, int] - src_ids: list[npt.NDArray[np.int_]] - trg_ids: list[npt.NDArray[np.int_]] - trg_ids_next: list[npt.NDArray[np.int_]] + src_ids: list[list[int]] + trg_ids: list[list[int]] + trg_ids_next: list[list[int]] def __init__( self, @@ -171,9 +171,28 @@ def __init__( # load data self.data = self._load_data() + @overload def _load_dict( - self, lang: _Wmt16Language, dict_size: int, reverse: bool = False - ) -> dict[str, int] | dict[int, str]: + self, lang: _Wmt16Language, dict_size: int, reverse: Literal[True] = ... + ) -> dict[int, str]: + ... + + @overload + def _load_dict( + self, + lang: _Wmt16Language, + dict_size: int, + reverse: Literal[False] = ..., + ) -> dict[str, int]: + ... + + @overload + def _load_dict( + self, lang: _Wmt16Language, dict_size: int, reverse: bool = ... + ) -> dict[int, str] | dict[str, int]: + ... + + def _load_dict(self, lang, dict_size, reverse=False): dict_path = os.path.join( paddle.dataset.common.DATA_HOME, "wmt16/%s_%d.dict" % (lang, dict_size), @@ -271,9 +290,25 @@ def __getitem__( def __len__(self) -> int: return len(self.src_ids) + @overload def get_dict( - self, lang: _Wmt16Language, reverse: bool = False - ) -> dict[str, int] | dict[int, str]: + self, lang: _Wmt16Language, reverse: Literal[True] = ... + ) -> dict[int, str]: + ... + + @overload + def get_dict( + self, lang: _Wmt16Language, reverse: Literal[False] = ... + ) -> dict[str, int]: + ... + + @overload + def get_dict( + self, lang: _Wmt16Language, reverse: bool = ... + ) -> dict[int, str] | dict[str, int]: + ... + + def get_dict(self, lang, reverse=False): """ return the word dictionary for the specified language.