-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathneural_obj.hpp
62 lines (46 loc) · 1.54 KB
/
neural_obj.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#ifndef NEURAL_OBJ_H
#define NEURAL_OBJ_H
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "tracked_object.hpp"
namespace mnn {
class NeuralObj;
using NeuralObj_ptr = std::shared_ptr<NeuralObj>;
class NeuralObj : CounterObj<NeuralObj>,
std::enable_shared_from_this<NeuralObj> {
/* Member Variables */
private:
protected:
std::string name;
std::vector<NeuralObj_ptr> inputs;
std::vector<double> weights;
std::vector<bool> is_loop_flag;
bool is_waiting;
int waiting_on;
std::unordered_map<NeuralObj_ptr, double> forward_handoff_map;
public:
bool done_calculating;
/* Member Methods */
private:
protected:
virtual void recieve_backprop_handoff(NeuralObj_ptr &, double) = 0;
virtual void request_forwardprop_handoff(NeuralObj_ptr &) = 0;
virtual void give_forwardprop_handoff(NeuralObj_ptr &, double) = 0;
/*Eventually make this thread-safe*/
double get_forwardprop_responce(NeuralObj_ptr &n) {
assert(forward_handoff_map.find(n) != forward_handoff_map.end());
return forward_handoff_map[n];
}
//Called at the end of add_input on the added object, passing shared_ptr<this>
virtual void connect(NeuralObj_ptr &) = 0;
public:
NeuralObj() : is_waiting(false), waiting_on(0), done_calculating(false) {}
virtual void add_input(NeuralObj_ptr &) = 0;
virtual void remove_input(NeuralObj_ptr &) = 0;
virtual void calculate() = 0;
virtual void update(double) = 0;
};
} // namespace mnn
#endif