4
4
#include < iostream>
5
5
#include < stdlib.h>
6
6
7
-
8
7
#define CUDA torch::kCUDA
9
8
#define CPU torch::kCPU
10
9
@@ -51,6 +50,7 @@ void Backend::perform(std::vector<float *> in_buffer,
51
50
52
51
// PROCESS TENSOR
53
52
at::Tensor tensor_out;
53
+ std::unique_lock<std::mutex> model_lock (m_model_mutex);
54
54
try {
55
55
tensor_out = m_model.get_method (method)(inputs).toTensor ();
56
56
tensor_out = tensor_out.repeat_interleave (out_ratio).reshape (
@@ -59,6 +59,8 @@ void Backend::perform(std::vector<float *> in_buffer,
59
59
std::cerr << e.what () << ' \n ' ;
60
60
return ;
61
61
}
62
+ model_lock.unlock ();
63
+
62
64
int out_batches (tensor_out.size (0 )), out_channels (tensor_out.size (1 )),
63
65
out_n_vec (tensor_out.size (2 ));
64
66
@@ -94,6 +96,7 @@ int Backend::load(std::string path) {
94
96
}
95
97
m_model = model;
96
98
m_loaded = 1 ;
99
+ m_available_methods = get_available_methods ();
97
100
m_path = path;
98
101
return 0 ;
99
102
} catch (const std::exception &e) {
@@ -102,11 +105,10 @@ int Backend::load(std::string path) {
102
105
}
103
106
}
104
107
105
- int Backend::reload (){
106
- return load (m_path);
107
- }
108
+ int Backend::reload () { return load (m_path); }
108
109
109
110
bool Backend::has_method (std::string method_name) {
111
+ std::unique_lock<std::mutex> model_lock (m_model_mutex);
110
112
for (const auto &m : m_model.get_methods ()) {
111
113
if (m.name () == method_name)
112
114
return true ;
@@ -126,25 +128,32 @@ std::vector<std::string> Backend::get_available_methods() {
126
128
std::vector<std::string> methods;
127
129
try {
128
130
std::vector<c10::IValue> dumb_input = {};
131
+
132
+ std::unique_lock<std::mutex> model_lock (m_model_mutex);
129
133
auto methods_from_model =
130
134
m_model.get_method (" get_methods" )(dumb_input).toList ();
135
+ model_lock.unlock ();
136
+
131
137
for (int i = 0 ; i < methods_from_model.size (); i++) {
132
138
methods.push_back (methods_from_model.get (i).toStringRef ());
133
139
}
134
140
} catch (...) {
141
+ std::unique_lock<std::mutex> model_lock (m_model_mutex);
135
142
for (const auto &m : m_model.get_methods ()) {
136
143
try {
137
144
auto method_params = m_model.attr (m.name () + " _params" );
138
145
methods.push_back (m.name ());
139
146
} catch (...) {
140
147
}
141
148
}
149
+ model_lock.unlock ();
142
150
}
143
151
return methods;
144
152
}
145
153
146
154
std::vector<std::string> Backend::get_available_attributes () {
147
155
std::vector<std::string> attributes;
156
+ std::unique_lock<std::mutex> model_lock (m_model_mutex);
148
157
for (const auto &attribute : m_model.named_attributes ())
149
158
attributes.push_back (attribute.name );
150
159
return attributes;
@@ -154,44 +163,56 @@ std::vector<std::string> Backend::get_settable_attributes() {
154
163
std::vector<std::string> attributes;
155
164
try {
156
165
std::vector<c10::IValue> dumb_input = {};
166
+ std::unique_lock<std::mutex> model_lock (m_model_mutex);
157
167
auto methods_from_model =
158
168
m_model.get_method (" get_attributes" )(dumb_input).toList ();
169
+ model_lock.unlock ();
159
170
for (int i = 0 ; i < methods_from_model.size (); i++) {
160
171
attributes.push_back (methods_from_model.get (i).toStringRef ());
161
172
}
162
173
} catch (...) {
174
+ std::unique_lock<std::mutex> model_lock (m_model_mutex);
163
175
for (const auto &a : m_model.named_attributes ()) {
164
176
try {
165
177
auto method_params = m_model.attr (a.name + " _params" );
166
178
attributes.push_back (a.name );
167
179
} catch (...) {
168
180
}
169
181
}
182
+ model_lock.unlock ();
170
183
}
171
184
return attributes;
172
185
}
173
186
174
187
std::vector<c10::IValue> Backend::get_attribute (std::string attribute_name) {
175
188
std::string attribute_getter_name = " get_" + attribute_name;
176
189
try {
190
+ std::unique_lock<std::mutex> model_lock (m_model_mutex);
177
191
auto attribute_getter = m_model.get_method (attribute_getter_name);
192
+ model_lock.unlock ();
178
193
} catch (...) {
179
194
throw " getter for attribute " + attribute_name + " not found in model" ;
180
195
}
181
196
std::vector<c10::IValue> getter_inputs = {}, attributes;
182
197
try {
183
198
try {
199
+ std::unique_lock<std::mutex> model_lock (m_model_mutex);
184
200
attributes = m_model.get_method (attribute_getter_name)(getter_inputs)
185
201
.toList ()
186
202
.vec ();
203
+ model_lock.unlock ();
187
204
} catch (...) {
205
+ std::unique_lock<std::mutex> model_lock (m_model_mutex);
188
206
auto output_tuple =
189
207
m_model.get_method (attribute_getter_name)(getter_inputs).toTuple ();
190
208
attributes = (*output_tuple.get ()).elements ();
209
+ model_lock.unlock ();
191
210
}
192
211
} catch (...) {
212
+ std::unique_lock<std::mutex> model_lock (m_model_mutex);
193
213
attributes.push_back (
194
214
m_model.get_method (attribute_getter_name)(getter_inputs));
215
+ model_lock.unlock ();
195
216
}
196
217
return attributes;
197
218
}
@@ -201,7 +222,9 @@ std::string Backend::get_attribute_as_string(std::string attribute_name) {
201
222
// finstringd arguments
202
223
torch::Tensor setter_params;
203
224
try {
225
+ std::unique_lock<std::mutex> model_lock (m_model_mutex);
204
226
setter_params = m_model.attr (attribute_name + " _params" ).toTensor ();
227
+ model_lock.unlock ();
205
228
} catch (...) {
206
229
throw " parameters to set attribute " + attribute_name +
207
230
" not found in model" ;
@@ -248,14 +271,18 @@ void Backend::set_attribute(std::string attribute_name,
248
271
// find setter
249
272
std::string attribute_setter_name = " set_" + attribute_name;
250
273
try {
274
+ std::unique_lock<std::mutex> model_lock (m_model_mutex);
251
275
auto attribute_setter = m_model.get_method (attribute_setter_name);
276
+ model_lock.unlock ();
252
277
} catch (...) {
253
278
throw " setter for attribute " + attribute_name + " not found in model" ;
254
279
}
255
280
// find arguments
256
281
torch::Tensor setter_params;
257
282
try {
283
+ std::unique_lock<std::mutex> model_lock (m_model_mutex);
258
284
setter_params = m_model.attr (attribute_name + " _params" ).toTensor ();
285
+ model_lock.unlock ();
259
286
} catch (...) {
260
287
throw " parameters to set attribute " + attribute_name +
261
288
" not found in model" ;
@@ -288,7 +315,9 @@ void Backend::set_attribute(std::string attribute_name,
288
315
}
289
316
}
290
317
try {
318
+ std::unique_lock<std::mutex> model_lock (m_model_mutex);
291
319
auto setter_out = m_model.get_method (attribute_setter_name)(setter_inputs);
320
+ model_lock.unlock ();
292
321
int setter_result = setter_out.toInt ();
293
322
if (setter_result != 0 ) {
294
323
throw " setter returned -1" ;
@@ -299,12 +328,14 @@ void Backend::set_attribute(std::string attribute_name,
299
328
}
300
329
301
330
std::vector<int > Backend::get_method_params (std::string method) {
302
- auto am = get_available_methods ();
303
331
std::vector<int > params;
304
332
305
- if (std::find (am.begin (), am.end (), method) != am.end ()) {
333
+ if (std::find (m_available_methods.begin (), m_available_methods.end (),
334
+ method) != m_available_methods.end ()) {
306
335
try {
336
+ std::unique_lock<std::mutex> model_lock (m_model_mutex);
307
337
auto p = m_model.attr (method + " _params" ).toTensor ();
338
+ model_lock.unlock ();
308
339
for (int i (0 ); i < 4 ; i++)
309
340
params.push_back (p[i].item ().to <int >());
310
341
} catch (...) {
@@ -315,8 +346,7 @@ std::vector<int> Backend::get_method_params(std::string method) {
315
346
316
347
int Backend::get_higher_ratio () {
317
348
int higher_ratio = 1 ;
318
- auto model_methods = get_available_methods ();
319
- for (const auto &method : model_methods) {
349
+ for (const auto &method : m_available_methods) {
320
350
auto params = get_method_params (method);
321
351
if (!params.size ())
322
352
continue ; // METHOD NOT USABLE, SKIPPING
0 commit comments