Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor spaces #54

Merged
merged 4 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/mmtune/mmcls_cifar_100_asynchb_nevergrad_pso.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

space = {
'model': {{_base_.model}},
'model.head.num_classes': dict(type='Constant', value=100),
'model.head.num_classes': 100,
'optimizer': {{_base_.optimizer}},
'data.samples_per_gpu': {{_base_.batch_size}},
}
Expand Down
4 changes: 2 additions & 2 deletions configs/mmtune/mmseg_asynchb_nevergrad_pso.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
'model': {{_base_.model}},
'optimizer': {{_base_.optimizer}},
'data.samples_per_gpu': {{_base_.batch_size}},
'model.decode_head.num_classes': dict(type='Constant', value=21),
'model.auxiliary_head.num_classes': dict(type='Constant', value=21),
'model.decode_head.num_classes': 21,
'model.auxiliary_head.num_classes': 21,
}

task = dict(type='MMSegmentation')
Expand Down
12 changes: 9 additions & 3 deletions mmtune/ray/spaces/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from .base import BaseSpace
from .base import (BaseSpace, Lograndint, Loguniform, Qlograndint, Qloguniform,
Qrandint, Qrandn, Quniform, Randint, Randn, Uniform)
from .builder import SPACES, build_space
from .choice import Choice
from .const import Constant
from .grid_search import GridSearch
from .sample_from import SampleFrom

__all__ = ['BaseSpace', 'SPACES', 'build_space', 'Choice', 'Constant']
__all__ = [
'BaseSpace', 'Uniform', 'Quniform', 'Loguniform', 'Qloguniform', 'Randn',
'Qrandn', 'Randint', 'Qrandint', 'Lograndint', 'Qlograndint', 'SPACES',
'build_space', 'Choice', 'GridSearch', 'SampleFrom'
]
63 changes: 61 additions & 2 deletions mmtune/ray/spaces/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,69 @@
from abc import ABCMeta
from typing import Callable

import ray.tune as tune

from .builder import SPACES


class BaseSpace(metaclass=ABCMeta):
"""Base Space class."""
sample: Callable = None

def __init__(self, **kwargs) -> None:
self.kwargs = kwargs

@property
def space(self) -> callable:
def space(self) -> tune.sample.Domain:
"""Return the space."""
return getattr(self, '_space', None)
return self.sample.__func__(**self.kwargs)


@SPACES.register_module()
class Uniform(BaseSpace):
sample: Callable = tune.uniform


@SPACES.register_module()
class Quniform(BaseSpace):
sample: Callable = tune.quniform


@SPACES.register_module()
class Loguniform(BaseSpace):
sample: Callable = tune.loguniform


@SPACES.register_module()
class Qloguniform(BaseSpace):
sample: Callable = tune.qloguniform


@SPACES.register_module()
class Randn(BaseSpace):
sample: Callable = tune.randn


@SPACES.register_module()
class Qrandn(BaseSpace):
sample: Callable = tune.qrandn


@SPACES.register_module()
class Randint(BaseSpace):
sample: Callable = tune.randint


@SPACES.register_module()
class Qrandint(BaseSpace):
sample: Callable = tune.qrandint


@SPACES.register_module()
class Lograndint(BaseSpace):
sample: Callable = tune.lograndint


@SPACES.register_module()
class Qlograndint(BaseSpace):
sample: Callable = tune.qlograndint
47 changes: 16 additions & 31 deletions mmtune/ray/spaces/builder.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,28 @@
import inspect
from typing import Callable, Dict
from typing import Mapping, Sequence

from mmcv.utils import Registry
from ray.tune import sample

from .base import BaseSpace

SPACES = Registry('spaces')


def _register_space(space: Callable) -> None:
"""Register a space.

Args:
space (Callable): The space to register.
"""

@SPACES.register_module(name=space.__name__.capitalize())
class _ImplicitSpace(BaseSpace):

def __init__(self, *args, **kwargs):
self._space = space(*args, **kwargs)


for space_name in dir(sample):
space = getattr(sample, space_name)
if not inspect.isfunction(space):
continue
_register_space(space)


def build_space(cfgs: Dict) -> Dict:
def build_space(cfg: dict) -> dict:
"""Build a space.

Args:
cfgs (Dict): The configurations of the space.
cfg (dict): The configurations of the space.

Returns:
Dict: The instantiated space.
dict: The instantiated space.
"""

return {key: SPACES.build(cfg).space for key, cfg in cfgs.items()}
cfg = cfg.copy()
for k, v in cfg.items():
if isinstance(v, (int, str, bool, float)):
continue
elif isinstance(v, Sequence):
cfg[k] = [build_space(_) if isinstance(_, dict) else _ for _ in v]
elif isinstance(v, Mapping):
cfg[k] = build_space(v)
typ = cfg[k].get('type', '')
if isinstance(typ, str) and typ in SPACES:
cfg[k] = SPACES.build(cfg[k]).space
return cfg
41 changes: 21 additions & 20 deletions mmtune/ray/spaces/choice.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,34 @@
from typing import Optional, Sequence
from typing import Callable, Optional, Sequence

from ray.tune.sample import choice
import ray.tune as tune

from mmtune.utils import ImmutableContainer
from .base import BaseSpace
from .builder import SPACES


