Skip to content

Commit

Permalink
Merge pull request #191 from gheinrich/dev/issue_161
Browse files Browse the repository at this point in the history
Add option to select min/max samples per class
  • Loading branch information
lukeyeager committed Aug 6, 2015
2 parents bd13014 + fa39922 commit cb7ff66
Show file tree
Hide file tree
Showing 10 changed files with 341 additions and 24 deletions.
51 changes: 49 additions & 2 deletions digits/dataset/images/classification/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ..forms import ImageDatasetForm
from digits import utils
from digits.utils.forms import validate_required_iff
from digits.utils.forms import validate_required_iff, validate_greater_than

class ImageClassificationDatasetForm(ImageDatasetForm):
"""
Expand Down Expand Up @@ -72,6 +72,22 @@ def validate_folder_path(form, field):
]
)

folder_train_min_per_class = wtforms.IntegerField(u'Minimum samples per class',
default=2,
validators=[
validators.Optional(),
validators.NumberRange(min=1),
]
)

folder_train_max_per_class = wtforms.IntegerField(u'Maximum samples per class',
validators=[
validators.Optional(),
validators.NumberRange(min=1),
validate_greater_than('folder_train_min_per_class'),
]
)

has_val_folder = wtforms.BooleanField('Separate validation images folder',
default = False,
validators=[
Expand All @@ -84,7 +100,22 @@ def validate_folder_path(form, field):
validate_required_iff(
method='folder',
has_val_folder=True),
validate_folder_path,
]
)

folder_val_min_per_class = wtforms.IntegerField(u'Minimum samples per class',
default=2,
validators=[
validators.Optional(),
validators.NumberRange(min=1),
]
)

folder_val_max_per_class = wtforms.IntegerField(u'Maximum samples per class',
validators=[
validators.Optional(),
validators.NumberRange(min=1),
validate_greater_than('folder_val_min_per_class'),
]
)

Expand All @@ -104,6 +135,22 @@ def validate_folder_path(form, field):
]
)

folder_test_min_per_class = wtforms.IntegerField(u'Minimum samples per class',
default=2,
validators=[
validators.Optional(),
validators.NumberRange(min=1)
]
)

folder_test_max_per_class = wtforms.IntegerField(u'Maximum samples per class',
validators=[
validators.Optional(),
validators.NumberRange(min=1),
validate_greater_than('folder_test_min_per_class'),
]
)

### Method - textfile

textfile_use_local_files = wtforms.BooleanField(u'Use local files',
Expand Down
20 changes: 14 additions & 6 deletions digits/dataset/images/classification/test_imageset_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
IMAGE_COUNT = 10 # per category


def create_classification_imageset(folder, image_size=None, image_count=None):
def create_classification_imageset(folder, image_size=None, image_count=None, add_unbalanced_category=False):
"""
Creates a folder of folders of images for classification
If requested to add an unbalanced category then a category is added with
half the number of samples of other categories
"""
if image_size is None:
image_size = IMAGE_SIZE
Expand All @@ -29,11 +32,16 @@ def create_classification_imageset(folder, image_size=None, image_count=None):
# Stores the relative path of each image of the dataset
paths = defaultdict(list)

for class_name, pixel_index, rotation in [
('red-to-right', 0, 0),
('green-to-top', 1, 90),
('blue-to-left', 2, 180),
]:
config = [
('red-to-right', 0, 0, image_count),
('green-to-top', 1, 90, image_count),
('blue-to-left', 2, 180, image_count),
]

if add_unbalanced_category:
config.append( ('blue-to-bottom', 2, 270, image_count/2) )

for class_name, pixel_index, rotation, image_count in config:
os.makedirs(os.path.join(folder, class_name))

colors = np.linspace(200, 255, image_count)
Expand Down
132 changes: 131 additions & 1 deletion digits/dataset/images/classification/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def dataset_exists(cls, job_id):
def dataset_status(cls, job_id):
return cls.job_status(job_id, 'datasets')

@classmethod
def dataset_info(cls, job_id):
return cls.job_info(job_id, 'datasets')

@classmethod
def abort_dataset(cls, job_id):
return cls.abort_job(job_id, job_type='datasets')
Expand All @@ -62,12 +66,15 @@ class BaseViewsTestWithImageset(BaseViewsTest):
IMAGE_WIDTH = 10
IMAGE_CHANNELS = 3

UNBALANCED_CATEGORY = False

@classmethod
def setUpClass(cls):
super(BaseViewsTestWithImageset, cls).setUpClass()
cls.imageset_folder = tempfile.mkdtemp()
# create imageset
cls.imageset_paths = create_classification_imageset(cls.imageset_folder)
cls.imageset_paths = create_classification_imageset(cls.imageset_folder,
add_unbalanced_category=cls.UNBALANCED_CATEGORY)
cls.created_datasets = []

@classmethod
Expand Down Expand Up @@ -128,6 +135,10 @@ def create_dataset(cls, **kwargs):
cls.created_datasets.append(job_id)
return job_id

@classmethod
def categoryCount(cls):
return len(cls.imageset_paths.keys())

class BaseViewsTestWithDataset(BaseViewsTestWithImageset):
"""
Provides a dataset and some functions
Expand Down Expand Up @@ -249,6 +260,125 @@ def check_textfiles(self, absolute_path=True, local_path=True):
job_id = self.create_dataset(**data)
assert self.dataset_wait_completion(job_id) == 'Done', 'create failed'

