Skip to content

Commit 43d2d7d

Browse files
authored
Refactor spaces (#54)
* Refactor spaces * Remove Constant space * Fix * Add test cases
1 parent 1bd4d49 commit 43d2d7d

10 files changed

+222
-105
lines changed

configs/mmtune/mmcls_cifar_100_asynchb_nevergrad_pso.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
space = {
88
'model': {{_base_.model}},
9-
'model.head.num_classes': dict(type='Constant', value=100),
9+
'model.head.num_classes': 100,
1010
'optimizer': {{_base_.optimizer}},
1111
'data.samples_per_gpu': {{_base_.batch_size}},
1212
}

configs/mmtune/mmseg_asynchb_nevergrad_pso.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
'model': {{_base_.model}},
99
'optimizer': {{_base_.optimizer}},
1010
'data.samples_per_gpu': {{_base_.batch_size}},
11-
'model.decode_head.num_classes': dict(type='Constant', value=21),
12-
'model.auxiliary_head.num_classes': dict(type='Constant', value=21),
11+
'model.decode_head.num_classes': 21,
12+
'model.auxiliary_head.num_classes': 21,
1313
}
1414

1515
task = dict(type='MMSegmentation')

mmtune/ray/spaces/__init__.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1-
from .base import BaseSpace
1+
from .base import (BaseSpace, Lograndint, Loguniform, Qlograndint, Qloguniform,
2+
Qrandint, Qrandn, Quniform, Randint, Randn, Uniform)
23
from .builder import SPACES, build_space
34
from .choice import Choice
4-
from .const import Constant
5+
from .grid_search import GridSearch
6+
from .sample_from import SampleFrom
57

6-
__all__ = ['BaseSpace', 'SPACES', 'build_space', 'Choice', 'Constant']
8+
__all__ = [
9+
'BaseSpace', 'Uniform', 'Quniform', 'Loguniform', 'Qloguniform', 'Randn',
10+
'Qrandn', 'Randint', 'Qrandint', 'Lograndint', 'Qlograndint', 'SPACES',
11+
'build_space', 'Choice', 'GridSearch', 'SampleFrom'
12+
]

mmtune/ray/spaces/base.py

+61-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,69 @@
11
from abc import ABCMeta
2+
from typing import Callable
3+
4+
import ray.tune as tune
5+
6+
from .builder import SPACES
27

38

49
class BaseSpace(metaclass=ABCMeta):
510
"""Base Space class."""
11+
sample: Callable = None
12+
13+
def __init__(self, **kwargs) -> None:
14+
self.kwargs = kwargs
615

716
@property
8-
def space(self) -> callable:
17+
def space(self) -> tune.sample.Domain:
918
"""Return the space."""
10-
return getattr(self, '_space', None)
19+
return self.sample.__func__(**self.kwargs)
20+
21+
22+
@SPACES.register_module()
23+
class Uniform(BaseSpace):
24+
sample: Callable = tune.uniform
25+
26+
27+
@SPACES.register_module()
28+
class Quniform(BaseSpace):
29+
sample: Callable = tune.quniform
30+
31+
32+
@SPACES.register_module()
33+
class Loguniform(BaseSpace):
34+
sample: Callable = tune.loguniform
35+
36+
37+
@SPACES.register_module()
38+
class Qloguniform(BaseSpace):
39+
sample: Callable = tune.qloguniform
40+
41+
42+
@SPACES.register_module()
43+
class Randn(BaseSpace):
44+
sample: Callable = tune.randn
45+
46+
47+
@SPACES.register_module()
48+
class Qrandn(BaseSpace):
49+
sample: Callable = tune.qrandn
50+
51+
52+
@SPACES.register_module()
53+
class Randint(BaseSpace):
54+
sample: Callable = tune.randint
55+
56+
57+
@SPACES.register_module()
58+
class Qrandint(BaseSpace):
59+
sample: Callable = tune.qrandint
60+
61+
62+
@SPACES.register_module()
63+
class Lograndint(BaseSpace):
64+
sample: Callable = tune.lograndint
65+
66+
67+
@SPACES.register_module()
68+
class Qlograndint(BaseSpace):
69+
sample: Callable = tune.qlograndint

mmtune/ray/spaces/builder.py

+16-31
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,28 @@
1-
import inspect
2-
from typing import Callable, Dict
1+
from typing import Mapping, Sequence
32

43
from mmcv.utils import Registry
5-
from ray.tune import sample
6-
7-
from .base import BaseSpace
84

