Skip to content

Add ImageFolderWithoutTargets #47

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from .lsun import LSUN, LSUNClass
from .folder import ImageFolder
from .folder import ImageClassFolder, ImageFolder
from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100
from .mnist import MNIST

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100',
'MNIST')
__all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'ImageClassFolder',
'CocoCaptions', 'CocoDetection', 'CIFAR10', 'CIFAR100', 'MNIST')
53 changes: 46 additions & 7 deletions torchvision/datasets/folder.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,34 @@
import torch.utils.data as data

from PIL import Image
import os
import os.path

import torch.utils.data as data
from PIL import Image

IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
'.jpg',
'.JPG',
'.jpeg',
'.JPEG',
'.png',
'.PNG',
'.ppm',
'.PPM',
'.bmp',
'.BMP',
]


def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def find_classes(dir):
classes = os.listdir(dir)
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx


def make_dataset(dir, class_to_idx):
images = []
for target in os.listdir(dir):
Expand All @@ -38,8 +49,11 @@ def default_loader(path):
return Image.open(path).convert('RGB')


class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, target_transform=None,
class ImageClassFolder(data.Dataset):
def __init__(self,
root,
transform=None,
target_transform=None,
loader=default_loader):
classes, class_to_idx = find_classes(root)
imgs = make_dataset(root, class_to_idx)
Expand All @@ -64,3 +78,28 @@ def __getitem__(self, index):

def __len__(self):
return len(self.imgs)


class ImageFolder(data.Dataset):
""" ImageFolder can be used to load images where there are no labels."""

def __init__(self, root, transform=None, loader=default_loader):
images = []
for filename in os.listdir(root):
if is_image_file(filename):
images.append('{}'.format(filename))

self.root = root
self.imgs = images
self.transform = transform
self.loader = loader

def __getitem__(self, index):
filename = self.imgs[index]
img = self.loader(os.path.join(self.root, filename))
if self.transform is not None:
img = self.transform(img)
return img, filename

def __len__(self):
return len(self.imgs)