From ce56e68d303da6a214298eb6e4dd1fc6ed42e7fc Mon Sep 17 00:00:00 2001 From: Ziyi Wu Date: Thu, 29 Apr 2021 16:01:34 +0800 Subject: [PATCH] [Enhance] Replace data_dict calling 'img' key to support MMDet3D (#514) * remove dict calling img key for compatibility * fix unit test * infer batch size using len(result) to be consistent with mmcv --- mmseg/apis/test.py | 2 +- mmseg/models/segmentors/base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mmseg/apis/test.py b/mmseg/apis/test.py index 2b9cc17033..1597df6aa3 100644 --- a/mmseg/apis/test.py +++ b/mmseg/apis/test.py @@ -97,7 +97,7 @@ def single_gpu_test(model, result = np2tmp(result) results.append(result) - batch_size = data['img'][0].size(0) + batch_size = len(result) for _ in range(batch_size): prog_bar.update() return results diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 58c31887f3..7b53757537 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -155,7 +155,7 @@ def train_step(self, data_batch, optimizer, **kwargs): outputs = dict( loss=loss, log_vars=log_vars, - num_samples=len(data_batch['img'].data)) + num_samples=len(data_batch['img_metas'])) return outputs