95
SPACES = Registry('spaces')
106

117

12-
def _register_space(space: Callable) -> None:
13-
"""Register a space.
14-
15-
Args:
16-
space (Callable): The space to register.
17-
"""
18-
19-
@SPACES.register_module(name=space.__name__.capitalize())
20-
class _ImplicitSpace(BaseSpace):
21-
22-
def __init__(self, *args, **kwargs):
23-
self._space = space(*args, **kwargs)
24-
25-
26-
for space_name in dir(sample):
27-
space = getattr(sample, space_name)
28-
if not inspect.isfunction(space):
29-
continue
30-
_register_space(space)
31-
32-
33-
def build_space(cfgs: Dict) -> Dict:
8+
def build_space(cfg: dict) -> dict:
349
"""Build a space.
3510
3611
Args:
37-
cfgs (Dict): The configurations of the space.
12+
cfg (dict): The configurations of the space.
3813
3914
Returns:
40-
Dict: The instantiated space.
15+
dict: The instantiated space.
4116
"""
42-
43-
return {key: SPACES.build(cfg).space for key, cfg in cfgs.items()}
17+
cfg = cfg.copy()
18+
for k, v in cfg.items():
19+
if isinstance(v, (int, str, bool, float)):
20+
continue
21+
elif isinstance(v, Sequence):
22+
cfg[k] = [build_space(_) if isinstance(_, dict) else _ for _ in v]
23+
elif isinstance(v, Mapping):
24+
cfg[k] = build_space(v)
25+
typ = cfg[k].get('type', '')
26+
if isinstance(typ, str) and typ in SPACES:
27+
cfg[k] = SPACES.build(cfg[k]).space
28+
return cfg

mmtune/ray/spaces/choice.py

+21-20
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,34 @@
1-
from typing import Optional, Sequence
1+
from typing import Callable, Optional, Sequence
22

3-
from ray.tune.sample import choice
3+
import ray.tune as tune
44

55
from mmtune.utils import ImmutableContainer
66
from .base import BaseSpace
77
from .builder import SPACES
88

99

10-
@SPACES.register_module(force=True)
10+
@SPACES.register_module()
1111
class Choice(BaseSpace):
12+
"""Sample a categorical value.
13+
14+
Args:
15+
categories (Sequence): The categories.
16+
alias (Sequence, optional): A alias to be expressed.
17+
Defaults to None.
18+
"""
19+
20+
sample: Callable = tune.choice
1221

1322
def __init__(self,
1423
categories: Sequence,
15-
alias: Optional[Sequence] = None,
16-
use_container: bool = True):
17-
"""Initialize Choice.
18-
19-
Args:
20-
categories (Sequence): The categories.
21-
alias (Optional[Sequence]):
22-
A alias to be expressed. Defaults to None.
23-
use_container (bool):
24-
Whether to use containers. Defaults to True.
25-
"""
26-
24+
alias: Optional[Sequence] = None) -> None:
2725
if alias is not None:
2826
assert len(categories) == len(alias)
29-
categories = [
30-
ImmutableContainer(c, None if alias is None else alias[idx])
31-
if use_container else c for idx, c in enumerate(categories)
32-
]
33-
self._space = choice(categories)
27+
categories = [
28+
ImmutableContainer(*it) for it in zip(categories, alias)
29+
]
30+
self.categories = categories
31+
32+
@property
33+
def space(self) -> tune.sample.Domain:
34+
return self.sample.__func__(self.categories)

mmtune/ray/spaces/const.py

-27
This file was deleted.

mmtune/ray/spaces/grid_search.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import Callable, Optional, Sequence
2+
3+
import ray.tune as tune
4+
5+
from mmtune.utils import ImmutableContainer
6+
from .base import BaseSpace
7+
from .builder import SPACES
8+
9+
10+
@SPACES.register_module()
11+
class GridSearch(BaseSpace):
12+
"""Grid search over a value.
13+
14+
Args:
15+
values (Sequence): An iterable whose parameters will be gridded.
16+
alias (Sequence, optional): A alias to be expressed.
17+
Defaults to None.
18+
"""
19+
20+
sample: Callable = tune.grid_search
21+
22+
def __init__(self,
23+
values: Sequence,
24+
alias: Optional[Sequence] = None) -> None:
25+
if alias is not None:
26+
assert len(values) == len(alias)
27+
values = [ImmutableContainer(*it) for it in zip(values, alias)]
28+
self.values = values
29+
30+
@property
31+
def space(self) -> dict:
32+
return self.sample.__func__(self.values)

