-
Notifications
You must be signed in to change notification settings - Fork 2
/
load_graph.cpp
126 lines (109 loc) · 3.55 KB
/
load_graph.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#define COMPILER_MSVC
#define NOMINMAX
#include <string>
#include <vector>
#include <fstream>
#include <tensorflow/cc/client/client_session.h>
#include <tensorflow/core/public/session.h>
#include <tensorflow/cc/ops/standard_ops.h>
#include <tensorflow/core/platform/env.h>
#include <tensorflow/cc/ops/image_ops.h>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
using namespace std;
using namespace tensorflow;
using namespace tensorflow::ops;
void ReadImage(const vector<string> &file_names, const int height, const int width, vector<Tensor> &out_tensors)
{
vector<Tensor> temp;
for (auto &file_name : file_names)
{
ifstream file_in(file_name, ios::binary);
Tensor image_tensor(DT_INT16, TensorShape({ 1, height, width, 1 }));
short *image_data = image_tensor.flat<short>().data();
file_in.read((char *)image_data, height * width * sizeof(short));
file_in.close();
temp.push_back(image_tensor);
}
// Construct graph to adjust window level and width
Scope root = Scope::NewRootScope();
auto input_ = Placeholder(root, DT_INT16, Placeholder::Shape({ 1, height, width, 1 }));
auto cliped = ClipByValue(root, input_, (short)-70, (short)180);
auto shifted = Sub(root, cliped, (short)-70);
auto casted = Cast(root, shifted, DT_FLOAT);
auto scaled = Div(root, casted, 250.0f);
GraphDef graphdef;
TF_CHECK_OK(root.ToGraphDef(&graphdef));
ClientSession session(root);
vector<Tensor> outputs;
for (auto &img : temp)
{
TF_CHECK_OK(session.Run({ { input_, img} }, { scaled }, &outputs));
out_tensors.push_back(outputs[0]);
outputs.clear();
}
}
void SaveImage(const vector<string> &names, const vector<Tensor> &masks)
{
for (size_t i(0); i < names.size(); ++i)
{
const int *p = masks[i].flat<int>().data();
auto shape = masks[i].shape().dim_sizes();
int *arr = new int[shape[1] * shape[2]];
memcpy(arr, p, shape[1] * shape[2] * sizeof(int));
cv::Mat image_mat((int)shape[1], (int)shape[2], CV_32S, arr);
image_mat.convertTo(image_mat, CV_8U, 255, 0);
bool flag = cv::imwrite(names[i], image_mat);
cout << names[i] << ": " << flag << endl;
delete[] arr;
}
}
int main(int argc, char **argv)
{
if (argc < 3)
{
cout << "ERROR: Need at least 3 parameters!" << endl;
return -1;
}
// load model
string model_path(argv[1]);
GraphDef graphdef;
Status load_status = ReadBinaryProto(Env::Default(), model_path, &graphdef);
if (!load_status.ok())
{
cout << "ERROR: Loading model failed! " << model_path << endl;
cout << load_status.ToString() << endl;
return -1;
}
cout << "INFO: Model loaded." << endl;
// import model to session
SessionOptions options;
unique_ptr<Session> session(NewSession(options));
Status create_status = session->Create(graphdef);
if (!create_status.ok())
{
cout << "ERROR: Creating graph in session failed! " << endl;
cout << create_status.ToString() << endl;
return -1;
}
cout << "INFO: Session successfully created." << endl;
// read image
vector<string> image_paths;
for (int i(2); i < argc; ++i)
image_paths.push_back(string(argv[i]));
vector<Tensor> image_tensors;
ReadImage(image_paths, 512, 512, image_tensors);
vector<Tensor> outputs;
vector<Tensor> masks;
for (auto &image_tensor : image_tensors)
{
TF_CHECK_OK(session->Run({ {"Image", image_tensor} }, { "FCN-DenseNet/BinaryPred/Pred2Binary" }, {}, &outputs));
masks.push_back(outputs[0]);
outputs.clear();
}
cout << "INFO: Run inference finished." << endl;
vector<string> out_paths;
for (auto &path : image_paths)
out_paths.push_back(string(path.begin(), path.end() - 3) + "jpg");
SaveImage(out_paths, masks);
}