-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add test apis * Add test code * Add test code * Add test code * Lint * add ci * fix ci * fix ci * fix ci * fix ci * fix ci * fix ci
- Loading branch information
Showing
18 changed files
with
383 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
name: Build & Unit Test | ||
|
||
on: | ||
pull_request: | ||
branch: | ||
- 'master' | ||
paths-ignore: | ||
- ".github/**.md" | ||
- "README.md" | ||
|
||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.ref }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
unit_test: | ||
runs-on: ubuntu-latest | ||
timeout-minutes: 120 | ||
|
||
steps: | ||
- name: checkout | ||
uses: actions/checkout@v2 | ||
|
||
- name: Build and install | ||
run: | | ||
pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html | ||
pip install mmcv-full==1.4.7 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html | ||
pip install nevergrad mmdet mmsegmentation | ||
pip install -e . | ||
- name: Run unittests and generate coverage report | ||
run: | | ||
pip install pytest coverage | ||
coverage run --branch --source mmtune -m pytest tests/ | ||
coverage xml | ||
coverage report -m | ||
- name: Display coverage | ||
uses: ewjoachim/coverage-comment-action@v1 | ||
with: | ||
GITHUB_TOKEN: ${{ github.token }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
from mmcv.utils import Config, Registry | ||
from typing import Dict | ||
|
||
from mmcv.utils import Registry | ||
|
||
REWRITERS = Registry('rewriters') | ||
|
||
|
||
def build_rewriter(cfg: Config) -> object: | ||
def build_rewriter(cfg: Dict) -> object: | ||
return REWRITERS.build(cfg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import os | ||
import tempfile | ||
from unittest.mock import MagicMock | ||
|
||
import mmcv | ||
from ray.tune.trainable import Trainable | ||
|
||
from mmtune.apis import log_analysis, tune | ||
|
||
|
||
def test_log_analysis(): | ||
mock_analysis = MagicMock() | ||
|
||
task_config = mmcv.Config(dict(model=dict(type='TempModel'))) | ||
|
||
tune_config = mmcv.Config( | ||
dict( | ||
scheduler=dict( | ||
type='AsyncHyperBandScheduler', | ||
time_attr='training_iteration', | ||
max_t=20, | ||
grace_period=2), | ||
metric='accuracy', | ||
mode='max', | ||
)) | ||
|
||
mock_analysis.best_config = task_config | ||
mock_analysis.best_result = dict(accuracy=50) | ||
mock_analysis.best_logdir = 'temp_log_dir' | ||
mock_analysis.results = [dict(accuracy=50)] | ||
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
log_analysis(mock_analysis, tune_config, task_config, tmpdir) | ||
assert os.path.exists(os.path.join(tmpdir, 'tune_config.py')) | ||
assert os.path.exists(os.path.join(tmpdir, 'task_config.py')) | ||
|
||
|
||
def test_tune(): | ||
|
||
class TestTrainable(Trainable): | ||
|
||
def step(self): | ||
result = {'name': self.trial_name, 'trial_id': self.trial_id} | ||
return result | ||
|
||
tune_config = mmcv.Config( | ||
dict( | ||
scheduler=dict( | ||
type='AsyncHyperBandScheduler', | ||
time_attr='training_iteration', | ||
max_t=3, | ||
grace_period=1), | ||
metric='accuracy', | ||
mode='max', | ||
num_samples=1)) | ||
|
||
mock_task_processor = MagicMock() | ||
mock_task_processor.create_trainable.return_value = TestTrainable | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
mock_task_processor.args.work_dir = tmpdir | ||
mock_task_processor.args.num_workers = 1 | ||
mock_task_processor.args.num_cpus_per_worker = 1 | ||
mock_task_processor.args.num_gpus_per_worker = 0 | ||
tune(mock_task_processor, tune_config, 'exp_name') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from unittest.mock import MagicMock | ||
|
||
import mmcv | ||
import pytest | ||
|
||
from mmtune.mm.context import ContextManager, build_rewriter | ||
|
||
|
||
def test_contextmanager(): | ||
base_cfg = dict() | ||
args = MagicMock() | ||
|
||
rewriters = [dict(type='Decouple'), build_rewriter(dict(type='Dump'))] | ||
context_manager = ContextManager(base_cfg, args, rewriters) | ||
|
||
rewriters = [dict(type='Decouple')] | ||
context_manager = ContextManager(base_cfg, args, rewriters) | ||
func = lambda **kargs: 1 # noqa | ||
inner = context_manager(func) | ||
config = mmcv.Config(dict()) | ||
context = dict(cfg=config) | ||
inner(config, context) | ||
|
||
with pytest.raises(TypeError): | ||
rewriters = [dict(type='Decouple'), []] | ||
context_manager = ContextManager(base_cfg, args, rewriters) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import os | ||
from unittest.mock import MagicMock, patch | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
from mmtune.mm.hooks import RayCheckpointHook, RayTuneLoggerHook | ||
|
||
|
||
def test_raycheckpointhook(): | ||
os.environ['MASTER_ADDR'] = '127.0.0.1' | ||
os.environ['MASTER_PORT'] = '29500' | ||
dist.init_process_group('gloo', rank=0, world_size=1) | ||
|
||
hook = RayCheckpointHook( | ||
interval=1, | ||
by_epoch=True, | ||
out_dir='/tmp/ray_checkpoint', | ||
mode='min', | ||
metric_name='loss', | ||
max_concurrent=1, | ||
checkpoint_metric=True, | ||
checkpoint_at_end=True, | ||
) | ||
mock_runner = MagicMock() | ||
mock_runner.inner_iter = 3 | ||
mock_runner.iter = 5 | ||
|
||
cur_iter = hook.get_iter(mock_runner, False) | ||
assert cur_iter == 6 | ||
cur_iter = hook.get_iter(mock_runner, True) | ||
assert cur_iter == 4 | ||
|
||
mock_runner.model = torch.nn.Linear(2, 2) | ||
|
||
hook._save_checkpoint(mock_runner) | ||
|
||
|
||
@patch.object(RayTuneLoggerHook, 'get_loggable_tags') | ||
def test_raytuneloggerhook(mock_get_loggable_tags): | ||
mock_get_loggable_tags.return_value = {'train/Loss': 0.55, 'val/mAP': 0.6} | ||
|
||
mock_runner = MagicMock() | ||
mock_runner.iter = 5 | ||
|
||
loggerhook = RayTuneLoggerHook() | ||
loggerhook.log(mock_runner) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from unittest.mock import MagicMock, patch | ||
|
||
import mmcv | ||
|
||
from mmtune.mm.context.rewriters import (BatchConfigPathcer, ConfigMerger, | ||
CustomHookRegister, Decouple, Dump, | ||
SequeunceConfigPathcer, SetEnv) | ||
|
||
|
||
def test_decouple(): | ||
decouple = Decouple(keys=['searched_cfg', 'base_cfg']) | ||
|
||
context = dict( | ||
base_cfg=dict(model=dict(type='DummyModel')), | ||
searched_cfg=dict(model=[ | ||
dict(type='DummyModel'), | ||
dict(type='DummyModel2'), | ||
]), | ||
) | ||
decouple(context) | ||
|
||
|
||
def test_dump(): | ||
dump = Dump() | ||
config = mmcv.Config(dict()) | ||
args = MagicMock() | ||
args.config = config | ||
context = dict(cfg=config, args=args) | ||
dump(context) | ||
|
||
|
||
@patch('ray.tune.get_trial_id') | ||
def test_setenv(mock_get_trial_id): | ||
mock_get_trial_id.return_value = 'sdfkj234' | ||
setenv = SetEnv() | ||
|
||
args = MagicMock() | ||
args.work_dir = 'tmpdir' | ||
context = dict(args=args) | ||
setenv(context) | ||
|
||
|
||
def test_merge(): | ||
merger = ConfigMerger() | ||
|
||
context = dict( | ||
base_cfg=mmcv.Config(dict(model=dict(type='DummyModel'))), | ||
searched_cfg=mmcv.Config( | ||
dict(model=[ | ||
dict(type='DummyModel6'), | ||
dict(type='DummyModel2'), | ||
]))) | ||
merger(context) | ||
|
||
|
||
def test_patch(): | ||
context = dict( | ||
base_cfg=mmcv.Config(dict(model=dict(type='DummyModel'))), | ||
searched_cfg=mmcv.Config( | ||
dict(model=[ | ||
dict(type='DummyModel6'), | ||
dict(type='DummyModel2'), | ||
]))) | ||
patcher = BatchConfigPathcer() | ||
patcher(context) | ||
|
||
patcher = SequeunceConfigPathcer() | ||
patcher(context) | ||
|
||
|
||
def test_register(): | ||
post_custom_hooks = ['a', 'b'] | ||
register = CustomHookRegister(post_custom_hooks) | ||
cfg = MagicMock() | ||
cfg.custom_hooks = [] | ||
context = dict(cfg=cfg) | ||
context = register(context) |
Oops, something went wrong.