Skip to content

Commit 9e1c609

Browse files
KKIEEKKKIEEK
authored and
KKIEEK
committed
Add test cases
1 parent af83b2d commit 9e1c609

File tree

1 file changed

+57
-12
lines changed

1 file changed

+57
-12
lines changed

tests/test_ray/test_spaces.py

+57-12
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1+
import pytest
12
from ray import tune
23

3-
from mmtune.ray.spaces import BaseSpace, Choice, build_space
4-
5-
6-
def test_base_space():
7-
assert hasattr(BaseSpace(), 'space')
4+
from mmtune.ray.spaces import Choice, GridSearch, SampleFrom, build_space
5+
from mmtune.utils import ImmutableContainer
86

97

108
def test_build_space():
@@ -18,11 +16,58 @@ def test_build_space():
1816

1917
def test_choice():
2018

21-
def objective(config):
22-
assert config.get('test') in [1, 2, 3]
23-
tune.report(result=config['test'])
19+
def is_in(config):
20+
assert config['test'] in [0, 1, 2]
21+
22+
choice = Choice(categories=[0, 1, 2])
23+
tune.run(is_in, config=dict(test=choice.space))
24+
25+
# with alias
26+
def is_immutable(config):
27+
assert isinstance(config['test'], ImmutableContainer)
28+
assert config['test'].data in [True, False]
29+
assert config['test'].alias in ['T', 'F']
30+
31+
with pytest.raises(AssertionError):
32+
choice = Choice(categories=[True, False], alias=['TF'])
33+
choice = Choice(categories=[True, False], alias=['T', 'F'])
34+
tune.run(is_immutable, config=dict(test=choice.space))
35+
36+
37+
def test_grid_search():
38+
39+
def is_in(config):
40+
print(config)
41+
assert config['test1'] in [0, 1, 2]
42+
assert config['test2'] in [3, 4, 5]
43+
44+
grid1 = GridSearch(values=[0, 1, 2])
45+
grid2 = GridSearch(values=[3, 4, 5])
46+
tune.run(is_in, config=dict(test1=grid1.space, test2=grid2.space))
47+
48+
# with alias
49+
def is_immutable(config):
50+
for test in ['test1', 'test2']:
51+
assert isinstance(config[test], ImmutableContainer)
52+
assert config[test].data in [True, False]
53+
assert config[test].alias in ['T', 'F']
54+
55+
with pytest.raises(AssertionError):
56+
grid1 = GridSearch(values=[True, False], alias=['TF'])
57+
grid1 = GridSearch(values=[True, False], alias=['T', 'F'])
58+
grid2 = GridSearch(values=[False, True], alias=['F', 'T'])
59+
tune.run(is_immutable, config=dict(test1=grid1.space, test2=grid2.space))
60+
61+
62+
def test_sample_from():
63+
64+
def is_eq(config):
65+
assert config['test'] == config['base']**2
66+
67+
with pytest.raises(AssertionError):
68+
sample_from = SampleFrom('wrong expression')
69+
sample_from = SampleFrom('lambda spec: spec.config.base ** 2')
70+
tune.run(is_eq, config=dict(base=10, test=sample_from.space))
2471

25-
tune.run(
26-
objective,
27-
config=dict(
28-
test=Choice(categories=[1, 2, 3], use_container=False).space))
72+
sample_from = SampleFrom(lambda spec: spec.config.base**2)
73+
tune.run(is_eq, config=dict(base=10, test=sample_from.space))

0 commit comments

Comments
 (0)