Skip to content

Commit 9f02b72

Browse files
KKIEEKKKIEEK
authored and
KKIEEK
committed
Refactor spaces
1 parent bb04144 commit 9f02b72

File tree

6 files changed

+123
-73
lines changed

6 files changed

+123
-73
lines changed

mmtune/ray/spaces/__init__.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
from .base import BaseSpace
1+
from .base import (BaseSpace, Lograndint, Loguniform, Qlograndint, Qloguniform,
2+
Qrandint, Qrandn, Quniform, Randint, Randn, Uniform)
23
from .builder import SPACES, build_space
34
from .choice import Choice
4-
from .const import Constant
55

6-
__all__ = ['BaseSpace', 'SPACES', 'build_space', 'Choice', 'Constant']
6+
__all__ = [
7+
'BaseSpace', 'Uniform', 'Quniform', 'Loguniform', 'Qloguniform', 'Randn',
8+
'Qrandn', 'Randint', 'Qrandint', 'Lograndint', 'Qlograndint', 'SPACES',
9+
'build_space', 'Choice'
10+
]

mmtune/ray/spaces/base.py

+62-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,69 @@
11
from abc import ABCMeta
2+
from numbers import Number
3+
from typing import Callable, Union
4+
5+
import ray.tune as tune
6+
7+
from .builder import SPACES
28

39

410
class BaseSpace(metaclass=ABCMeta):
511
"""Base Space class."""
12+
sample: Callable
13+
14+
def __init__(self, **kwargs):
15+
self.kwargs = kwargs
616

717
@property
8-
def space(self) -> callable:
9-
"""Return the space."""
10-
return getattr(self, '_space', None)
18+
def space(self) -> Union[Number, list]:
19+
return self.sample.__func__(**self.kwargs)
20+
21+
22+
@SPACES.register_module()
23+
class Uniform(BaseSpace):
24+
sample: Callable = tune.uniform
25+
26+
27+
@SPACES.register_module()
28+
class Quniform(BaseSpace):
29+
sample: Callable = tune.quniform
30+
31+
32+
@SPACES.register_module()
33+
class Loguniform(BaseSpace):
34+
sample: Callable = tune.loguniform
35+
36+
37+
@SPACES.register_module()
38+
class Qloguniform(BaseSpace):
39+
sample: Callable = tune.qloguniform
40+
41+
42+
@SPACES.register_module()
43+
class Randn(BaseSpace):
44+
sample: Callable = tune.randn
45+
46+
47+
@SPACES.register_module()
48+
class Qrandn(BaseSpace):
49+
sample: Callable = tune.qrandn
50+
51+
52+
@SPACES.register_module()
53+
class Randint(BaseSpace):
54+
sample: Callable = tune.randint
55+
56+
57+
@SPACES.register_module()
58+
class Qrandint(BaseSpace):
59+
sample: Callable = tune.qrandint
60+
61+
62+
@SPACES.register_module()
63+
class Lograndint(BaseSpace):
64+
sample: Callable = tune.lograndint
65+
66+
67+
@SPACES.register_module()
68+
class Qlograndint(BaseSpace):
69+
sample: Callable = tune.qlograndint

mmtune/ray/spaces/builder.py

+14-31
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,26 @@
1-
import inspect
2-
from typing import Callable, Dict
1+
from typing import Mapping, Sequence
32

43
from mmcv.utils import Registry
5-
from ray.tune import sample
6-
7-
from .base import BaseSpace
84

95
SPACES = Registry('spaces')
106

117

12-
def _register_space(space: Callable) -> None:
13-
"""Register a space.
14-
15-
Args:
16-
space (Callable): The space to register.
17-
"""
18-
19-
@SPACES.register_module(name=space.__name__.capitalize())
20-
class _ImplicitSpace(BaseSpace):
21-
22-
def __init__(self, *args, **kwargs):
23-
self._space = space(*args, **kwargs)
24-
25-
26-
for space_name in dir(sample):
27-
space = getattr(sample, space_name)
28-
if not inspect.isfunction(space):
29-
continue
30-
_register_space(space)
31-
32-
33-
def build_space(cfgs: Dict) -> Dict:
8+
def build_space(cfg: dict) -> dict:
349
"""Build a space.
3510
3611
Args:
37-
cfgs (Dict): The configurations of the space.
12+
cfg (dict): The configurations of the space.
3813
3914
Returns:
40-
Dict: The instantiated space.
15+
dict: The instantiated space.
4116
"""
42-
43-
return {key: SPACES.build(cfg).space for key, cfg in cfgs.items()}
17+
# inductive step
18+
cfg = cfg.copy()
19+
for key, val in cfg.items():
20+
if isinstance(val, Sequence):
21+
cfg[key] = [build_space(v) for v in val if isinstance(v, Mapping)]
22+
elif isinstance(val, Mapping):
23+
cfg[key] = build_space(val)
24+
if cfg.get('type', '') in SPACES:
25+
cfg[key] = SPACES.build(cfg).space
26+
return cfg

mmtune/ray/spaces/choice.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
from typing import Optional, Sequence
1+
from numbers import Number
2+
from typing import Callable, Optional, Sequence, Union
23

3-
from ray.tune.sample import choice
4+
import ray.tune as tune
45

56
from mmtune.utils import ImmutableContainer
67
from .base import BaseSpace
78
from .builder import SPACES
89

910

10-
@SPACES.register_module(force=True)
11+
@SPACES.register_module()
1112
class Choice(BaseSpace):
13+
sample: Callable = tune.choice
1214

1315
def __init__(self,
1416
categories: Sequence,
@@ -23,11 +25,16 @@ def __init__(self,
2325
use_container (bool):
2426
Whether to use containers. Defaults to True.
2527
"""
26-
2728
if alias is not None:
2829
assert len(categories) == len(alias)
29-
categories = [
30-
ImmutableContainer(c, None if alias is None else alias[idx])
31-
if use_container else c for idx, c in enumerate(categories)
32-
]
33-
self._space = choice(categories)
30+
31+
if use_container:
32+
aliases = alias or [None] * len(categories)
33+
categories = [
34+
ImmutableContainer(*it) for it in zip(categories, aliases)
35+
]
36+
self.categories = categories
37+
38+
@property
39+
def space(self) -> Union[Number, list]:
40+
return self.sample.__func__(self.categories)

mmtune/ray/spaces/const.py

-27
This file was deleted.

mmtune/ray/spaces/sample_from.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from numbers import Number
2+
from typing import Callable, Union
3+
4+
import ray.tune as tune
5+
6+
from .base import BaseSpace
7+
from .builder import SPACES
8+
9+
10+
@SPACES.register_module()
11+
class SampleFrom(BaseSpace):
12+
sample: Callable = tune.sample_from
13+
14+
def __init__(self, func: Union[str, Callable], imports=None):
15+
if isinstance(func, str):
16+
func = eval(func)
17+
self.func = func
18+
self.imports = imports or []
19+
20+
@property
21+
def space(self) -> Union[Number, list]:
22+
for module in self.imports:
23+
exec(f'import {module}')
24+
return self.sample.__func__(self.func)

0 commit comments

Comments
 (0)