Skip to content

Commit

Permalink
[Enhancement] Unit Test (#25)
Browse files Browse the repository at this point in the history
* add test apis

* Add test code

* Add test code

* Add test code

* Lint

* add ci

* fix ci

* fix ci

* fix ci

* fix ci

* fix ci

* fix ci
  • Loading branch information
nijkah authored May 27, 2022
1 parent efea853 commit ad52fe5
Show file tree
Hide file tree
Showing 18 changed files with 383 additions and 9 deletions.
40 changes: 40 additions & 0 deletions .github/workflows/build_unit_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: Build & Unit Test

on:
pull_request:
branch:
- 'master'
paths-ignore:
- ".github/**.md"
- "README.md"


concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
unit_test:
runs-on: ubuntu-latest
timeout-minutes: 120

steps:
- name: checkout
uses: actions/checkout@v2

- name: Build and install
run: |
pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
pip install mmcv-full==1.4.7 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html
pip install nevergrad mmdet mmsegmentation
pip install -e .
- name: Run unittests and generate coverage report
run: |
pip install pytest coverage
coverage run --branch --source mmtune -m pytest tests/
coverage xml
coverage report -m
- name: Display coverage
uses: ewjoachim/coverage-comment-action@v1
with:
GITHUB_TOKEN: ${{ github.token }}
1 change: 0 additions & 1 deletion configs/mmtune/_base_/space/mmdet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@
_delete_=True,
type='FasterRCNN',
backbone=dict(
_delete_=True,
type='SwinTransformer',
embed_dims=96,
depths=[2, 2, 18, 2],
Expand Down
6 changes: 4 additions & 2 deletions mmtune/mm/context/rewriters/builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from mmcv.utils import Config, Registry
from typing import Dict

from mmcv.utils import Registry

REWRITERS = Registry('rewriters')


def build_rewriter(cfg: Config) -> object:
def build_rewriter(cfg: Dict) -> object:
return REWRITERS.build(cfg)
7 changes: 5 additions & 2 deletions mmtune/mm/context/rewriters/decouple.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Dict, List

from mmtune.utils import ImmutableContainer
from .builder import REWRITERS
Expand All @@ -10,7 +10,10 @@ class Decouple:
def __init__(self, keys: List[str] = []):
self.keys = keys

def __call__(self, context: dict):
def __call__(self, context: Dict):
assert set(context).issuperset(set(
self.keys)), ('context should have superset of keys!')

for key in self.keys:
context[key] = ImmutableContainer.decouple(context[key])
return context
28 changes: 28 additions & 0 deletions mmtune/mm/context/rewriters/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,41 @@ class ConfigMerger:
def merge_dict(src: dict,
dst: dict,
allow_list_keys: Union[list, dict, bool] = False):
"""merge dict ``a`` into dict ``b`` (non-inplace).
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
in-place modifications.
Args:
a (dict): The source dict to be merged into ``b``.
b (dict): The origin dict to be fetch keys from ``a``.
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
are allowed in source ``a`` and will replace the element of the
corresponding index in b if b is a list. Default: False.
Returns:
dict: The modified dict of ``b`` using ``a``.
Examples:
# Normally merge a into b.
>>> Config._merge_a_into_b(
... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
{'obj': {'a': 2}}
# Delete b first and merge a into b.
>>> Config._merge_a_into_b(
... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
{'obj': {'a': 2}}
# b is a list
>>> Config._merge_a_into_b(
... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
[{'a': 2}, {'b': 2}]
"""
dst = dst.copy()
for k, v in src.items():
if allow_list_keys and k.isdigit() and isinstance(dst, list):
k = int(k)
if len(dst) <= k:
raise KeyError(
f'Index {k} exceeds the length of list {dst}')
# modified from the mmcv.config.Config._merge_a_into_b
# this allows merging with primitives such as int, float
dst[k] = ConfigMerger.merge_dict(v, dst[k],
allow_list_keys) if hasattr(
dst[k], 'copy') else v
Expand Down
4 changes: 2 additions & 2 deletions mmtune/mm/tasks/mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@ def run(self, *args, **kwargs):
val_dataset.pipeline = cfg.data.train.pipeline
datasets.append(self.build_dataset(val_dataset))
if cfg.checkpoint_config is not None:
# save mmseg version, config file content and class names in
# save mmdet version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmseg_version=f'{__version__}+{get_git_hash()[:7]}',
mmdet_version=f'{__version__}+{get_git_hash()[:7]}',
config=cfg.pretty_text,
CLASSES=datasets[0].CLASSES,
PALETTE=datasets[0].PALETTE)
Expand Down
4 changes: 3 additions & 1 deletion mmtune/ray/stoppers/early_drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def __init__(self,
metric_threshold: float,
grace_period: int = 0):

assert mode in ['min', 'max']
if mode not in ['min', 'max']:
raise ValueError('mode must be either "min" or "max".')

self._metric = metric
self._mode = mode
self._metric_threshold = metric_threshold
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ line_length = 79
multi_line_output = 0
extra_standard_library = setuptools
known_first_party = mmtune
known_third_party = mmcv,numpy,ray,setuptools,torch
known_third_party = mmcv,numpy,pytest,ray,setuptools,torch
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
64 changes: 64 additions & 0 deletions tests/test_apis/test_apis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os
import tempfile
from unittest.mock import MagicMock

import mmcv
from ray.tune.trainable import Trainable

from mmtune.apis import log_analysis, tune


def test_log_analysis():
mock_analysis = MagicMock()

task_config = mmcv.Config(dict(model=dict(type='TempModel')))

tune_config = mmcv.Config(
dict(
scheduler=dict(
type='AsyncHyperBandScheduler',
time_attr='training_iteration',
max_t=20,
grace_period=2),
metric='accuracy',
mode='max',
))

mock_analysis.best_config = task_config
mock_analysis.best_result = dict(accuracy=50)
mock_analysis.best_logdir = 'temp_log_dir'
mock_analysis.results = [dict(accuracy=50)]

with tempfile.TemporaryDirectory() as tmpdir:
log_analysis(mock_analysis, tune_config, task_config, tmpdir)
assert os.path.exists(os.path.join(tmpdir, 'tune_config.py'))
assert os.path.exists(os.path.join(tmpdir, 'task_config.py'))


def test_tune():

class TestTrainable(Trainable):

def step(self):
result = {'name': self.trial_name, 'trial_id': self.trial_id}
return result

tune_config = mmcv.Config(
dict(
scheduler=dict(
type='AsyncHyperBandScheduler',
time_attr='training_iteration',
max_t=3,
grace_period=1),
metric='accuracy',
mode='max',
num_samples=1))

mock_task_processor = MagicMock()
mock_task_processor.create_trainable.return_value = TestTrainable
with tempfile.TemporaryDirectory() as tmpdir:
mock_task_processor.args.work_dir = tmpdir
mock_task_processor.args.num_workers = 1
mock_task_processor.args.num_cpus_per_worker = 1
mock_task_processor.args.num_gpus_per_worker = 0
tune(mock_task_processor, tune_config, 'exp_name')
26 changes: 26 additions & 0 deletions tests/test_mm/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from unittest.mock import MagicMock

import mmcv
import pytest

from mmtune.mm.context import ContextManager, build_rewriter


def test_contextmanager():
base_cfg = dict()
args = MagicMock()

rewriters = [dict(type='Decouple'), build_rewriter(dict(type='Dump'))]
context_manager = ContextManager(base_cfg, args, rewriters)

rewriters = [dict(type='Decouple')]
context_manager = ContextManager(base_cfg, args, rewriters)
func = lambda **kargs: 1 # noqa
inner = context_manager(func)
config = mmcv.Config(dict())
context = dict(cfg=config)
inner(config, context)

with pytest.raises(TypeError):
rewriters = [dict(type='Decouple'), []]
context_manager = ContextManager(base_cfg, args, rewriters)
47 changes: 47 additions & 0 deletions tests/test_mm/test_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
from unittest.mock import MagicMock, patch

import torch
import torch.distributed as dist

from mmtune.mm.hooks import RayCheckpointHook, RayTuneLoggerHook


def test_raycheckpointhook():
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group('gloo', rank=0, world_size=1)

hook = RayCheckpointHook(
interval=1,
by_epoch=True,
out_dir='/tmp/ray_checkpoint',
mode='min',
metric_name='loss',
max_concurrent=1,
checkpoint_metric=True,
checkpoint_at_end=True,
)
mock_runner = MagicMock()
mock_runner.inner_iter = 3
mock_runner.iter = 5

cur_iter = hook.get_iter(mock_runner, False)
assert cur_iter == 6
cur_iter = hook.get_iter(mock_runner, True)
assert cur_iter == 4

mock_runner.model = torch.nn.Linear(2, 2)

hook._save_checkpoint(mock_runner)


@patch.object(RayTuneLoggerHook, 'get_loggable_tags')
def test_raytuneloggerhook(mock_get_loggable_tags):
mock_get_loggable_tags.return_value = {'train/Loss': 0.55, 'val/mAP': 0.6}

mock_runner = MagicMock()
mock_runner.iter = 5

loggerhook = RayTuneLoggerHook()
loggerhook.log(mock_runner)
77 changes: 77 additions & 0 deletions tests/test_mm/test_rewriters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from unittest.mock import MagicMock, patch

import mmcv

from mmtune.mm.context.rewriters import (BatchConfigPathcer, ConfigMerger,
CustomHookRegister, Decouple, Dump,
SequeunceConfigPathcer, SetEnv)


def test_decouple():
decouple = Decouple(keys=['searched_cfg', 'base_cfg'])

context = dict(
base_cfg=dict(model=dict(type='DummyModel')),
searched_cfg=dict(model=[
dict(type='DummyModel'),
dict(type='DummyModel2'),
]),
)
decouple(context)


def test_dump():
dump = Dump()
config = mmcv.Config(dict())
args = MagicMock()
args.config = config
context = dict(cfg=config, args=args)
dump(context)


@patch('ray.tune.get_trial_id')
def test_setenv(mock_get_trial_id):
mock_get_trial_id.return_value = 'sdfkj234'
setenv = SetEnv()

args = MagicMock()
args.work_dir = 'tmpdir'
context = dict(args=args)
setenv(context)


def test_merge():
merger = ConfigMerger()

context = dict(
base_cfg=mmcv.Config(dict(model=dict(type='DummyModel'))),
searched_cfg=mmcv.Config(
dict(model=[
dict(type='DummyModel6'),
dict(type='DummyModel2'),
])))
merger(context)


def test_patch():
context = dict(
base_cfg=mmcv.Config(dict(model=dict(type='DummyModel'))),
searched_cfg=mmcv.Config(
dict(model=[
dict(type='DummyModel6'),
dict(type='DummyModel2'),
])))
patcher = BatchConfigPathcer()
patcher(context)

patcher = SequeunceConfigPathcer()
patcher(context)


def test_register():
post_custom_hooks = ['a', 'b']
register = CustomHookRegister(post_custom_hooks)
cfg = MagicMock()
cfg.custom_hooks = []
context = dict(cfg=cfg)
context = register(context)
Loading

0 comments on commit ad52fe5

Please sign in to comment.