Skip to content

Commit

Permalink
add unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouzaida committed Aug 11, 2021
1 parent 2ee69e3 commit 2f934a8
Showing 1 changed file with 42 additions and 9 deletions.
51 changes: 42 additions & 9 deletions tests/test_runner/test_eval_hook.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os.path as osp
import tempfile
import unittest.mock as mock
Expand All @@ -7,13 +8,14 @@
import pytest
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from mmcv.runner import DistEvalHook as BaseDistEvalHook
from mmcv.runner import EpochBasedRunner
from mmcv.runner import EvalHook as BaseEvalHook
from mmcv.runner import IterBasedRunner
from mmcv.utils import get_logger
from mmcv.utils import get_logger, scandir


class ExampleDataset(Dataset):
Expand Down Expand Up @@ -48,18 +50,16 @@ class Model(nn.Module):

def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 1)
self.param = nn.Parameter(torch.tensor([1.0]))

def forward(self, x, **kwargs):
return x
return self.param * x

def train_step(self, data_batch, optimizer, **kwargs):
if not isinstance(data_batch, dict):
data_batch = dict(x=data_batch)
return data_batch
return {'loss': torch.sum(self(data_batch['x']))}

def val_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
def val_step(self, data_batch, optimizer, **kwargs):
return {'loss': torch.sum(self(data_batch['x']))}


def _build_epoch_runner():
Expand Down Expand Up @@ -307,7 +307,7 @@ def test_eval_hook():
(_build_iter_runner, False)])
def test_start_param(EvalHookParam, _build_demo_runner, by_epoch):
# create dummy data
dataloader = DataLoader(torch.ones((5, 2)))
dataloader = DataLoader(EvalDataset())

# 0.1. dataloader is not a DataLoader object
with pytest.raises(TypeError):
Expand Down Expand Up @@ -389,3 +389,36 @@ def test_start_param(EvalHookParam, _build_demo_runner, by_epoch):
runner._iter = 1
runner.run([dataloader], [('train', 1)], 3)
assert evalhook.evaluate.call_count == 2 # after epoch 2 & 3


@pytest.mark.parametrize('runner,by_epoch,eval_hook_priority',
[(EpochBasedRunner, True, 'NORMAL'),
(EpochBasedRunner, True, 'LOW'),
(IterBasedRunner, False, 'LOW')])
def test_logger(runner, by_epoch, eval_hook_priority):
loader = DataLoader(EvalDataset())
model = Model()
data_loader = DataLoader(EvalDataset())
eval_hook = EvalHook(
data_loader, interval=1, by_epoch=by_epoch, save_best='acc')

with tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_logger')
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
runner = EpochBasedRunner(
model=model, optimizer=optimizer, work_dir=tmpdir, logger=logger)
runner.register_logger_hooks(
dict(
interval=1,
hooks=[dict(type='TextLoggerHook', by_epoch=by_epoch)]))
runner.register_timer_hook(dict(type='IterTimerHook'))
runner.register_hook(eval_hook, priority=eval_hook_priority)
runner.run([loader], [('train', 1)], 1)

path = osp.join(tmpdir, next(scandir(tmpdir, '.json')))
with open(path) as fr:
fr.readline() # skip first line which is hook_msg
train_log = json.loads(fr.readline())
assert train_log['mode'] == 'train' and 'time' in train_log
val_log = json.loads(fr.readline())
assert val_log['mode'] == 'val' and 'time' not in val_log

0 comments on commit 2f934a8

Please sign in to comment.