Skip to content

Commit 29c82ea

Browse files
authored
[Fix] Fix loss parse in val_step (#906)
* [Fix] Fix loss parse in val_step * Add val_step unittest * Add train_step unittest
1 parent e171e80 commit 29c82ea

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

mmseg/models/segmentors/base.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,22 @@ def train_step(self, data_batch, optimizer, **kwargs):
145145

146146
return outputs
147147

148-
def val_step(self, data_batch, **kwargs):
148+
def val_step(self, data_batch, optimizer=None, **kwargs):
149149
"""The iteration step during validation.
150150
151151
This method shares the same signature as :func:`train_step`, but used
152152
during val epochs. Note that the evaluation after training epochs is
153153
not implemented with this method, but an evaluation hook.
154154
"""
155-
output = self(**data_batch, **kwargs)
156-
return output
155+
losses = self(**data_batch)
156+
loss, log_vars = self._parse_losses(losses)
157+
158+
outputs = dict(
159+
loss=loss,
160+
log_vars=log_vars,
161+
num_samples=len(data_batch['img_metas']))
162+
163+
return outputs
157164

158165
@staticmethod
159166
def _parse_losses(losses):

tests/test_models/test_segmentors/utils.py

+20
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,26 @@ def _segmentor_forward_train_test(segmentor):
101101
imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True)
102102
assert isinstance(losses, dict)
103103

104+
# Test train_step
105+
data_batch = dict(
106+
img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg)
107+
outputs = segmentor.train_step(data_batch, None)
108+
assert isinstance(outputs, dict)
109+
assert 'loss' in outputs
110+
assert 'log_vars' in outputs
111+
assert 'num_samples' in outputs
112+
113+
# Test val_step
114+
with torch.no_grad():
115+
segmentor.eval()
116+
data_batch = dict(
117+
img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg)
118+
outputs = segmentor.val_step(data_batch, None)
119+
assert isinstance(outputs, dict)
120+
assert 'loss' in outputs
121+
assert 'log_vars' in outputs
122+
assert 'num_samples' in outputs
123+
104124
# Test forward simple test
105125
with torch.no_grad():
106126
segmentor.eval()

0 commit comments

Comments
 (0)