1
+ import pytest
1
2
from ray import tune
2
3
3
- from mmtune .ray .spaces import BaseSpace , Choice , Constant , build_space
4
-
5
-
6
- def test_base_space ():
7
- assert hasattr (BaseSpace (), 'space' )
4
+ from mmtune .ray .spaces import Choice , GridSearch , SampleFrom , build_space
5
+ from mmtune .utils import ImmutableContainer
8
6
9
7
10
8
def test_build_space ():
@@ -18,22 +16,57 @@ def test_build_space():
18
16
19
17
def test_choice ():
20
18
21
- def objective (config ):
22
- assert config .get ('test' ) in [1 , 2 , 3 ]
23
- tune .report (result = config ['test' ])
19
+ def is_in (config ):
20
+ assert config ['test' ] in [0 , 1 , 2 ]
21
+
22
+ choice = Choice (categories = [0 , 1 , 2 ])
23
+ tune .run (is_in , config = dict (test = choice .space ))
24
+
25
+ # with alias
26
+ def is_immutable (config ):
27
+ assert isinstance (config ['test' ], ImmutableContainer )
28
+ assert config ['test' ].data in [True , False ]
29
+ assert config ['test' ].alias in ['T' , 'F' ]
30
+
31
+ with pytest .raises (AssertionError ):
32
+ choice = Choice (categories = [True , False ], alias = ['TF' ])
33
+ choice = Choice (categories = [True , False ], alias = ['T' , 'F' ])
34
+ tune .run (is_immutable , config = dict (test = choice .space ))
35
+
36
+
37
+ def test_grid_search ():
38
+
39
+ def is_in (config ):
40
+ assert config ['test1' ] in [0 , 1 , 2 ]
41
+ assert config ['test2' ] in [3 , 4 , 5 ]
42
+
43
+ grid1 = GridSearch (values = [0 , 1 , 2 ])
44
+ grid2 = GridSearch (values = [3 , 4 , 5 ])
45
+ tune .run (is_in , config = dict (test1 = grid1 .space , test2 = grid2 .space ))
46
+
47
+ # with alias
48
+ def is_immutable (config ):
49
+ for test in ['test1' , 'test2' ]:
50
+ assert isinstance (config [test ], ImmutableContainer )
51
+ assert config [test ].data in [True , False ]
52
+ assert config [test ].alias in ['T' , 'F' ]
53
+
54
+ with pytest .raises (AssertionError ):
55
+ grid1 = GridSearch (values = [True , False ], alias = ['TF' ])
56
+ grid1 = GridSearch (values = [True , False ], alias = ['T' , 'F' ])
57
+ grid2 = GridSearch (values = [False , True ], alias = ['F' , 'T' ])
58
+ tune .run (is_immutable , config = dict (test1 = grid1 .space , test2 = grid2 .space ))
24
59
25
- tune .run (
26
- objective ,
27
- config = dict (
28
- test = Choice (categories = [1 , 2 , 3 ], use_container = False ).space ))
29
60
61
+ def test_sample_from ():
30
62
31
- def test_constant ():
63
+ def is_eq (config ):
64
+ assert config ['test' ] == config ['base' ]** 2
32
65
33
- def objective (config ):
34
- assert config .get ('test' ) == - 1
35
- tune .report (result = config ['test' ])
66
+ with pytest .raises (AssertionError ):
67
+ sample_from = SampleFrom ('wrong expression' )
68
+ sample_from = SampleFrom ('lambda spec: spec.config.base ** 2' )
69
+ tune .run (is_eq , config = dict (base = 10 , test = sample_from .space ))
36
70
37
- tune .run (
38
- objective ,
39
- config = dict (test = Constant (value = - 1 , use_container = False ).space ))
71
+ sample_from = SampleFrom (lambda spec : spec .config .base ** 2 )
72
+ tune .run (is_eq , config = dict (base = 10 , test = sample_from .space ))
0 commit comments