Skip to content

Commit ee00f3f

Browse files
KKIEEKKKIEEK
KKIEEK
authored and
KKIEEK
committed
Fix blocking issue at test_tasks.py
1 parent 789ca62 commit ee00f3f

File tree

1 file changed

+48
-43
lines changed

1 file changed

+48
-43
lines changed

tests/test_mm/test_tasks.py

+48-43
Original file line numberDiff line numberDiff line change
@@ -5,86 +5,82 @@
55
import mmdet # noqa: F401
66
import mmseg # noqa: F401
77
import pytest
8+
import ray
89
import torch
910
from mmcv.utils import Config
1011
from ray import tune
1112
from ray.air import session
13+
from ray.tune.result_grid import ResultGrid
1214

1315
from siatune.mm.tasks import (TASKS, BaseTask, BlackBoxTask,
1416
ContinuousTestFunction, DiscreteTestFunction,
1517
MMClassification, MMDetection, MMSegmentation,
1618
MMTrainBasedTask, build_task_processor)
1719
from siatune.utils.config import dump_cfg
1820

19-
_session = dict()
20-
2121

22-
def report_to_session(*args, **kwargs):
23-
_session = get_session()
24-
_session.update(kwargs)
25-
for arg in args:
26-
if isinstance(arg, dict):
27-
_session.update(arg)
22+
@pytest.fixture
23+
def init_ray():
24+
if ray.is_initialized():
25+
ray.shutdown()
26+
return ray.init(num_cpus=1)
2827

2928

30-
def get_session():
31-
global _session
32-
return _session
33-
34-
35-
@patch('ray.tune.report', side_effect=report_to_session)
36-
def test_base_task(mock_report):
29+
def test_base_task(init_ray):
3730
with pytest.raises(TypeError):
3831
BaseTask()
3932

4033
class TestRewriter:
4134

4235
def __call__(self, context):
43-
context.get('args').test = -1
44-
return context
36+
args = context.pop('args')
37+
args.test = 'success'
38+
return dict(args=args)
4539

4640
class TestTask(BaseTask):
4741

48-
def parse_args(self, *args, **kwargs):
49-
return argparse.Namespace(test=1)
42+
def parse_args(self, args):
43+
return argparse.Namespace(test='init')
5044

51-
def run(self, *, args, **kwargs):
45+
def run(self, args):
5246
tune.report(test=args.test)
53-
return args.test
5447

5548
def create_trainable(self):
5649
return self.context_aware_run
5750

5851
task = TestTask([TestRewriter()])
59-
task.set_args('')
60-
assert task.args == argparse.Namespace(test=1)
61-
assert isinstance(task.rewriters, list)
62-
task.context_aware_run({})
63-
assert get_session().get('test') == -1
52+
task.set_args([])
53+
assert task.args == argparse.Namespace(test='init')
6454

65-
tune.run(task.create_trainable(), config={})
55+
trainable = task.create_trainable()
56+
result = ResultGrid(tune.run(trainable, config={}))[0]
57+
assert result.metrics['test'] == 'success'
6658

6759

68-
def test_black_box_task():
60+
def test_black_box_task(init_ray):
6961
with pytest.raises(TypeError):
7062
BlackBoxTask()
7163

7264
class TestTask(BlackBoxTask):
7365

7466
def run(self, *args, **kwargs):
75-
tune.report(test=1)
67+
tune.report(test='success')
7668

7769
task = TestTask()
78-
task.set_args('')
70+
task.set_args([])
7971
assert task.args == argparse.Namespace()
80-
tune.run(task.create_trainable(), config={})
72+
73+
trainable = task.create_trainable()
74+
result = ResultGrid(tune.run(trainable, config={}))[0]
75+
assert result.metrics['test'] == 'success'
8176

8277

8378
def test_build_task_processor():
8479

80+
@TASKS.register_module()
8581
class TestTaks(BaseTask):
8682

87-
def parse_args(self, *args, **kwargs):
83+
def parse_args(self, args):
8884
pass
8985

9086
def run(self, *args, **kwargs):
@@ -93,12 +89,22 @@ def run(self, *args, **kwargs):
9389
def create_trainable(self, *args, **kwargs):
9490
pass
9591

96-
TASKS.register_module(TestTaks)
97-
assert isinstance(build_task_processor(dict(type='TestTaks')), TestTaks)
92+
task = build_task_processor(dict(type='TestTaks'))
93+
assert isinstance(task, TestTaks)
94+
95+
96+
_session = dict()
97+
98+
99+
def report_to_session(*args, **kwargs):
100+
_session.update(kwargs)
101+
for arg in args:
102+
if isinstance(arg, dict):
103+
_session.update(arg)
98104

99105

100106
@patch('ray.tune.report', side_effect=report_to_session)
101-
def test_continuous_test_function(mock_report):
107+
def test_continuous_test_function(init_ray):
102108
func = ContinuousTestFunction()
103109
predefined_cont_funcs = [
104110
'delayedsphere',
@@ -147,11 +153,11 @@ def test_continuous_test_function(mock_report):
147153
'test.py')
148154
args = argparse.Namespace(config='test.py')
149155
func.run(args=args)
150-
assert isinstance(get_session().get('result'), float)
156+
assert isinstance(_session['result'], float)
151157

152158

153159
@patch('ray.tune.report', side_effect=report_to_session)
154-
def test_discrete_test_function(mock_report):
160+
def test_discrete_test_function(init_ray):
155161
func = DiscreteTestFunction()
156162

157163
predefined_discrete_funcs = ['onemax', 'leadingones', 'jump']
@@ -161,7 +167,7 @@ def test_discrete_test_function(mock_report):
161167
'test.py')
162168
args = argparse.Namespace(config='test.py')
163169
func.run(args=args)
164-
assert isinstance(get_session().get('result'), float)
170+
assert isinstance(_session['result'], float)
165171

166172

167173
@patch('mmcls.apis.train_model')
@@ -194,8 +200,7 @@ def test_mmseg(*not_used):
194200
task.run(args=task.args)
195201

196202

197-
@patch('ray.air.session.report', side_effect=report_to_session)
198-
def test_mm_train_based_task(mock_report):
203+
def test_mm_train_based_task(init_ray):
199204
with pytest.raises(TypeError):
200205
MMTrainBasedTask()
201206

@@ -250,7 +255,7 @@ def train_model(self, model, dataset, cfg):
250255
loss.backward()
251256
optimizer.step()
252257
total_loss += loss.item()
253-
session.report(loss=total_loss / (batch_idx + 1))
258+
session.report(dict(loss=total_loss / (batch_idx + 1)))
254259

255260
def run(self, *, searched_cfg, **kwargs):
256261
cfg = searched_cfg.get('cfg')
@@ -274,7 +279,7 @@ def run(self, *, searched_cfg, **kwargs):
274279
task = TestTask()
275280
task.set_resource()
276281
task.context_aware_run(searched_cfg=dict(cfg=cfg))
277-
assert 'loss' in get_session()
278282

279283
trainable = task.create_trainable()
280-
tune.Tuner(trainable).fit()
284+
result = tune.Tuner(trainable).fit()[0]
285+
assert 'loss' in result.metrics

0 commit comments

Comments
 (0)