Skip to content

Commit

Permalink
[Typing][B-64] Add type annotations for `python/paddle/text/datasets/…
Browse files Browse the repository at this point in the history
…uci_housing.py` (PaddlePaddle#66057)


---------

Co-authored-by: SigureMo <sigure.qaq@gmail.com>
  • Loading branch information
2 people authored and lixcli committed Jul 22, 2024
1 parent a47c97b commit ba2378e
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions python/paddle/text/datasets/uci_housing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal

import numpy as np

import paddle
from paddle.dataset.common import _check_exists_and_download
from paddle.io import Dataset

if TYPE_CHECKING:
import numpy.typing as npt

from paddle._typing.dtype_like import _DTypeLiteral

_UciHousingDataSetMode = Literal["train", "test"]
__all__ = []

URL = 'http://paddlemodels.bj.bcebos.com/uci_housing/housing.data'
Expand Down Expand Up @@ -45,11 +54,11 @@ class UCIHousing(Dataset):
dataset
Args:
data_file(str): path to data file, can be set None if
:attr:`download` is True. Default None
data_file(str|None): path to data file, can be set None if
:attr:`download` is True. Default None.
mode(str): 'train' or 'test' mode. Default 'train'.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
:attr:`data_file` is not set. Default True.
Returns:
Dataset: instance of UCI housing dataset.
Expand Down Expand Up @@ -93,7 +102,16 @@ class UCIHousing(Dataset):
"""

def __init__(self, data_file=None, mode='train', download=True):
mode: _UciHousingDataSetMode
data_file: str | None
dtype: _DTypeLiteral

def __init__(
self,
data_file: str | None = None,
mode: _UciHousingDataSetMode = 'train',
download: bool = True,
) -> None:
assert mode.lower() in [
'train',
'test',
Expand All @@ -114,7 +132,7 @@ def __init__(self, data_file=None, mode='train', download=True):

self.dtype = paddle.get_default_dtype()

def _load_data(self, feature_num=14, ratio=0.8):
def _load_data(self, feature_num: int = 14, ratio: float = 0.8) -> None:
data = np.fromfile(self.data_file, sep=' ')
data = data.reshape(data.shape[0] // feature_num, feature_num)
maximums, minimums, avgs = (
Expand All @@ -130,11 +148,13 @@ def _load_data(self, feature_num=14, ratio=0.8):
elif self.mode == 'test':
self.data = data[offset:]

def __getitem__(self, idx):
def __getitem__(
self, idx: int
) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]:
data = self.data[idx]
return np.array(data[:-1]).astype(self.dtype), np.array(
data[-1:]
).astype(self.dtype)

def __len__(self):
def __len__(self) -> int:
return len(self.data)

0 comments on commit ba2378e

Please sign in to comment.