5
5
import mmdet # noqa: F401
6
6
import mmseg # noqa: F401
7
7
import pytest
8
+ import ray
8
9
import torch
9
10
from mmcv .utils import Config
10
11
from ray import tune
11
12
from ray .air import session
13
+ from ray .tune .result_grid import ResultGrid
12
14
13
15
from siatune .mm .tasks import (TASKS , BaseTask , BlackBoxTask ,
14
16
ContinuousTestFunction , DiscreteTestFunction ,
15
17
MMClassification , MMDetection , MMSegmentation ,
16
18
MMTrainBasedTask , build_task_processor )
17
19
from siatune .utils .config import dump_cfg
18
20
19
- _session = dict ()
20
-
21
21
22
- def report_to_session (* args , ** kwargs ):
23
- _session = get_session ()
24
- _session .update (kwargs )
25
- for arg in args :
26
- if isinstance (arg , dict ):
27
- _session .update (arg )
22
+ @pytest .fixture
23
+ def init_ray ():
24
+ if ray .is_initialized ():
25
+ ray .shutdown ()
26
+ return ray .init (num_cpus = 1 )
28
27
29
28
30
- def get_session ():
31
- global _session
32
- return _session
33
-
34
-
35
- @patch ('ray.tune.report' , side_effect = report_to_session )
36
- def test_base_task (mock_report ):
29
+ def test_base_task (init_ray ):
37
30
with pytest .raises (TypeError ):
38
31
BaseTask ()
39
32
40
33
class TestRewriter :
41
34
42
35
def __call__ (self , context ):
43
- context .get ('args' ).test = - 1
44
- return context
36
+ args = context .pop ('args' )
37
+ args .test = 'success'
38
+ return dict (args = args )
45
39
46
40
class TestTask (BaseTask ):
47
41
48
- def parse_args (self , * args , ** kwargs ):
49
- return argparse .Namespace (test = 1 )
42
+ def parse_args (self , args ):
43
+ return argparse .Namespace (test = 'init' )
50
44
51
- def run (self , * , args , ** kwargs ):
45
+ def run (self , args ):
52
46
tune .report (test = args .test )
53
- return args .test
54
47
55
48
def create_trainable (self ):
56
49
return self .context_aware_run
57
50
58
51
task = TestTask ([TestRewriter ()])
59
- task .set_args ('' )
60
- assert task .args == argparse .Namespace (test = 1 )
61
- assert isinstance (task .rewriters , list )
62
- task .context_aware_run ({})
63
- assert get_session ().get ('test' ) == - 1
52
+ task .set_args ([])
53
+ assert task .args == argparse .Namespace (test = 'init' )
64
54
65
- tune .run (task .create_trainable (), config = {})
55
+ trainable = task .create_trainable ()
56
+ result = ResultGrid (tune .run (trainable , config = {}))[0 ]
57
+ assert result .metrics ['test' ] == 'success'
66
58
67
59
68
- def test_black_box_task ():
60
+ def test_black_box_task (init_ray ):
69
61
with pytest .raises (TypeError ):
70
62
BlackBoxTask ()
71
63
72
64
class TestTask (BlackBoxTask ):
73
65
74
66
def run (self , * args , ** kwargs ):
75
- tune .report (test = 1 )
67
+ tune .report (test = 'success' )
76
68
77
69
task = TestTask ()
78
- task .set_args ('' )
70
+ task .set_args ([] )
79
71
assert task .args == argparse .Namespace ()
80
- tune .run (task .create_trainable (), config = {})
72
+
73
+ trainable = task .create_trainable ()
74
+ result = ResultGrid (tune .run (trainable , config = {}))[0 ]
75
+ assert result .metrics ['test' ] == 'success'
81
76
82
77
83
78
def test_build_task_processor ():
84
79
80
+ @TASKS .register_module ()
85
81
class TestTaks (BaseTask ):
86
82
87
- def parse_args (self , * args , ** kwargs ):
83
+ def parse_args (self , args ):
88
84
pass
89
85
90
86
def run (self , * args , ** kwargs ):
@@ -93,12 +89,22 @@ def run(self, *args, **kwargs):
93
89
def create_trainable (self , * args , ** kwargs ):
94
90
pass
95
91
96
- TASKS .register_module (TestTaks )
97
- assert isinstance (build_task_processor (dict (type = 'TestTaks' )), TestTaks )
92
+ task = build_task_processor (dict (type = 'TestTaks' ))
93
+ assert isinstance (task , TestTaks )
94
+
95
+
96
+ _session = dict ()
97
+
98
+
99
+ def report_to_session (* args , ** kwargs ):
100
+ _session .update (kwargs )
101
+ for arg in args :
102
+ if isinstance (arg , dict ):
103
+ _session .update (arg )
98
104
99
105
100
106
@patch ('ray.tune.report' , side_effect = report_to_session )
101
- def test_continuous_test_function (mock_report ):
107
+ def test_continuous_test_function (init_ray ):
102
108
func = ContinuousTestFunction ()
103
109
predefined_cont_funcs = [
104
110
'delayedsphere' ,
@@ -147,11 +153,11 @@ def test_continuous_test_function(mock_report):
147
153
'test.py' )
148
154
args = argparse .Namespace (config = 'test.py' )
149
155
func .run (args = args )
150
- assert isinstance (get_session (). get ( 'result' ) , float )
156
+ assert isinstance (_session [ 'result' ] , float )
151
157
152
158
153
159
@patch ('ray.tune.report' , side_effect = report_to_session )
154
- def test_discrete_test_function (mock_report ):
160
+ def test_discrete_test_function (init_ray ):
155
161
func = DiscreteTestFunction ()
156
162
157
163
predefined_discrete_funcs = ['onemax' , 'leadingones' , 'jump' ]
@@ -161,7 +167,7 @@ def test_discrete_test_function(mock_report):
161
167
'test.py' )
162
168
args = argparse .Namespace (config = 'test.py' )
163
169
func .run (args = args )
164
- assert isinstance (get_session (). get ( 'result' ) , float )
170
+ assert isinstance (_session [ 'result' ] , float )
165
171
166
172
167
173
@patch ('mmcls.apis.train_model' )
@@ -194,8 +200,7 @@ def test_mmseg(*not_used):
194
200
task .run (args = task .args )
195
201
196
202
197
- @patch ('ray.air.session.report' , side_effect = report_to_session )
198
- def test_mm_train_based_task (mock_report ):
203
+ def test_mm_train_based_task (init_ray ):
199
204
with pytest .raises (TypeError ):
200
205
MMTrainBasedTask ()
201
206
@@ -250,7 +255,7 @@ def train_model(self, model, dataset, cfg):
250
255
loss .backward ()
251
256
optimizer .step ()
252
257
total_loss += loss .item ()
253
- session .report (loss = total_loss / (batch_idx + 1 ))
258
+ session .report (dict ( loss = total_loss / (batch_idx + 1 ) ))
254
259
255
260
def run (self , * , searched_cfg , ** kwargs ):
256
261
cfg = searched_cfg .get ('cfg' )
@@ -274,7 +279,7 @@ def run(self, *, searched_cfg, **kwargs):
274
279
task = TestTask ()
275
280
task .set_resource ()
276
281
task .context_aware_run (searched_cfg = dict (cfg = cfg ))
277
- assert 'loss' in get_session ()
278
282
279
283
trainable = task .create_trainable ()
280
- tune .Tuner (trainable ).fit ()
284
+ result = tune .Tuner (trainable ).fit ()[0 ]
285
+ assert 'loss' in result .metrics
0 commit comments