class TestImageCount(BaseViewsTestWithImageset):

def test_image_count(self):
for type in ['train','val','test']:
yield self.check_image_count, type

def check_image_count(self, type):
data = {'folder_pct_val': 20,
'folder_pct_test': 10}
if type == 'val':
data['has_val_folder'] = 'True'
data['folder_val'] = self.imageset_folder
elif type == 'test':
data['has_test_folder'] = 'True'
data['folder_test'] = self.imageset_folder

job_id = self.create_dataset(**data)
assert self.dataset_wait_completion(job_id) == 'Done', 'create failed'
info = self.dataset_info(job_id)

if type == 'train':
assert len(info['ParseFolderTasks']) == 1, 'expected exactly one ParseFolderTasks'
parse_info = info['ParseFolderTasks'][0]
image_count = parse_info['train_count'] + parse_info['val_count'] + parse_info['test_count']
assert parse_info['val_count'] == 0.2 * image_count
assert parse_info['test_count'] == 0.1 * image_count
else:
assert len(info['ParseFolderTasks']) == 2, 'expected exactly one ParseFolderTasks'
parse_info = info['ParseFolderTasks'][1]
if type == 'val':
assert parse_info['train_count'] == 0
assert parse_info['test_count'] == 0
image_count = parse_info['val_count']
else:
assert parse_info['train_count'] == 0
assert parse_info['val_count'] == 0
image_count = parse_info['test_count']
assert self.categoryCount() == parse_info['label_count']
assert image_count == DUMMY_IMAGE_COUNT * parse_info['label_count'], 'image count mismatch'
assert self.delete_dataset(job_id) == 200, 'delete failed'
assert not self.dataset_exists(job_id), 'dataset exists after delete'

class TestMaxPerClass(BaseViewsTestWithImageset):
def test_max_per_class(self):
for type in ['train','val','test']:
yield self.check_max_per_class, type

def check_max_per_class(self, type):
# create dataset, asking for at most DUMMY_IMAGE_COUNT/2 images per class
assert DUMMY_IMAGE_COUNT%2 == 0
max_per_class = DUMMY_IMAGE_COUNT/2
data = {'folder_pct_val': 0}
if type == 'train':
data['folder_train_max_per_class'] = max_per_class
if type == 'val':
data['has_val_folder'] = 'True'
data['folder_val'] = self.imageset_folder
data['folder_val_max_per_class'] = max_per_class
elif type == 'test':
data['has_test_folder'] = 'True'
data['folder_test'] = self.imageset_folder
data['folder_test_max_per_class'] = max_per_class

job_id = self.create_dataset(**data)
assert self.dataset_wait_completion(job_id) == 'Done', 'create failed'
info = self.dataset_info(job_id)

if type == 'train':
assert len(info['ParseFolderTasks']) == 1, 'expected exactly one ParseFolderTasks'
parse_info = info['ParseFolderTasks'][0]
else:
assert len(info['ParseFolderTasks']) == 2, 'expected exactly one ParseFolderTasks'
parse_info = info['ParseFolderTasks'][1]

image_count = parse_info['train_count'] + parse_info['val_count'] + parse_info['test_count']
assert image_count == max_per_class * parse_info['label_count'], 'image count mismatch'
assert self.delete_dataset(job_id) == 200, 'delete failed'
assert not self.dataset_exists(job_id), 'dataset exists after delete'

class TestMinPerClass(BaseViewsTestWithImageset):

UNBALANCED_CATEGORY = True

def test_min_per_class(self):
for type in ['train','val','test']:
yield self.check_min_per_class, type

