From eb87ace6bbe3f79595e4705636bfe4c4e2a510fe Mon Sep 17 00:00:00 2001 From: zhangyajie Date: Sat, 25 Sep 2021 10:34:15 +0800 Subject: [PATCH 1/3] [Fix] Fix loss parse in val_step --- mmseg/models/segmentors/base.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 906c6fe564..944da0f2e4 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -145,15 +145,22 @@ def train_step(self, data_batch, optimizer, **kwargs): return outputs - def val_step(self, data_batch, **kwargs): + def val_step(self, data_batch, optimizer=None, **kwargs): """The iteration step during validation. This method shares the same signature as :func:`train_step`, but used during val epochs. Note that the evaluation after training epochs is not implemented with this method, but an evaluation hook. """ - output = self(**data_batch, **kwargs) - return output + losses = self(**data_batch) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, + log_vars=log_vars, + num_samples=len(data_batch['img_metas'])) + + return outputs @staticmethod def _parse_losses(losses): From 0ecffa151cc608f5e75a5f2e3fcd32c725a0fb7c Mon Sep 17 00:00:00 2001 From: zhangyajie Date: Sat, 25 Sep 2021 13:33:27 +0800 Subject: [PATCH 2/3] Add val_step unittest --- tests/test_models/test_segmentors/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_models/test_segmentors/utils.py b/tests/test_models/test_segmentors/utils.py index 0f51a4b1f5..1a4c825f23 100644 --- a/tests/test_models/test_segmentors/utils.py +++ b/tests/test_models/test_segmentors/utils.py @@ -101,6 +101,14 @@ def _segmentor_forward_train_test(segmentor): imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True) assert isinstance(losses, dict) + # Test val_step + with torch.no_grad(): + segmentor.eval() + data_batch = dict( + img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg) + outputs = segmentor.val_step(data_batch) + assert isinstance(outputs, dict) + # Test forward simple test with torch.no_grad(): segmentor.eval() From d496e1de673a60059f4e065e6aefbfd7156992de Mon Sep 17 00:00:00 2001 From: zhangyajie Date: Sat, 25 Sep 2021 14:07:52 +0800 Subject: [PATCH 3/3] Add train_step unittest --- tests/test_models/test_segmentors/utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/test_models/test_segmentors/utils.py b/tests/test_models/test_segmentors/utils.py index 1a4c825f23..1826dbf859 100644 --- a/tests/test_models/test_segmentors/utils.py +++ b/tests/test_models/test_segmentors/utils.py @@ -101,13 +101,25 @@ def _segmentor_forward_train_test(segmentor): imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True) assert isinstance(losses, dict) + # Test train_step + data_batch = dict( + img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg) + outputs = segmentor.train_step(data_batch, None) + assert isinstance(outputs, dict) + assert 'loss' in outputs + assert 'log_vars' in outputs + assert 'num_samples' in outputs + # Test val_step with torch.no_grad(): segmentor.eval() data_batch = dict( img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg) - outputs = segmentor.val_step(data_batch) + outputs = segmentor.val_step(data_batch, None) assert isinstance(outputs, dict) + assert 'loss' in outputs + assert 'log_vars' in outputs + assert 'num_samples' in outputs # Test forward simple test with torch.no_grad():