From de680d201f6ff14b44e018b174077a4ed91bb228 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Wed, 30 Dec 2020 16:44:05 -0500 Subject: [PATCH 1/2] 'make_dataset' as staticmethod of 'DatasetFolder' --- torchvision/datasets/folder.py | 98 +++++++++++++++++----------------- 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 565a79c7a19..4ec1dd0a553 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -32,54 +32,6 @@ def is_image_file(filename: str) -> bool: return has_file_allowed_extension(filename, IMG_EXTENSIONS) -def make_dataset( - directory: str, - class_to_idx: Dict[str, int], - extensions: Optional[Tuple[str, ...]] = None, - is_valid_file: Optional[Callable[[str], bool]] = None, -) -> List[Tuple[str, int]]: - """Generates a list of samples of a form (path_to_sample, class). - - Args: - directory (str): root dataset directory - class_to_idx (Dict[str, int]): dictionary mapping class name to class index - extensions (optional): A list of allowed extensions. - Either extensions or is_valid_file should be passed. Defaults to None. - is_valid_file (optional): A function that takes path of a file - and checks if the file is a valid file - (used to check of corrupt files) both extensions and - is_valid_file should not be passed. Defaults to None. - - Raises: - ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. - - Returns: - List[Tuple[str, int]]: samples of a form (path_to_sample, class) - """ - instances = [] - directory = os.path.expanduser(directory) - both_none = extensions is None and is_valid_file is None - both_something = extensions is not None and is_valid_file is not None - if both_none or both_something: - raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") - if extensions is not None: - def is_valid_file(x: str) -> bool: - return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) - is_valid_file = cast(Callable[[str], bool], is_valid_file) - for target_class in sorted(class_to_idx.keys()): - class_index = class_to_idx[target_class] - target_dir = os.path.join(directory, target_class) - if not os.path.isdir(target_dir): - continue - for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): - for fname in sorted(fnames): - path = os.path.join(root, fname) - if is_valid_file(path): - item = path, class_index - instances.append(item) - return instances - - class DatasetFolder(VisionDataset): """A generic data loader where the samples are arranged in this way: :: @@ -124,7 +76,7 @@ def __init__( super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform) classes, class_to_idx = self._find_classes(self.root) - samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) + samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file) if len(samples) == 0: msg = "Found 0 files in subfolders of: {}\n".format(self.root) if extensions is not None: @@ -139,6 +91,54 @@ def __init__( self.samples = samples self.targets = [s[1] for s in samples] + @staticmethod + def make_dataset( + directory: str, + class_to_idx: Dict[str, int], + extensions: Optional[Tuple[str, ...]] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, + ) -> List[Tuple[str, int]]: + """Generates a list of samples of a form (path_to_sample, class). + + Args: + directory (str): root dataset directory + class_to_idx (Dict[str, int]): dictionary mapping class name to class index + extensions (optional): A list of allowed extensions. + Either extensions or is_valid_file should be passed. Defaults to None. + is_valid_file (optional): A function that takes path of a file + and checks if the file is a valid file + (used to check of corrupt files) both extensions and + is_valid_file should not be passed. Defaults to None. + + Raises: + ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. + + Returns: + List[Tuple[str, int]]: samples of a form (path_to_sample, class) + """ + instances = [] + directory = os.path.expanduser(directory) + both_none = extensions is None and is_valid_file is None + both_something = extensions is not None and is_valid_file is not None + if both_none or both_something: + raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") + if extensions is not None: + def is_valid_file(x: str) -> bool: + return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) + is_valid_file = cast(Callable[[str], bool], is_valid_file) + for target_class in sorted(class_to_idx.keys()): + class_index = class_to_idx[target_class] + target_dir = os.path.join(directory, target_class) + if not os.path.isdir(target_dir): + continue + for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): + for fname in sorted(fnames): + path = os.path.join(root, fname) + if is_valid_file(path): + item = path, class_index + instances.append(item) + return instances + def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: """ Finds the class folders in a dataset. From 41d18bed04fd33a10bca0ab72a44554e37dd293b Mon Sep 17 00:00:00 2001 From: ain-soph Date: Wed, 30 Dec 2020 16:54:29 -0500 Subject: [PATCH 2/2] a better fix --- torchvision/datasets/folder.py | 89 +++++++++++++++++++--------------- 1 file changed, 49 insertions(+), 40 deletions(-) diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 4ec1dd0a553..ef3ae7af896 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -32,6 +32,54 @@ def is_image_file(filename: str) -> bool: return has_file_allowed_extension(filename, IMG_EXTENSIONS) +def make_dataset( + directory: str, + class_to_idx: Dict[str, int], + extensions: Optional[Tuple[str, ...]] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, +) -> List[Tuple[str, int]]: + """Generates a list of samples of a form (path_to_sample, class). + + Args: + directory (str): root dataset directory + class_to_idx (Dict[str, int]): dictionary mapping class name to class index + extensions (optional): A list of allowed extensions. + Either extensions or is_valid_file should be passed. Defaults to None. + is_valid_file (optional): A function that takes path of a file + and checks if the file is a valid file + (used to check of corrupt files) both extensions and + is_valid_file should not be passed. Defaults to None. + + Raises: + ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. + + Returns: + List[Tuple[str, int]]: samples of a form (path_to_sample, class) + """ + instances = [] + directory = os.path.expanduser(directory) + both_none = extensions is None and is_valid_file is None + both_something = extensions is not None and is_valid_file is not None + if both_none or both_something: + raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") + if extensions is not None: + def is_valid_file(x: str) -> bool: + return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) + is_valid_file = cast(Callable[[str], bool], is_valid_file) + for target_class in sorted(class_to_idx.keys()): + class_index = class_to_idx[target_class] + target_dir = os.path.join(directory, target_class) + if not os.path.isdir(target_dir): + continue + for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): + for fname in sorted(fnames): + path = os.path.join(root, fname) + if is_valid_file(path): + item = path, class_index + instances.append(item) + return instances + + class DatasetFolder(VisionDataset): """A generic data loader where the samples are arranged in this way: :: @@ -98,46 +146,7 @@ def make_dataset( extensions: Optional[Tuple[str, ...]] = None, is_valid_file: Optional[Callable[[str], bool]] = None, ) -> List[Tuple[str, int]]: - """Generates a list of samples of a form (path_to_sample, class). - - Args: - directory (str): root dataset directory - class_to_idx (Dict[str, int]): dictionary mapping class name to class index - extensions (optional): A list of allowed extensions. - Either extensions or is_valid_file should be passed. Defaults to None. - is_valid_file (optional): A function that takes path of a file - and checks if the file is a valid file - (used to check of corrupt files) both extensions and - is_valid_file should not be passed. Defaults to None. - - Raises: - ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. - - Returns: - List[Tuple[str, int]]: samples of a form (path_to_sample, class) - """ - instances = [] - directory = os.path.expanduser(directory) - both_none = extensions is None and is_valid_file is None - both_something = extensions is not None and is_valid_file is not None - if both_none or both_something: - raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") - if extensions is not None: - def is_valid_file(x: str) -> bool: - return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) - is_valid_file = cast(Callable[[str], bool], is_valid_file) - for target_class in sorted(class_to_idx.keys()): - class_index = class_to_idx[target_class] - target_dir = os.path.join(directory, target_class) - if not os.path.isdir(target_dir): - continue - for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): - for fname in sorted(fnames): - path = os.path.join(root, fname) - if is_valid_file(path): - item = path, class_index - instances.append(item) - return instances + return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file) def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: """