def check_min_per_class(self, type):
# create dataset, asking for one more image per class
# than available in the "unbalanced" category
min_per_class = DUMMY_IMAGE_COUNT/2+1
data = {'folder_pct_val': 0}
if type == 'train':
data['folder_train_min_per_class'] = min_per_class
if type == 'val':
data['has_val_folder'] = 'True'
data['folder_val'] = self.imageset_folder
data['folder_val_min_per_class'] = min_per_class
elif type == 'test':
data['has_test_folder'] = 'True'
data['folder_test'] = self.imageset_folder
data['folder_test_min_per_class'] = min_per_class

job_id = self.create_dataset(**data)
assert self.dataset_wait_completion(job_id) == 'Done', 'create failed'
info = self.dataset_info(job_id)

if type == 'train':
assert len(info['ParseFolderTasks']) == 1, 'expected exactly one ParseFolderTasks'
parse_info = info['ParseFolderTasks'][0]
else:
assert len(info['ParseFolderTasks']) == 2, 'expected exactly two ParseFolderTasks'
parse_info = info['ParseFolderTasks'][1]

assert self.categoryCount() == parse_info['label_count']+1
assert self.delete_dataset(job_id) == 200, 'delete failed'
assert not self.dataset_exists(job_id), 'dataset exists after delete'


class TestCreated(BaseViewsTestWithDataset):
"""
Tests on a dataset that has already been created
Expand Down
23 changes: 19 additions & 4 deletions digits/dataset/images/classification/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,16 @@ def from_folders(job, form):
if form.has_test_folder.data:
percent_test = 0

min_per_class = form.folder_train_min_per_class.data
max_per_class = form.folder_train_max_per_class.data

parse_train_task = tasks.ParseFolderTask(
job_dir = job.dir(),
folder = form.folder_train.data,
percent_val = percent_val,
percent_test = percent_test,
job_dir = job.dir(),
folder = form.folder_train.data,
percent_val = percent_val,
percent_test = percent_test,
min_per_category = min_per_class if min_per_class>0 else 1,
max_per_category = max_per_class if max_per_class>0 else None
)
job.tasks.append(parse_train_task)

Expand All @@ -46,23 +51,33 @@ def from_folders(job, form):
test_parents = [parse_train_task]

if form.has_val_folder.data:
min_per_class = form.folder_val_min_per_class.data
max_per_class = form.folder_val_max_per_class.data

parse_val_task = tasks.ParseFolderTask(
job_dir = job.dir(),
parents = parse_train_task,
folder = form.folder_val.data,
percent_val = 100,
percent_test = 0,
min_per_category = min_per_class if min_per_class>0 else 1,
max_per_category = max_per_class if max_per_class>0 else None
)
job.tasks.append(parse_val_task)
val_parents = [parse_val_task]

if form.has_test_folder.data:
min_per_class = form.folder_test_min_per_class.data
max_per_class = form.folder_test_max_per_class.data

parse_test_task = tasks.ParseFolderTask(
job_dir = job.dir(),
parents = parse_train_task,
folder = form.folder_test.data,
percent_val = 0,
percent_test = 100,
min_per_category = min_per_class if min_per_class>0 else 1,
max_per_category = max_per_class if max_per_class>0 else None
)
job.tasks.append(parse_test_task)
test_parents = [parse_test_task]
Expand Down
17 changes: 17 additions & 0 deletions digits/dataset/job.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) 2014-2015, NVIDIA CORPORATION. All rights reserved.

from digits.job import Job
from digits.utils import subclass, override
from . import tasks

# NOTE: Increment this everytime the pickled object changes
PICKLE_VERSION = 1

@subclass
class DatasetJob(Job):
"""
A Job that creates a dataset
Expand All @@ -17,6 +19,21 @@ def __init__(self, **kwargs):
super(DatasetJob, self).__init__(**kwargs)
self.pickver_job_dataset = PICKLE_VERSION

@override
def json_dict(self, verbose=False):
d = super(DatasetJob, self).json_dict(verbose)

if verbose:
d.update({
'ParseFolderTasks': [{"name": t.name(),
"label_count": t.label_count,
"train_count": t.train_count,
"val_count": t.val_count,
"test_count": t.test_count,
} for t in self.parse_folder_tasks()],
})
return d

def parse_folder_tasks(self):
"""
Return all ParseFolderTasks for this job
Expand Down
Loading

0 comments on commit cb7ff66

Please sign in to comment.