forked from open-mmlab/mmdetection
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refine auto Dataset, Nets, Losses, Metrics, Optimizers and Pipeline (o…
…pen-mmlab#38) * add comments * fix * refine dataset
- Loading branch information
Showing
24 changed files
with
343 additions
and
413 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,53 @@ | ||
from abc import abstractmethod | ||
|
||
from ..core import * | ||
|
||
__all__ = ['Dataset'] | ||
|
||
|
||
class Dataset(object): | ||
def __init__(self, name, train_path=None, val_path=None, batch_size=None, num_workers=None): | ||
class Dataset(BaseAutoObject): | ||
def __init__(self, name, train_path=None, val_path=None, batch_size=None, num_workers=None, | ||
transform_train_fn=None, transform_val_fn=None, | ||
transform_train_list=None, transform_val_list=None, **kwargs): | ||
# TODO (cgraywang): add search space, handle batch_size, num_workers | ||
super(Dataset, self).__init__() | ||
self.name = name | ||
self.train_path = train_path | ||
self.val_path = val_path | ||
self.batch_size = batch_size | ||
self.num_workers = num_workers | ||
self.search_space = None | ||
self.train = None | ||
self.val = None | ||
self.train_data = None | ||
self.val_data = None | ||
self.transform_train_fn = transform_train_fn | ||
self.transform_val_fn = transform_val_fn | ||
self.transform_train_list = transform_train_list | ||
self.transform_val_list = transform_val_list | ||
self._train = None | ||
self._val = None | ||
self._num_classes = None | ||
|
||
def _read_dataset(self): | ||
pass | ||
@property | ||
def train(self): | ||
return self._train | ||
|
||
def _set_search_space(self, cs): | ||
self.search_space = cs | ||
@train.setter | ||
def train(self, value): | ||
self._train = value | ||
|
||
def add_search_space(self): | ||
pass | ||
@property | ||
def val(self): | ||
return self._val | ||
|
||
def get_search_space(self): | ||
return self.search_space | ||
@val.setter | ||
def val(self, value): | ||
self._val = value | ||
|
||
@property | ||
def num_classes(self): | ||
return self._num_classes | ||
|
||
@num_classes.setter | ||
def num_classes(self, value): | ||
self._num_classes = value | ||
|
||
@abstractmethod | ||
def _read_dataset(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,43 @@ | ||
import ConfigSpace as CS | ||
|
||
import autogluon as ag | ||
from ..core import * | ||
from ..space import * | ||
from .utils import Loss | ||
|
||
__all__ = ['Losses'] | ||
|
||
|
||
class Losses(object): | ||
class Losses(BaseAutoObject): | ||
def __init__(self, loss_list): | ||
# TODO(cgraywang): add model instance, for now, use a list of model names | ||
# TODO(cgraywang): add instance, for now, use a list | ||
assert isinstance(loss_list, list), type(loss_list) | ||
super(Losses, self).__init__() | ||
self.loss_list = loss_list | ||
self.search_space = None | ||
self.add_search_space() | ||
self._add_search_space() | ||
|
||
def _set_search_space(self, cs): | ||
self.search_space = cs | ||
|
||
def add_search_space(self): | ||
def _add_search_space(self): | ||
cs = CS.ConfigurationSpace() | ||
# TODO (cgraywang): add more hparams for loss, e.g., weight | ||
loss_list_hyper_param = List('loss', | ||
choices=self.get_loss_strs()) \ | ||
.get_hyper_param() | ||
choices=self._get_search_space_strs()).get_hyper_param() | ||
cs.add_hyperparameter(loss_list_hyper_param) | ||
# TODO (cgraywang): do not add hyper-params for loss | ||
self._set_search_space(cs) | ||
|
||
def get_search_space(self): | ||
return self.search_space | ||
self.search_space = cs | ||
|
||
def get_loss_strs(self): | ||
def _get_search_space_strs(self): | ||
loss_strs = [] | ||
for loss in self.loss_list: | ||
if isinstance(loss, Loss): | ||
loss_strs.append(loss.name) | ||
elif isinstance(loss, str): | ||
loss_strs.append(loss) | ||
else: | ||
pass | ||
raise NotImplementedError | ||
return loss_strs | ||
|
||
def __repr__(self): | ||
return "AutoGluon Losses %s with %s" % (str(self.get_loss_strs()), str(self.search_space)) | ||
return "AutoGluon Losses %s with %s" % ( | ||
str(self._get_search_space_strs()), str(self.search_space)) | ||
|
||
def __str__(self): | ||
return "AutoGluon Losses %s with %s" % (str(self.get_loss_strs()), str(self.search_space)) | ||
return "AutoGluon Losses %s with %s" % ( | ||
str(self._get_search_space_strs()), str(self.search_space)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,50 +1,44 @@ | ||
import ConfigSpace as CS | ||
|
||
import autogluon as ag | ||
from ..core import * | ||
from ..space import * | ||
from .utils import Metric | ||
|
||
__all__ = ['Metrics'] | ||
|
||
|
||
class Metrics(object): | ||
class Metrics(BaseAutoObject): | ||
def __init__(self, metric_list): | ||
# TODO(cgraywang): add model instance, for now, use a list of model names | ||
# TODO(cgraywang): add instance, for now, use a list | ||
# TODO(cgraywang): use all option | ||
assert isinstance(metric_list, list), type(metric_list) | ||
super(Metrics, self).__init__() | ||
self.metric_list = metric_list | ||
self.search_space = None | ||
self.add_search_space() | ||
self._add_search_space() | ||
|
||
def _set_search_space(self, cs): | ||
self.search_space = cs | ||
|
||
def add_search_space(self): | ||
def _add_search_space(self): | ||
cs = CS.ConfigurationSpace() | ||
# TODO (cgraywang): add more hparams for metric, e.g., weight | ||
metric_list_hyper_param = List('metric', | ||
choices=self.get_metric_strs()) \ | ||
.get_hyper_param() | ||
choices=self._get_search_space_strs()).get_hyper_param() | ||
cs.add_hyperparameter(metric_list_hyper_param) | ||
# TODO (cgraywang): do not add hyper-params for metric | ||
self._set_search_space(cs) | ||
|
||
def get_search_space(self): | ||
return self.search_space | ||
self.search_space = cs | ||
|
||
def get_metric_strs(self): | ||
def _get_search_space_strs(self): | ||
metric_strs = [] | ||
for metric in self.metric_list: | ||
if isinstance(metric, Metric): | ||
metric_strs.append(metric.name) | ||
elif isinstance(metric, str): | ||
metric_strs.append(metric) | ||
else: | ||
pass | ||
raise NotImplementedError | ||
return metric_strs | ||
|
||
def __repr__(self): | ||
return "AutoGluon Metrics %s with %s" % ( | ||
str(self.get_metric_strs()), str(self.search_space)) | ||
str(self._get_search_space_strs()), str(self.search_space)) | ||
|
||
def __str__(self): | ||
return "AutoGluon Metrics %s with %s" % ( | ||
str(self.get_metric_strs()), str(self.search_space)) | ||
str(self._get_search_space_strs()), str(self.search_space)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.