mmtune/ray/spaces/sample_from.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Callable, Union
2+
3+
import ray.tune as tune
4+
5+
from .base import BaseSpace
6+
from .builder import SPACES
7+
8+
9+
@SPACES.register_module()
10+
class SampleFrom(BaseSpace):
11+
"""Specify that tune should sample configuration values from this function.
12+
13+
Args:
14+
func (str | Callable): An string or callable function
15+
to draw a sample from.
16+
"""
17+
18+
sample: Callable = tune.sample_from
19+
20+
def __init__(self, func: Union[str, Callable]) -> None:
21+
if isinstance(func, str):
22+
assert func.startswith('lambda')
23+
func = eval(func)
24+
self.func = func
25+
26+
@property
27+
def space(self) -> tune.sample.Domain:
28+
return self.sample.__func__(self.func)

tests/test_ray/test_spaces.py

+52-19
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1+
import pytest
12
from ray import tune
23

3-
from mmtune.ray.spaces import BaseSpace, Choice, Constant, build_space
4-
5-
6-
def test_base_space():
7-
assert hasattr(BaseSpace(), 'space')
4+
from mmtune.ray.spaces import Choice, GridSearch, SampleFrom, build_space
5+
from mmtune.utils import ImmutableContainer
86

97

108
def test_build_space():
@@ -18,22 +16,57 @@ def test_build_space():
1816

1917
def test_choice():
2018

21-
def objective(config):
22-
assert config.get('test') in [1, 2, 3]
23-
tune.report(result=config['test'])
19+
def is_in(config):
20+
assert config['test'] in [0, 1, 2]
21+
22+
choice = Choice(categories=[0, 1, 2])
23+
tune.run(is_in, config=dict(test=choice.space))
24+
25+
# with alias
26+
def is_immutable(config):
27+
assert isinstance(config['test'], ImmutableContainer)
28+
assert config['test'].data in [True, False]
29+
assert config['test'].alias in ['T', 'F']
30+
31+
with pytest.raises(AssertionError):
32+
choice = Choice(categories=[True, False], alias=['TF'])
33+
choice = Choice(categories=[True, False], alias=['T', 'F'])
34+
tune.run(is_immutable, config=dict(test=choice.space))
35+
36+
37+
def test_grid_search():
38+
39+
def is_in(config):
40+
assert config['test1'] in [0, 1, 2]
41+
assert config['test2'] in [3, 4, 5]
42+
43+
grid1 = GridSearch(values=[0, 1, 2])
44+
grid2 = GridSearch(values=[3, 4, 5])
45+
tune.run(is_in, config=dict(test1=grid1.space, test2=grid2.space))
46+
47+
# with alias
48+
def is_immutable(config):
49+
for test in ['test1', 'test2']:
50+
assert isinstance(config[test], ImmutableContainer)
51+
assert config[test].data in [True, False]
52+
assert config[test].alias in ['T', 'F']
53+
54+
with pytest.raises(AssertionError):
55+
grid1 = GridSearch(values=[True, False], alias=['TF'])
56+
grid1 = GridSearch(values=[True, False], alias=['T', 'F'])
57+
grid2 = GridSearch(values=[False, True], alias=['F', 'T'])
58+
tune.run(is_immutable, config=dict(test1=grid1.space, test2=grid2.space))
2459

25-
tune.run(
26-
objective,
27-
config=dict(
28-
test=Choice(categories=[1, 2, 3], use_container=False).space))
2960

61+
def test_sample_from():
3062

31-
def test_constant():
63+
def is_eq(config):
64+
assert config['test'] == config['base']**2
3265

33-
def objective(config):
34-
assert config.get('test') == -1
35-
tune.report(result=config['test'])
66+
with pytest.raises(AssertionError):
67+
sample_from = SampleFrom('wrong expression')
68+
sample_from = SampleFrom('lambda spec: spec.config.base ** 2')
69+
tune.run(is_eq, config=dict(base=10, test=sample_from.space))
3670

37-
tune.run(
38-
objective,
39-
config=dict(test=Constant(value=-1, use_container=False).space))
71+
sample_from = SampleFrom(lambda spec: spec.config.base**2)
72+
tune.run(is_eq, config=dict(base=10, test=sample_from.space))

0 commit comments

Comments
 (0)