forked from QiuJueqin/SqueezeDet-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo.py
52 lines (41 loc) · 1.65 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import glob
import tqdm
import numpy as np
import skimage.io
import torch
import torch.utils.data
from datasets.kitti import KITTI
from engine.detector import Detector
from model.squeezedet import SqueezeDet
from utils.config import Config
from utils.model import load_model
def demo(cfg):
# prepare configurations
cfg.load_model = '../models/squeezedet_kitti_epoch280.pth'
cfg.gpus = [-1] # -1 to use CPU
cfg.debug = 2 # to visualize detection boxes
dataset = KITTI('val', cfg)
cfg = Config().update_dataset_info(cfg, dataset)
# preprocess image to match model's input resolution
preprocess_func = dataset.preprocess
del dataset
# prepare model & detector
model = SqueezeDet(cfg)
model = load_model(model, cfg.load_model)
detector = Detector(model.to(cfg.device), cfg)
# prepare images
sample_images_dir = '../data/samples/kitti/testing/image_2'
sample_image_paths = glob.glob(os.path.join(sample_images_dir, '*.png'))
# detection
for path in tqdm.tqdm(sample_image_paths):
image = skimage.io.imread(path).astype(np.float32)
image_meta = {'image_id': os.path.basename(path)[:-4],
'orig_size': np.array(image.shape, dtype=np.int32)}
image, image_meta, _ = preprocess_func(image, image_meta)
image = torch.from_numpy(image.transpose(2, 0, 1)).unsqueeze(0).to(cfg.device)
image_meta = {k: torch.from_numpy(v).unsqueeze(0).to(cfg.device) if isinstance(v, np.ndarray)
else [v] for k, v in image_meta.items()}
inp = {'image': image,
'image_meta': image_meta}
_ = detector.detect(inp)