Skip to content

Commit

Permalink
Demo patch 1 (open-mmlab#17)
Browse files Browse the repository at this point in the history
* mv dataset inside

* patch
  • Loading branch information
zhanghang1989 authored May 22, 2019
1 parent c1dd7d8 commit 6be0d44
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 10 deletions.
5 changes: 2 additions & 3 deletions autogluon/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@


class Dataset(object):
def __init__(self, train_path=None, val_path=None):
def __init__(self, name, train_path=None, val_path=None):
# TODO (cgraywang): add search space, handle batch_size, num_workers
self.name = name
self.train_path = train_path
self.val_path = val_path
self._read_dataset()
self.search_space = None
self.add_search_space()
self.train_data = None
self.val_data = None

Expand Down
10 changes: 3 additions & 7 deletions autogluon/task/image_classification/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,12 @@


class Dataset(dataset.Dataset):
def __init__(self, train_path=None, val_path=None):
def __init__(self, name=None, train_path=None, val_path=None):
super(Dataset, self).__init__(name, train_path, val_path)
# TODO (cgraywang): add search space, handle batch_size, num_workers

self.train_path = train_path
self.val_path = val_path
self.train_data = None
self.val_data = None
self._num_classes = None
self._read_dataset()
self.search_space = None
self.add_search_space()

@property
Expand All @@ -46,7 +42,7 @@ def _read_dataset(self):
[0.2023, 0.1994, 0.2010])
])

if 'CIFAR10' in self.train_path or 'CIFAR10' in self.val_path:
if self.name.lower() == 'cifar10':
train_dataset = gluon.data.vision.CIFAR10(train=True)
test_dataset = gluon.data.vision.CIFAR10(train=False)
train_data = gluon.data.DataLoader(
Expand Down

0 comments on commit 6be0d44

Please sign in to comment.