File tree 2 files changed +12
-3
lines changed
2 files changed +12
-3
lines changed Original file line number Diff line number Diff line change 4
4
#include < iostream>
5
5
#include < stdlib.h>
6
6
7
+
7
8
#define CUDA torch::kCUDA
8
9
#define CPU torch::kCPU
9
10
@@ -85,20 +86,26 @@ void Backend::perform(std::vector<float *> in_buffer,
85
86
86
87
int Backend::load (std::string path) {
87
88
try {
88
- m_model = torch::jit::load (path);
89
- m_model .eval ();
89
+ auto model = torch::jit::load (path);
90
+ model .eval ();
90
91
if (m_cuda_available) {
91
92
std::cout << " sending model to gpu" << std::endl;
92
- m_model .to (CUDA);
93
+ model .to (CUDA);
93
94
}
95
+ m_model = model;
94
96
m_loaded = 1 ;
97
+ m_path = path;
95
98
return 0 ;
96
99
} catch (const std::exception &e) {
97
100
std::cerr << e.what () << ' \n ' ;
98
101
return 1 ;
99
102
}
100
103
}
101
104
105
+ int Backend::reload (){
106
+ return load (m_path);
107
+ }
108
+
102
109
bool Backend::has_method (std::string method_name) {
103
110
for (const auto &m : m_model.get_methods ()) {
104
111
if (m.name () == method_name)
Original file line number Diff line number Diff line change @@ -8,6 +8,7 @@ class Backend {
8
8
private:
9
9
torch::jit::script::Module m_model;
10
10
int m_loaded;
11
+ std::string m_path;
11
12
12
13
public:
13
14
Backend ();
@@ -26,6 +27,7 @@ class Backend {
26
27
std::vector<int > get_method_params (std::string method);
27
28
int get_higher_ratio ();
28
29
int load (std::string path);
30
+ int reload ();
29
31
bool is_loaded ();
30
32
bool m_cuda_available;
31
33
torch::jit::script::Module get_model () { return m_model; }
You can’t perform that action at this time.
0 commit comments