Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add argument img_folder to OpenFire #34

Merged
merged 6 commits into from
Sep 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions pyronear/datasets/openfire.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,21 @@ class OpenFire(VisionDataset):
threads (int, optional): If download is set to True, use this amount of threads
for downloading the dataset.
num_samples (int, optional): Number of samples to download (all by default)
img_folder (str or Path, optional): Location of image folder. Default: <root>/OpenFire/images
**kwargs: optional arguments of torchvision.datasets.VisionDataset
"""

url = 'https://gist.githubusercontent.com/frgfm/f53b4f53a1b2dc3bb4f18c006a32ec0d/raw/c0351134e333710c6ce0c631af5198e109ed7a92/openfire_binary.json' # noqa: E501
classes = [False, True]

def __init__(self, root, train=True, download=False, threads=None, num_samples=None, **kwargs):
def __init__(self, root, train=True, download=False, threads=None, num_samples=None,
img_folder=None, **kwargs):
super(OpenFire, self).__init__(root, **kwargs)
self.train = train
if img_folder is None:
self.img_folder = Path(self.root, self.__class__.__name__, 'images')
else:
self.img_folder = Path(img_folder)

if download:
self.download(threads, num_samples)
Expand All @@ -50,7 +56,7 @@ def __init__(self, root, train=True, download=False, threads=None, num_samples=N

@property
def _images(self):
return Path(self.root, self.__class__.__name__, 'images')
return self.img_folder

@property
def _annotations(self):
Expand Down
5 changes: 4 additions & 1 deletion references/classification/OpenFire/fastai/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def main(args):

# Aggregate path and labels into list for fastai ImageDataBunch
fnames, labels, is_valid = [], [], []
dataset = OpenFire(root=args.data_path, train=True, download=True)
dataset = OpenFire(root=args.data_path, train=True, download=True,
img_folder=args.img_folder)
for sample in dataset.data:
fnames.append(dataset._images.joinpath(sample['name']).relative_to(dataset.root))
labels.append(sample['target'])
Expand Down Expand Up @@ -128,6 +129,8 @@ def main(args):
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Input / Output
parser.add_argument('--data-path', default='./data', help='dataset root folder')
parser.add_argument('--img-folder', default=None,
help='Folder containing images. Default: <data_path>/OpenFire/images')
parser.add_argument('--checkpoint', default='checkpoint', type=str, help='name of output file')
parser.add_argument('--resume', default=None, help='checkpoint name to resume from')
# Architecture
Expand Down
6 changes: 4 additions & 2 deletions references/classification/OpenFire/torch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ def main(args):

# Train & test sets
train_set = OpenFire(root=args.data_path, train=True, download=True,
transform=data_transforms)
transform=data_transforms, img_folder=args.img_folder)
val_set = OpenFire(root=args.data_path, train=False, download=True,
transform=data_transforms)
transform=data_transforms, img_folder=args.img_folder)
num_classes = len(train_set.classes)
if args.binary:
if num_classes == 2:
Expand Down Expand Up @@ -264,6 +264,8 @@ def main(args):
# Input / Output
parser.add_argument('--data-path', default='./data', help='dataset root folder')
parser.add_argument('--resume', default=None, help='checkpoint file to resume from')
parser.add_argument('--img-folder', default=None,
help='Folder containing images. Default: <data_path>/OpenFire/images')
parser.add_argument('--output-dir', default=None, help='path for output saving')
parser.add_argument('--checkpoint', default=None, type=str, help='name of output file')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
Expand Down
19 changes: 15 additions & 4 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,21 @@ def test_downloadurls(self):
def test_openfire(self):
num_samples = 200

# Test img_folder argument: wrong type and default (None)
with tempfile.TemporaryDirectory() as root:
self.assertRaises(TypeError, datasets.OpenFire, root, download=True, img_folder=1)
ds = datasets.OpenFire(root=root, download=True, num_samples=num_samples,
img_folder=None)
self.assertIsInstance(ds.img_folder, Path)

with tempfile.TemporaryDirectory() as root, tempfile.TemporaryDirectory() as img_folder:

# Working case
train_set = datasets.OpenFire(root=root, train=True, download=True, num_samples=num_samples)
test_set = datasets.OpenFire(root=root, train=False, download=True, num_samples=num_samples)
# Test img_folder as Path and str
train_set = datasets.OpenFire(root=root, train=True, download=True, num_samples=num_samples,
img_folder=Path(img_folder))
test_set = datasets.OpenFire(root=root, train=False, download=True, num_samples=num_samples,
img_folder=img_folder)
# Check inherited properties
self.assertIsInstance(train_set, VisionDataset)

Expand All @@ -73,8 +83,9 @@ def test_openfire(self):
datasets.utils.download_url(train_set.url, root, filename='extract.json', verbose=False)
with open(Path(root).joinpath('extract.json'), 'rb') as f:
extract = json.load(f)[:num_samples]
# Uncomment when download issues are resolved
# self.assertEqual(len(train_set) + len(test_set), len(extract))
# Test if not more than 15 downloads failed.
# Change to assertEqual when download issues are resolved
self.assertAlmostEqual(len(train_set) + len(test_set), len(extract), delta=15)

# Check integrity of samples
img, target = train_set[0]
Expand Down