Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][B-65,B-66] Add type annotations for python/paddle/text/datasets/{wmt14,wmt16}.py #66058

Merged
merged 2 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 56 additions & 10 deletions python/paddle/text/datasets/wmt14.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import tarfile
from typing import TYPE_CHECKING, Literal, overload

import numpy as np

from paddle.dataset.common import _check_exists_and_download
from paddle.io import Dataset

if TYPE_CHECKING:
import numpy.typing as npt

_Wmt14DataSetMode = Literal["train", "test", "gen"]
__all__ = []

URL_DEV_TEST = (
Expand All @@ -45,12 +51,12 @@ class WMT14(Dataset):
http://paddlemodels.bj.bcebos.com/wmt/wmt14.tgz .

Args:
data_file(str): path to data tar file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train', 'test' or 'gen'. Default 'train'
data_file(str|None): path to data tar file, can be set None if
:attr:`download` is True. Default None.
mode(str): 'train', 'test' or 'gen'. Default 'train'.
dict_size(int): word dictionary size. Default -1.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
:attr:`data_file` is not set. Default True.

Returns:
Dataset: Instance of WMT14 dataset
Expand Down Expand Up @@ -95,9 +101,22 @@ class WMT14(Dataset):

"""

mode: _Wmt14DataSetMode
data_file: str | None
dict_size: int
src_ids: list[list[int]]
trg_ids: list[list[int]]
trg_ids_next: list[list[int]]
src_dict: dict[str, int]
trg_dict: dict[str, int]

def __init__(
self, data_file=None, mode='train', dict_size=-1, download=True
):
self,
data_file: str | None = None,
mode: _Wmt14DataSetMode = 'train',
dict_size: int = -1,
download: bool = True,
) -> None:
assert mode.lower() in [
'train',
'test',
Expand All @@ -119,8 +138,8 @@ def __init__(
self.dict_size = dict_size
self._load_data()

def _load_data(self):
def __to_dict(fd, size):
def _load_data(self) -> None:
def __to_dict(fd, size: int) -> dict[str, int]:
out_dict = {}
for line_count, line in enumerate(fd):
if line_count < size:
Expand Down Expand Up @@ -181,16 +200,43 @@ def __to_dict(fd, size):
self.trg_ids.append(trg_ids)
self.trg_ids_next.append(trg_ids_next)

def __getitem__(self, idx):
def __getitem__(
self, idx: int
) -> tuple[
npt.NDArray[np.int_],
npt.NDArray[np.int_],
npt.NDArray[np.int_],
]:
return (
np.array(self.src_ids[idx]),
np.array(self.trg_ids[idx]),
np.array(self.trg_ids_next[idx]),
)

def __len__(self):
def __len__(self) -> int:
return len(self.src_ids)

@overload
def get_dict(
self, reverse: Literal[True] = ...
) -> tuple[dict[int, str], dict[int, str]]:
...

@overload
def get_dict(
self, reverse: Literal[False] = ...
) -> tuple[dict[str, int], dict[str, int]]:
...

@overload
def get_dict(
self, reverse: bool = ...
) -> (
tuple[dict[str, int], dict[str, int]]
| tuple[dict[int, str], dict[int, str]]
):
...

def get_dict(self, reverse=False):
"""
Get the source and target dictionary.
Expand Down
89 changes: 77 additions & 12 deletions python/paddle/text/datasets/wmt16.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import os
import tarfile
from collections import defaultdict
from typing import TYPE_CHECKING, Literal, overload

import numpy as np

import paddle
from paddle.dataset.common import _check_exists_and_download
from paddle.io import Dataset

if TYPE_CHECKING:
import numpy.typing as npt

_Wmt16DataSetMode = Literal["train", "test", "val"]
_Wmt16Language = Literal["en", "de"]
__all__ = []

DATA_URL = "http://paddlemodels.bj.bcebos.com/wmt/wmt16.tar.gz"
Expand Down Expand Up @@ -57,7 +64,7 @@ class WMT16(Dataset):
}

