Skip to content

Commit

Permalink
[Typing][A-88] Add type annotations for `python/paddle/vision/dataset…
Browse files Browse the repository at this point in the history
…s/folder.py` (#65532)
  • Loading branch information
enkilee authored Jul 6, 2024
1 parent f3d8738 commit 9d25d4a
Showing 1 changed file with 64 additions and 26 deletions.
90 changes: 64 additions & 26 deletions python/paddle/vision/datasets/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence

from typing_extensions import TypeAlias

if TYPE_CHECKING:
from paddle._typing import _DTypeLiteral
from paddle.vision.transforms.transforms import _Transform

from ..image import _ImageDataType

_AllowedExtensions: TypeAlias = Literal[
'.jpg',
'.jpeg',
'.png',
'.ppm',
'.bmp',
'.pgm',
'.tif',
'.tiff',
'.webp',
]

import os

from PIL import Image
Expand All @@ -23,7 +47,7 @@
__all__ = []


def has_valid_extension(filename, extensions):
def has_valid_extension(filename: str, extensions: Sequence[str]) -> bool:
"""Checks if a file is a valid extension.
Args:
Expand Down Expand Up @@ -78,14 +102,14 @@ class DatasetFolder(Dataset):
Args:
root (str): Root directory path.
loader (Callable, optional): A function to load a sample given its path. Default: None.
extensions (list[str]|tuple[str], optional): A list of allowed extensions.
loader (Callable|None, optional): A function to load a sample given its path. Default: None.
extensions (list[str]|tuple[str]|None, optional): A list of allowed extensions.
Both :attr:`extensions` and :attr:`is_valid_file` should not be passed.
If this value is not set, the default is to use ('.jpg', '.jpeg', '.png',
'.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'). Default: None.
transform (Callable, optional): A function/transform that takes in
transform (Callable|None, optional): A function/transform that takes in
a sample and returns a transformed version. Default: None.
is_valid_file (Callable, optional): A function that takes path of a file
is_valid_file (Callable|None, optional): A function that takes path of a file
and check if the file is a valid file. Both :attr:`extensions` and
:attr:`is_valid_file` should not be passed. Default: None.
Expand Down Expand Up @@ -198,20 +222,29 @@ class DatasetFolder(Dataset):
>>> for img, label in iter(data_folder_2):
... # do something with img and label
... print(type(img), img.shape, label)
... print(type(img), img.shape, label) # type: ignore
... # <class 'paddle.Tensor'> [3, 64, 64] 0
>>> shutil.rmtree(fake_data_dir)
"""

loader: Callable[..., _ImageDataType] | None
extensions: Sequence[_AllowedExtensions] | None
transform: _Transform[Any, Any] | None
classes: list[str]
class_to_idx: dict[str, int]
samples: list[tuple[str, str]]
targets: list[str]
dtype: _DTypeLiteral

def __init__(
self,
root,
loader=None,
extensions=None,
transform=None,
is_valid_file=None,
):
root: str,
loader: Callable[..., _ImageDataType] | None = None,
extensions: Sequence[_AllowedExtensions] | None = None,
transform: _Transform[Any, Any] | None = None,
is_valid_file: _ImageDataType | None = None,
) -> None:
self.root = root
self.transform = transform
if extensions is None:
Expand All @@ -238,7 +271,7 @@ def __init__(

self.dtype = paddle.get_default_dtype()

def _find_classes(self, dir):
def _find_classes(self, dir: str) -> tuple[list[str], dict[str, int]]:
"""
Finds the class folders in a dataset.
Expand All @@ -255,7 +288,7 @@ def _find_classes(self, dir):
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx

def __getitem__(self, index):
def __getitem__(self, index: int) -> tuple[_ImageDataType, int]:
"""
Args:
index (int): Index
Expand Down Expand Up @@ -318,14 +351,14 @@ class ImageFolder(Dataset):
Args:
root (str): Root directory path.
loader (Callable, optional): A function to load a sample given its path. Default: None.
extensions (list[str]|tuple[str], optional): A list of allowed extensions.
loader (Callable|None, optional): A function to load a sample given its path. Default: None.
extensions (list[str]|tuple[str]|None, optional): A list of allowed extensions.
Both :attr:`extensions` and :attr:`is_valid_file` should not be passed.
If this value is not set, the default is to use ('.jpg', '.jpeg', '.png',
'.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'). Default: None.
transform (Callable, optional): A function/transform that takes in
transform (Callable|None, optional): A function/transform that takes in
a sample and returns a transformed version. Default: None.
is_valid_file (Callable, optional): A function that takes path of a file
is_valid_file (Callable|None, optional): A function that takes path of a file
and check if the file is a valid file. Both :attr:`extensions` and
:attr:`is_valid_file` should not be passed. Default: None.
Expand Down Expand Up @@ -425,20 +458,25 @@ class ImageFolder(Dataset):
>>> for (img,) in iter(image_folder_2):
... # do something with img
... print(type(img), img.shape)
... print(type(img), img.shape) # type: ignore
... # <class 'paddle.Tensor'> [3, 64, 64]
>>> shutil.rmtree(fake_data_dir)
"""

loader: Callable[..., _ImageDataType] | None
extensions: Sequence[_AllowedExtensions] | None
samples: list[str]
transform: _Transform[Any, Any] | None

def __init__(
self,
root,
loader=None,
extensions=None,
transform=None,
is_valid_file=None,
):
root: str,
loader: Callable[..., _ImageDataType] | None = None,
extensions: Sequence[_AllowedExtensions] | None = None,
transform: _Transform[Any, Any] | None = None,
is_valid_file: _ImageDataType | None = None,
) -> None:
self.root = root
if extensions is None:
extensions = IMG_EXTENSIONS
Expand Down Expand Up @@ -470,7 +508,7 @@ def is_valid_file(x):
self.samples = samples
self.transform = transform

def __getitem__(self, index):
def __getitem__(self, index: int) -> list[_ImageDataType]:
"""
Args:
index (int): Index
Expand Down

0 comments on commit 9d25d4a

Please sign in to comment.