Skip to content

Commit

Permalink
support batchsize > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
xiliu8006 committed Apr 5, 2021
1 parent 74e4a84 commit 765de4b
Showing 1 changed file with 35 additions and 5 deletions.
40 changes: 35 additions & 5 deletions mmdet3d/apis/test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import mmcv
import os
import torch
from mmcv.image import tensor2imgs


def single_gpu_test(model, data_loader, show=False, out_dir=None):
def single_gpu_test(model,
data_loader,
show=False,
out_dir=None,
show_score_thr=0.3):
"""Test model with single gpu.
This method tests model with single gpu and gives the 'show' option.
Expand Down Expand Up @@ -33,10 +38,35 @@ def single_gpu_test(model, data_loader, show=False, out_dir=None):
if hasattr(model.module, 'show_results'):
model.module.show_results(data, result, out_dir)
else:
img_file = data['img_metas'][0].data[0][0]['filename']
outfile = os.path.basename(img_file)
outfile = os.path.join(out_dir, outfile)
model.module.show_result(img_file, result[0], out_file=outfile)
batch_size = len(result)
if batch_size == 1 and isinstance(data['img'][0],
torch.Tensor):
img_tensor = data['img'][0]
else:
img_tensor = data['img'][0].data[0]
img_metas = data['img_metas'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
assert len(imgs) == len(img_metas)

for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]

ori_h, ori_w = img_meta['ori_shape'][:-1]
img_show = mmcv.imresize(img_show, (ori_w, ori_h))

if out_dir:
out_file = os.path.join(out_dir,
img_meta['ori_filename'])
else:
out_file = None

model.module.show_result(
img_show,
result[i],
show=show,
out_file=out_file,
score_thr=show_score_thr)
results.extend(result)

batch_size = len(result)
Expand Down

0 comments on commit 765de4b

Please sign in to comment.