forked from fbelderink/flutter_pytorch_mobile
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcreate_d2go.py
96 lines (80 loc) · 3.87 KB
/
create_d2go.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python3
import contextlib
import copy
import os
import unittest
from PIL import Image
import torch
from d2go.export.api import convert_and_export_predictor
from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.runner import create_runner, GeneralizedRCNNRunner
from d2go.model_zoo import model_zoo
from mobile_cv.common.misc.file_utils import make_temp_directory
from d2go.utils.testing.data_loader_helper import LocalImageGenerator, _register_toy_dataset
patch_d2_meta_arch()
@contextlib.contextmanager
def create_fake_detection_data_loader(height, width, is_train):
with make_temp_directory("detectron2go_tmp_dataset") as dataset_dir:
runner = create_runner("d2go.runner.GeneralizedRCNNRunner")
cfg = runner.get_default_cfg()
cfg.DATASETS.TRAIN = ["default_dataset_train"]
cfg.DATASETS.TEST = ["default_dataset_test"]
with make_temp_directory("detectron2go_tmp_dataset") as dataset_dir:
image_dir = os.path.join(dataset_dir, "images")
os.makedirs(image_dir)
image_generator = LocalImageGenerator(image_dir, width=width, height=height)
if is_train:
with _register_toy_dataset(
"default_dataset_train", image_generator, num_images=3
):
train_loader = runner.build_detection_train_loader(cfg)
yield train_loader
else:
with _register_toy_dataset(
"default_dataset_test", image_generator, num_images=3
):
test_loader = runner.build_detection_test_loader(
cfg, dataset_name="default_dataset_test"
)
yield test_loader
def test_export_torchvision_format():
cfg_name = 'faster_rcnn_fbnetv3a_dsmask_C4.yaml'
pytorch_model = model_zoo.get(cfg_name, trained=True)
from typing import List, Dict
class Wrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
coco_idx_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77,
78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91]
self.coco_idx = torch.tensor(coco_idx_list)
def forward(self, inputs: torch.Tensor):
x = inputs * 255
scale = 320.0 / min(x.shape[-2], x.shape[-1])
x = torch.nn.functional.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=True, recompute_scale_factor=True)
out = self.model(x[0])
res : Dict[str, torch.Tensor] = {}
res["boxes"] = out[0] / scale
res["labels"] = torch.index_select(self.coco_idx, 0, out[1])
res["scores"] = out[2]
return res
size_divisibility = max(pytorch_model.backbone.size_divisibility, 10)
h, w = size_divisibility, size_divisibility * 2
with create_fake_detection_data_loader(h, w, is_train=False) as data_loader:
predictor_path = convert_and_export_predictor(
model_zoo.get_config(cfg_name),
copy.deepcopy(pytorch_model),
"torchscript_int8@tracing",
'./',
data_loader,
)
orig_model = torch.jit.load(os.path.join(predictor_path, "model.jit"))
wrapped_model = Wrapper(orig_model)
# optionally do a forward
wrapped_model(torch.rand(1, 3, 600, 600))
scripted_model = torch.jit.script(wrapped_model)
scripted_model.save("ObjectDetection/app/src/main/assets/d2go.pt")
if __name__ == '__main__':
test_export_torchvision_format()