-
Notifications
You must be signed in to change notification settings - Fork 1
/
convert_to_onnx.py
28 lines (23 loc) · 840 Bytes
/
convert_to_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import sys
import torch.onnx
from vision.ssd.config.fd_config import define_img_size
input_img_size = 320
define_img_size(input_img_size)
from vision.ssd.mb_tiny_RFB_fd import create_Mb_Tiny_RFB_fd
from vision.ssd.mb_tiny_fd import create_mb_tiny_fd
net_type = "RFB"
if net_type == 'slim':
model_path = "models/pretrained/version-slim-320.pth"
net = create_mb_tiny_fd(2, is_test=True)
elif net_type == 'RFB':
model_path = "models/pretrained/version-RFB-320.pth"
net = create_Mb_Tiny_RFB_fd(2, is_test=True)
else:
print("unsupport network type.")
sys.exit(1)
net.load(model_path)
net.eval()
net.to("cpu")
model_path = f"models/onnx/RFB-test.onnx"
dummy_input = torch.randn(1, 3, 240, 320)
torch.onnx.export(net, dummy_input, model_path, verbose=False, input_names=['input'], output_names=['scores', 'boxes'])