Args:
data_file(str): path to data tar file, can be set None if
data_file(str|None): path to data tar file, can be set None if
:attr:`download` is True. Default None.
mode(str): 'train', 'test' or 'val'. Default 'train'.
src_dict_size(int): word dictionary size for source language word. Default -1.
Expand Down Expand Up @@ -109,15 +116,26 @@ class WMT16(Dataset):
55 24 25
"""

mode: _Wmt16DataSetMode
data_file: str | None
lang: _Wmt16Language
src_dict_size: int
trg_dict_size: int
src_dict: dict[str, int]
trg_dict: dict[str, int]
src_ids: list[list[int]]
trg_ids: list[list[int]]
trg_ids_next: list[list[int]]

def __init__(
self,
data_file=None,
mode='train',
src_dict_size=-1,
trg_dict_size=-1,
lang='en',
download=True,
):
data_file: str | None = None,
mode: _Wmt16DataSetMode = 'train',
src_dict_size: int = -1,
trg_dict_size: int = -1,
lang: _Wmt16Language = 'en',
download: bool = True,
) -> None:
assert mode.lower() in [
'train',
'test',
Expand Down Expand Up @@ -153,6 +171,27 @@ def __init__(
# load data
self.data = self._load_data()

@overload
def _load_dict(
self, lang: _Wmt16Language, dict_size: int, reverse: Literal[True] = ...
) -> dict[int, str]:
...

@overload
def _load_dict(
self,
lang: _Wmt16Language,
dict_size: int,
reverse: Literal[False] = ...,
) -> dict[str, int]:
...

@overload
def _load_dict(
self, lang: _Wmt16Language, dict_size: int, reverse: bool = ...
) -> dict[int, str] | dict[str, int]:
...

def _load_dict(self, lang, dict_size, reverse=False):
dict_path = os.path.join(
paddle.dataset.common.DATA_HOME,
Expand All @@ -174,7 +213,9 @@ def _load_dict(self, lang, dict_size, reverse=False):
word_dict[line.strip().decode()] = idx
return word_dict

def _build_dict(self, dict_path, dict_size, lang):
def _build_dict(
self, dict_path: str, dict_size: int, lang: _Wmt16Language
) -> None:
word_dict = defaultdict(int)
with tarfile.open(self.data_file, mode="r") as f:
for line in f.extractfile("wmt16/train"):
Expand All @@ -196,7 +237,7 @@ def _build_dict(self, dict_path, dict_size, lang):
fout.write(word[0].encode())
fout.write(b'\n')

def _load_data(self):
def _load_data(self) -> None:
# the index for start mark, end mark, and unk are the same in source
# language and target language. Here uses the source language
# dictionary to determine their indices.
Expand Down Expand Up @@ -233,16 +274,40 @@ def _load_data(self):
self.trg_ids.append(trg_ids)
self.trg_ids_next.append(trg_ids_next)

def __getitem__(self, idx):
def __getitem__(
self, idx: int
) -> tuple[
npt.NDArray[np.int_],
npt.NDArray[np.int_],
npt.NDArray[np.int_],
]:
return (
np.array(self.src_ids[idx]),
np.array(self.trg_ids[idx]),
np.array(self.trg_ids_next[idx]),
)

def __len__(self):
def __len__(self) -> int:
return len(self.src_ids)

@overload
def get_dict(
self, lang: _Wmt16Language, reverse: Literal[True] = ...
) -> dict[int, str]:
...

@overload
def get_dict(
self, lang: _Wmt16Language, reverse: Literal[False] = ...
) -> dict[str, int]:
...

@overload
def get_dict(
self, lang: _Wmt16Language, reverse: bool = ...
) -> dict[int, str] | dict[str, int]:
...

def get_dict(self, lang, reverse=False):
"""
return the word dictionary for the specified language.
Expand Down