@SPACES.register_module(force=True)
@SPACES.register_module()
class Choice(BaseSpace):
"""Sample a categorical value.

Args:
categories (Sequence): The categories.
alias (Sequence, optional): A alias to be expressed.
Defaults to None.
"""

sample: Callable = tune.choice

def __init__(self,
categories: Sequence,
alias: Optional[Sequence] = None,
use_container: bool = True):
"""Initialize Choice.

Args:
categories (Sequence): The categories.
alias (Optional[Sequence]):
A alias to be expressed. Defaults to None.
use_container (bool):
Whether to use containers. Defaults to True.
"""

alias: Optional[Sequence] = None) -> None:
if alias is not None:
assert len(categories) == len(alias)
categories = [
ImmutableContainer(c, None if alias is None else alias[idx])
if use_container else c for idx, c in enumerate(categories)
]
self._space = choice(categories)
categories = [
ImmutableContainer(*it) for it in zip(categories, alias)
]
self.categories = categories

@property
def space(self) -> tune.sample.Domain:
return self.sample.__func__(self.categories)
27 changes: 0 additions & 27 deletions mmtune/ray/spaces/const.py

This file was deleted.

32 changes: 32 additions & 0 deletions mmtune/ray/spaces/grid_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Callable, Optional, Sequence

import ray.tune as tune

from mmtune.utils import ImmutableContainer
from .base import BaseSpace
from .builder import SPACES


@SPACES.register_module()
class GridSearch(BaseSpace):
"""Grid search over a value.

Args:
values (Sequence): An iterable whose parameters will be gridded.
alias (Sequence, optional): A alias to be expressed.
Defaults to None.
"""

sample: Callable = tune.grid_search

def __init__(self,
values: Sequence,
alias: Optional[Sequence] = None) -> None:
if alias is not None:
assert len(values) == len(alias)
values = [ImmutableContainer(*it) for it in zip(values, alias)]
self.values = values

@property
def space(self) -> dict:
return self.sample.__func__(self.values)
28 changes: 28 additions & 0 deletions mmtune/ray/spaces/sample_from.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Callable, Union

import ray.tune as tune

from .base import BaseSpace
from .builder import SPACES


@SPACES.register_module()
class SampleFrom(BaseSpace):
"""Specify that tune should sample configuration values from this function.

Args:
func (str | Callable): An string or callable function
to draw a sample from.
"""

sample: Callable = tune.sample_from

def __init__(self, func: Union[str, Callable]) -> None:
if isinstance(func, str):
assert func.startswith('lambda')
func = eval(func)
self.func = func

@property
def space(self) -> tune.sample.Domain:
return self.sample.__func__(self.func)
71 changes: 52 additions & 19 deletions tests/test_ray/test_spaces.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import pytest
from ray import tune

from mmtune.ray.spaces import BaseSpace, Choice, Constant, build_space


def test_base_space():
assert hasattr(BaseSpace(), 'space')
from mmtune.ray.spaces import Choice, GridSearch, SampleFrom, build_space
from mmtune.utils import ImmutableContainer


def test_build_space():
Expand All @@ -18,22 +16,57 @@ def test_build_space():

def test_choice():

def objective(config):
assert config.get('test') in [1, 2, 3]
tune.report(result=config['test'])
def is_in(config):
assert config['test'] in [0, 1, 2]

choice = Choice(categories=[0, 1, 2])
tune.run(is_in, config=dict(test=choice.space))

# with alias
def is_immutable(config):
assert isinstance(config['test'], ImmutableContainer)
assert config['test'].data in [True, False]
assert config['test'].alias in ['T', 'F']

with pytest.raises(AssertionError):
choice = Choice(categories=[True, False], alias=['TF'])
choice = Choice(categories=[True, False], alias=['T', 'F'])
tune.run(is_immutable, config=dict(test=choice.space))


def test_grid_search():

def is_in(config):
assert config['test1'] in [0, 1, 2]
assert config['test2'] in [3, 4, 5]

grid1 = GridSearch(values=[0, 1, 2])
grid2 = GridSearch(values=[3, 4, 5])
tune.run(is_in, config=dict(test1=grid1.space, test2=grid2.space))

# with alias
def is_immutable(config):
for test in ['test1', 'test2']:
assert isinstance(config[test], ImmutableContainer)
assert config[test].data in [True, False]
assert config[test].alias in ['T', 'F']

with pytest.raises(AssertionError):
grid1 = GridSearch(values=[True, False], alias=['TF'])
grid1 = GridSearch(values=[True, False], alias=['T', 'F'])
grid2 = GridSearch(values=[False, True], alias=['F', 'T'])
tune.run(is_immutable, config=dict(test1=grid1.space, test2=grid2.space))

tune.run(
objective,
config=dict(
test=Choice(categories=[1, 2, 3], use_container=False).space))

def test_sample_from():

def test_constant():
def is_eq(config):
assert config['test'] == config['base']**2

def objective(config):
assert config.get('test') == -1
tune.report(result=config['test'])
with pytest.raises(AssertionError):
sample_from = SampleFrom('wrong expression')
sample_from = SampleFrom('lambda spec: spec.config.base ** 2')
tune.run(is_eq, config=dict(base=10, test=sample_from.space))

tune.run(
objective,
config=dict(test=Constant(value=-1, use_container=False).space))
sample_from = SampleFrom(lambda spec: spec.config.base**2)
tune.run(is_eq, config=dict(base=10, test=sample_from.space))