-
Notifications
You must be signed in to change notification settings - Fork 704
/
rvm.h
151 lines (117 loc) · 5.05 KB
/
rvm.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
//
// Created by DefTruth on 2021/9/20.
//
#ifndef LITE_AI_TOOLKIT_ORT_CV_RVM_H
#define LITE_AI_TOOLKIT_ORT_CV_RVM_H
#include "lite/ort/core/ort_core.h"
namespace ortcv
{
class LITE_EXPORTS RobustVideoMatting
{
private:
Ort::Env ort_env;
Ort::Session *ort_session = nullptr;
// CPU MemoryInfo
Ort::AllocatorWithDefaultOptions allocator;
Ort::MemoryInfo memory_info_handler = Ort::MemoryInfo::CreateCpu(
OrtArenaAllocator, OrtMemTypeDefault);
// hardcode input node names
unsigned int num_inputs = 6;
std::vector<const char *> input_node_names = {
"src",
"r1i",
"r2i",
"r3i",
"r4i",
"downsample_ratio"
};
// init dynamic input dims
std::vector<std::vector<int64_t>> dynamic_input_node_dims = {
{1, 3, 1280, 720}, // src (b=1,c,h,w)
{1, 1, 1, 1}, // r1i
{1, 1, 1, 1}, // r2i
{1, 1, 1, 1}, // r3i
{1, 1, 1, 1}, // r4i
{1} // downsample_ratio dsr
}; // (1, 16, ?h, ?w) for inner loop rxi
// hardcode output node names
unsigned int num_outputs = 6;
std::vector<const char *> output_node_names = {
"fgr",
"pha",
"r1o",
"r2o",
"r3o",
"r4o"
};
const LITEORT_CHAR *onnx_path = nullptr;
const char *log_id = nullptr;
bool context_is_update = false;
// input values handler & init
std::vector<float> dynamic_src_value_handler;
std::vector<float> dynamic_r1i_value_handler = {0.0f}; // init 0. with shape (1,1,1,1)
std::vector<float> dynamic_r2i_value_handler = {0.0f};
std::vector<float> dynamic_r3i_value_handler = {0.0f};
std::vector<float> dynamic_r4i_value_handler = {0.0f};
std::vector<float> dynamic_dsr_value_handler = {0.25f}; // downsample_ratio with shape (1)
protected:
const unsigned int num_threads; // initialize at runtime.
public:
explicit RobustVideoMatting(const std::string &_onnx_path, unsigned int _num_threads = 1);
~RobustVideoMatting();
protected:
RobustVideoMatting(const RobustVideoMatting &) = delete; //
RobustVideoMatting(RobustVideoMatting &&) = delete; //
RobustVideoMatting &operator=(const RobustVideoMatting &) = delete; //
RobustVideoMatting &operator=(RobustVideoMatting &&) = delete; //
private:
// return normalized src, rxi, dsr Tensors
std::vector<Ort::Value> transform(const cv::Mat &mat);
int64_t value_size_of(const std::vector<int64_t> &dims); // get value size
void generate_matting(std::vector<Ort::Value> &output_tensors,
types::MattingContent &content,
bool remove_noise = false,
bool minimum_post_process = false);
void update_context(std::vector<Ort::Value> &output_tensors);
public:
/**
* Image Matting Using RVM(https://github.com/PeterL1n/RobustVideoMatting)
* @param mat: cv::Mat BGR HWC
* @param content: types::MattingContent to catch the detected results.
* @param downsample_ratio: 0.25 by default.
* @param video_mode: false by default.
* See https://github.com/PeterL1n/RobustVideoMatting/blob/master/documentation/inference_zh_Hans.md
* @param remove_noise: remove small connected area or not
* @param minimum_post_process: if True, will run matting with minimum post process
* in order to speed up the matting processes.
*/
void detect(const cv::Mat &mat, types::MattingContent &content,
float downsample_ratio = 0.25f, bool video_mode = false,
bool remove_noise = false, bool minimum_post_process = false);
/**
* Video Matting Using RVM(https://github.com/PeterL1n/RobustVideoMatting)
* @param video_path: eg. xxx/xxx/input.mp4
* @param output_path: eg. xxx/xxx/output.mp4
* @param contents: vector of MattingContent to catch the detected results.
* @param save_contents: false by default, whether to save MattingContent.
* @param downsample_ratio: 0.25 by default.
* See https://github.com/PeterL1n/RobustVideoMatting/blob/master/documentation/inference_zh_Hans.md
* @param writer_fps: FPS for VideoWriter, 20 by default.
* @param remove_noise: remove small connected area or not
* @param minimum_post_process: if True, will run matting with minimum post process
* in order to speed up the matting processes.
* @param background: user's custom background setting, will return with this target
* background if background Mat is not empty instead of green background.
*/
void detect_video(const std::string &video_path,
const std::string &output_path,
std::vector<types::MattingContent> &contents,
bool save_contents = false,
float downsample_ratio = 0.25f,
unsigned int writer_fps = 20,
bool remove_noise = false,
bool minimum_post_process = false,
const cv::Mat &background = cv::Mat());
};
}
#endif //LITE_AI_TOOLKIT_ORT_CV_RVM_H