Skip to content

Commit 8dd747a

Browse files
add reload option (non threaded)
1 parent bc2276b commit 8dd747a

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

src/backend/backend.cpp

+10-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <iostream>
55
#include <stdlib.h>
66

7+
78
#define CUDA torch::kCUDA
89
#define CPU torch::kCPU
910

@@ -85,20 +86,26 @@ void Backend::perform(std::vector<float *> in_buffer,
8586

8687
int Backend::load(std::string path) {
8788
try {
88-
m_model = torch::jit::load(path);
89-
m_model.eval();
89+
auto model = torch::jit::load(path);
90+
model.eval();
9091
if (m_cuda_available) {
9192
std::cout << "sending model to gpu" << std::endl;
92-
m_model.to(CUDA);
93+
model.to(CUDA);
9394
}
95+
m_model = model;
9496
m_loaded = 1;
97+
m_path = path;
9598
return 0;
9699
} catch (const std::exception &e) {
97100
std::cerr << e.what() << '\n';
98101
return 1;
99102
}
100103
}
101104

105+
int Backend::reload(){
106+
return load(m_path);
107+
}
108+
102109
bool Backend::has_method(std::string method_name) {
103110
for (const auto &m : m_model.get_methods()) {
104111
if (m.name() == method_name)

src/backend/backend.h

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ class Backend {
88
private:
99
torch::jit::script::Module m_model;
1010
int m_loaded;
11+
std::string m_path;
1112

1213
public:
1314
Backend();
@@ -26,6 +27,7 @@ class Backend {
2627
std::vector<int> get_method_params(std::string method);
2728
int get_higher_ratio();
2829
int load(std::string path);
30+
int reload();
2931
bool is_loaded();
3032
bool m_cuda_available;
3133
torch::jit::script::Module get_model() { return m_model; }

0 commit comments

Comments
 (0)