diff --git a/autogluon/dataset/dataset.py b/autogluon/dataset/dataset.py index a7c9c5923c6..4d61da72d9a 100644 --- a/autogluon/dataset/dataset.py +++ b/autogluon/dataset/dataset.py @@ -2,13 +2,12 @@ class Dataset(object): - def __init__(self, train_path=None, val_path=None): + def __init__(self, name, train_path=None, val_path=None): # TODO (cgraywang): add search space, handle batch_size, num_workers + self.name = name self.train_path = train_path self.val_path = val_path - self._read_dataset() self.search_space = None - self.add_search_space() self.train_data = None self.val_data = None diff --git a/autogluon/task/image_classification/dataset.py b/autogluon/task/image_classification/dataset.py index f104e48f971..127320645d4 100644 --- a/autogluon/task/image_classification/dataset.py +++ b/autogluon/task/image_classification/dataset.py @@ -11,16 +11,12 @@ class Dataset(dataset.Dataset): - def __init__(self, train_path=None, val_path=None): + def __init__(self, name=None, train_path=None, val_path=None): + super(Dataset, self).__init__(name, train_path, val_path) # TODO (cgraywang): add search space, handle batch_size, num_workers - self.train_path = train_path - self.val_path = val_path - self.train_data = None - self.val_data = None self._num_classes = None self._read_dataset() - self.search_space = None self.add_search_space() @property @@ -46,7 +42,7 @@ def _read_dataset(self): [0.2023, 0.1994, 0.2010]) ]) - if 'CIFAR10' in self.train_path or 'CIFAR10' in self.val_path: + if self.name.lower() == 'cifar10': train_dataset = gluon.data.vision.CIFAR10(train=True) test_dataset = gluon.data.vision.CIFAR10(train=False) train_data = gluon.data.DataLoader(