Skip to content

Commit

Permalink
feat: add img-folder argument to training scripts
Browse files Browse the repository at this point in the history
Adding argument --img-folder to fastai/train.py and torch/train.py
  • Loading branch information
Bruno Lenzi authored and Bruno Lenzi committed Oct 23, 2019
1 parent bce369a commit e103f02
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
8 changes: 6 additions & 2 deletions references/classification/fastai/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ def main(args):

# Aggregate path and labels into list for fastai ImageDataBunch
fnames, labels, is_valid = [], [], []
for sample in OpenFire(root=args.data_path, train=True, download=True).data:
for sample in OpenFire(root=args.data_path, train=True, download=True,
img_folder=args.img_folder).data:
fnames.append(sample['path'])
labels.append(sample['target'])
is_valid.append(False)
for sample in OpenFire(root=args.data_path, train=False, download=True).data:
for sample in OpenFire(root=args.data_path, train=False, download=True,
img_folder=args.img_folder).data:
fnames.append(sample['path'])
labels.append(sample['target'])
is_valid.append(True)
Expand Down Expand Up @@ -61,6 +63,8 @@ def main(args):
import argparse
parser = argparse.ArgumentParser(description='PyroNear Classification Training with Fastai')
parser.add_argument('--data-path', default='./data', help='dataset')
parser.add_argument('--img-folder', default=None,
help='Folder containing images. Default: <data_path>/OpenFire/images')
parser.add_argument('--model', default='resnet18', type=str, help='model')
parser.add_argument('--device', default='cuda', help='device')
parser.add_argument('-b', '--batch-size', default=32, type=int)
Expand Down
6 changes: 4 additions & 2 deletions references/classification/torch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ def main(args):

# Train & test sets
train_set = OpenFire(root=args.data_path, train=True, download=True,
transform=train_transforms)
transform=train_transforms, img_folder=args.img_folder)
val_set = OpenFire(root=args.data_path, train=False, download=True,
transform=test_transforms)
transform=test_transforms, img_folder=args.img_folder)
num_classes = len(train_set.classes)
# Samplers
train_sampler = torch.utils.data.RandomSampler(train_set)
Expand Down Expand Up @@ -230,6 +230,8 @@ def main(args):
import argparse
parser = argparse.ArgumentParser(description='PyroNear Classification Training')
parser.add_argument('--data-path', default='./data', help='dataset')
parser.add_argument('--img-folder', default=None,
help='Folder containing images. Default: <data_path>/OpenFire/images')
parser.add_argument('--model', default='resnet18', help='model')
parser.add_argument('--device', default=None, help='device')
parser.add_argument('-b', '--batch-size', default=32, type=int)
Expand Down

0 comments on commit e103f02

Please sign in to comment.