Skip to content

Commit 816529c

Browse files
backend add locking mechanism
1 parent c9af769 commit 816529c

File tree

4 files changed

+42
-10
lines changed

4 files changed

+42
-10
lines changed

src/backend/backend.cpp

+38-8
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <iostream>
55
#include <stdlib.h>
66

7-
87
#define CUDA torch::kCUDA
98
#define CPU torch::kCPU
109

@@ -51,6 +50,7 @@ void Backend::perform(std::vector<float *> in_buffer,
5150

5251
// PROCESS TENSOR
5352
at::Tensor tensor_out;
53+
std::unique_lock<std::mutex> model_lock(m_model_mutex);
5454
try {
5555
tensor_out = m_model.get_method(method)(inputs).toTensor();
5656
tensor_out = tensor_out.repeat_interleave(out_ratio).reshape(
@@ -59,6 +59,8 @@ void Backend::perform(std::vector<float *> in_buffer,
5959
std::cerr << e.what() << '\n';
6060
return;
6161
}
62+
model_lock.unlock();
63+
6264
int out_batches(tensor_out.size(0)), out_channels(tensor_out.size(1)),
6365
out_n_vec(tensor_out.size(2));
6466

@@ -94,6 +96,7 @@ int Backend::load(std::string path) {
9496
}
9597
m_model = model;
9698
m_loaded = 1;
99+
m_available_methods = get_available_methods();
97100
m_path = path;
98101
return 0;
99102
} catch (const std::exception &e) {
@@ -102,11 +105,10 @@ int Backend::load(std::string path) {
102105
}
103106
}
104107

105-
int Backend::reload(){
106-
return load(m_path);
107-
}
108+
int Backend::reload() { return load(m_path); }
108109

109110
bool Backend::has_method(std::string method_name) {
111+
std::unique_lock<std::mutex> model_lock(m_model_mutex);
110112
for (const auto &m : m_model.get_methods()) {
111113
if (m.name() == method_name)
112114
return true;
@@ -126,25 +128,32 @@ std::vector<std::string> Backend::get_available_methods() {
126128
std::vector<std::string> methods;
127129
try {
128130
std::vector<c10::IValue> dumb_input = {};
131+
132+
std::unique_lock<std::mutex> model_lock(m_model_mutex);
129133
auto methods_from_model =
130134
m_model.get_method("get_methods")(dumb_input).toList();
135+
model_lock.unlock();
136+
131137
for (int i = 0; i < methods_from_model.size(); i++) {
132138
methods.push_back(methods_from_model.get(i).toStringRef());
133139
}
134140
} catch (...) {
141+
std::unique_lock<std::mutex> model_lock(m_model_mutex);
135142
for (const auto &m : m_model.get_methods()) {
136143
try {
137144
auto method_params = m_model.attr(m.name() + "_params");
138145
methods.push_back(m.name());
139146
} catch (...) {
140147
}
141148
}
149+
model_lock.unlock();
142150
}
143151
return methods;
144152
}
145153

146154
std::vector<std::string> Backend::get_available_attributes() {
147155
std::vector<std::string> attributes;
156+
std::unique_lock<std::mutex> model_lock(m_model_mutex);
148157
for (const auto &attribute : m_model.named_attributes())
149158
attributes.push_back(attribute.name);
150159
return attributes;
@@ -154,44 +163,56 @@ std::vector<std::string> Backend::get_settable_attributes() {
154163
std::vector<std::string> attributes;
155164
try {
156165
std::vector<c10::IValue> dumb_input = {};
166+
std::unique_lock<std::mutex> model_lock(m_model_mutex);
157167
auto methods_from_model =
158168
m_model.get_method("get_attributes")(dumb_input).toList();
169+
model_lock.unlock();
159170
for (int i = 0; i < methods_from_model.size(); i++) {
160171
attributes.push_back(methods_from_model.get(i).toStringRef());
161172
}
162173
} catch (...) {
174+
std::unique_lock<std::mutex> model_lock(m_model_mutex);
163175
for (const auto &a : m_model.named_attributes()) {
164176
try {
165177
auto method_params = m_model.attr(a.name + "_params");
166178
attributes.push_back(a.name);
167179
} catch (...) {
168180
}
169181
}
182+
model_lock.unlock();
170183
}
171184
return attributes;
172185
}
173186

174187
std::vector<c10::IValue> Backend::get_attribute(std::string attribute_name) {
175188
std::string attribute_getter_name = "get_" + attribute_name;
176189
try {
190+
std::unique_lock<std::mutex> model_lock(m_model_mutex);
177191
auto attribute_getter = m_model.get_method(attribute_getter_name);
192+
model_lock.unlock();
178193
} catch (...) {
179194
throw "getter for attribute " + attribute_name + " not found in model";
180195
}
181196
std::vector<c10::IValue> getter_inputs = {}, attributes;
182197
try {
183198
try {
199+
std::unique_lock<std::mutex> model_lock(m_model_mutex);
184200
attributes = m_model.get_method(attribute_getter_name)(getter_inputs)
185201
.toList()
186202
.vec();
203+
model_lock.unlock();
187204
} catch (...) {
205+
std::unique_lock<std::mutex> model_lock(m_model_mutex);
188206
auto output_tuple =
189207
m_model.get_method(attribute_getter_name)(getter_inputs).toTuple();
190208
attributes = (*output_tuple.get()).elements();
209+
model_lock.unlock();
191210
}
192211
} catch (...) {
212+
std::unique_lock<std::mutex> model_lock(m_model_mutex);
193213
attributes.push_back(
194214
m_model.get_method(attribute_getter_name)(getter_inputs));
215+
model_lock.unlock();
195216
}
196217
return attributes;
197218
}
@@ -201,7 +222,9 @@ std::string Backend::get_attribute_as_string(std::string attribute_name) {
201222
// finstringd arguments
202223
torch::Tensor setter_params;
203224
try {
225+
std::unique_lock<std::mutex> model_lock(m_model_mutex);
204226
setter_params = m_model.attr(attribute_name + "_params").toTensor();
227+
model_lock.unlock();
205228
} catch (...) {
206229
throw "parameters to set attribute " + attribute_name +
207230
" not found in model";
@@ -248,14 +271,18 @@ void Backend::set_attribute(std::string attribute_name,
248271
// find setter
249272
std::string attribute_setter_name = "set_" + attribute_name;
250273
try {
274+
std::unique_lock<std::mutex> model_lock(m_model_mutex);
251275
auto attribute_setter = m_model.get_method(attribute_setter_name);
276+
model_lock.unlock();
252277
} catch (...) {
253278
throw "setter for attribute " + attribute_name + " not found in model";
254279
}
255280
// find arguments
256281
torch::Tensor setter_params;
257282
try {
283+
std::unique_lock<std::mutex> model_lock(m_model_mutex);
258284
setter_params = m_model.attr(attribute_name + "_params").toTensor();
285+
model_lock.unlock();
259286
} catch (...) {
260287
throw "parameters to set attribute " + attribute_name +
261288
" not found in model";
@@ -288,7 +315,9 @@ void Backend::set_attribute(std::string attribute_name,
288315
}
289316
}
290317
try {
318+
std::unique_lock<std::mutex> model_lock(m_model_mutex);
291319
auto setter_out = m_model.get_method(attribute_setter_name)(setter_inputs);
320+
model_lock.unlock();
292321
int setter_result = setter_out.toInt();
293322
if (setter_result != 0) {
294323
throw "setter returned -1";
@@ -299,12 +328,14 @@ void Backend::set_attribute(std::string attribute_name,
299328
}
300329

301330
std::vector<int> Backend::get_method_params(std::string method) {
302-
auto am = get_available_methods();
303331
std::vector<int> params;
304332

305-
if (std::find(am.begin(), am.end(), method) != am.end()) {
333+
if (std::find(m_available_methods.begin(), m_available_methods.end(),
334+
method) != m_available_methods.end()) {
306335
try {
336+
std::unique_lock<std::mutex> model_lock(m_model_mutex);
307337
auto p = m_model.attr(method + "_params").toTensor();
338+
model_lock.unlock();
308339
for (int i(0); i < 4; i++)
309340
params.push_back(p[i].item().to<int>());
310341
} catch (...) {
@@ -315,8 +346,7 @@ std::vector<int> Backend::get_method_params(std::string method) {
315346

316347
int Backend::get_higher_ratio() {
317348
int higher_ratio = 1;
318-
auto model_methods = get_available_methods();
319-
for (const auto &method : model_methods) {
349+
for (const auto &method : m_available_methods) {
320350
auto params = get_method_params(method);
321351
if (!params.size())
322352
continue; // METHOD NOT USABLE, SKIPPING

src/backend/backend.h

+3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#pragma once
2+
#include <mutex>
23
#include <string>
34
#include <torch/script.h>
45
#include <torch/torch.h>
@@ -9,6 +10,8 @@ class Backend {
910
torch::jit::script::Module m_model;
1011
int m_loaded;
1112
std::string m_path;
13+
std::mutex m_model_mutex;
14+
std::vector<std::string> m_available_methods;
1215

1316
public:
1417
Backend();

src/frontend/maxmsp/mcs.nn_tilde/mcs.nn_tilde.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class mc_bnn_tilde : public object<mc_bnn_tilde>, public mc_operator<> {
3838
// INLETS OUTLETS
3939
std::vector<std::unique_ptr<inlet<>>> m_inlets;
4040
std::vector<std::unique_ptr<outlet<>>> m_outlets;
41+
4142
// CHANNELS
4243
std::vector<int> input_chans;
4344
int get_batches();

src/frontend/maxmsp/nn_tilde/nn_tilde.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ class nn : public object<nn>, public vector_operator<> {
5454
void operator()(audio_bundle input, audio_bundle output);
5555
void perform(audio_bundle input, audio_bundle output);
5656

57-
// using vector_operator::operator();
58-
5957
// ONLY FOR DOCUMENTATION
6058
argument<symbol> path_arg{this, "model path",
6159
"Absolute path to the pretrained model."};

0 commit comments

Comments
 (0)