-
Notifications
You must be signed in to change notification settings - Fork 52
/
demo.py
52 lines (45 loc) · 1.94 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from models import vgg19
import gdown
from PIL import Image
from torchvision import transforms
import gradio as gr
import cv2
import numpy as np
import scipy
model_path = "pretrained_models/model_qnrf.pth"
url = "https://drive.google.com/uc?id=1nnIHPaV9RGqK8JHL645zmRvkNrahD9ru"
gdown.download(url, model_path, quiet=False)
device = torch.device('cpu') # device can be "cpu" or "gpu"
model = vgg19()
model.to(device)
model.load_state_dict(torch.load(model_path, device))
model.eval()
def predict(inp):
inp = Image.fromarray(inp.astype('uint8'), 'RGB')
inp = transforms.ToTensor()(inp).unsqueeze(0)
inp = inp.to(device)
with torch.set_grad_enabled(False):
outputs, _ = model(inp)
count = torch.sum(outputs).item()
vis_img = outputs[0, 0].cpu().numpy()
# normalize density map values from 0 to 1, then map it to 0-255.
vis_img = (vis_img - vis_img.min()) / (vis_img.max() - vis_img.min() + 1e-5)
vis_img = (vis_img * 255).astype(np.uint8)
vis_img = cv2.applyColorMap(vis_img, cv2.COLORMAP_JET)
vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB)
return vis_img, int(count)
title = "Distribution Matching for Crowd Counting"
desc = "A demo of DM-Count, a NeurIPS 2020 paper by Wang et al. Outperforms the state-of-the-art methods by a " \
"large margin on four challenging crowd counting datasets: UCF-QNRF, NWPU, ShanghaiTech, and UCF-CC50. " \
"This demo uses the QNRF trained model. Try it by uploading an image or clicking on an example " \
"(could take up to 20s if running on CPU)."
examples = [
["example_images/3.png"],
["example_images/2.png"],
["example_images/1.png"],
]
inputs = gr.inputs.Image(label="Image of Crowd")
outputs = [gr.outputs.Image(label="Predicted Density Map"), gr.outputs.Label(label="Predicted Count")]
gr.Interface(fn=predict, inputs=inputs, outputs=outputs, title=title, description=desc, examples=examples,
allow_flagging=False).launch()