-
Notifications
You must be signed in to change notification settings - Fork 54
/
example1_main.cc
65 lines (57 loc) · 1.83 KB
/
example1_main.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
// Copyright 2019 the deepx authors.
// Author: Yafei Zhang (kimmyzhang@tencent.com)
//
#include <deepx_core/dx_log.h>
#include <deepx_core/graph/graph.h>
#include <deepx_core/graph/graph_node.h>
#include <deepx_core/graph/op_context.h>
#include <deepx_core/graph/tensor_map.h>
#include <deepx_core/tensor/data_type.h>
#include <iostream>
#include <vector>
namespace deepx_core {
namespace {
class Main : public DataType {
public:
static int main() {
Graph graph;
TensorMap param;
// Initialize graph: Z = X * W + B.
InstanceNode X("X", Shape(1), TENSOR_TYPE_TSR);
InstanceNode W("W", Shape(1), TENSOR_TYPE_TSR);
InstanceNode B("B", Shape(1), TENSOR_TYPE_TSR);
MulNode XW("XW", &X, &W);
AddNode Z("Z", &XW, &B);
DXCHECK_THROW(graph.Compile({&Z}, 0));
// Initialize op context.
OpContext op_context;
op_context.Init(&graph, ¶m);
DXCHECK_THROW(op_context.InitOp(std::vector<int>{0}, -1));
auto& _X = op_context.mutable_inst()->insert<tsr_t>(X.name());
auto& _W = op_context.mutable_inst()->insert<tsr_t>(W.name());
auto& _B = op_context.mutable_inst()->insert<tsr_t>(B.name());
_X.resize(X.shape());
_W.resize(W.shape());
_B.resize(B.shape());
op_context.InitForward();
// Input, forward, output.
auto compute = [&op_context, &_X, &_W, &_B, &Z](float_t x, float_t w,
float_t b) {
_X.data(0) = x;
_W.data(0) = w;
_B.data(0) = b;
op_context.Forward();
const auto& _Z = op_context.hidden().get<tsr_t>(Z.name());
float_t z = _Z.data(0);
std::cout << "Z=" << z << std::endl;
};
compute(1, 2, 3);
compute(2, 3, 4);
compute(4, 5, 6);
compute(10, 20, 30);
return 0;
}
};
} // namespace
} // namespace deepx_core
int main() { return deepx_core::Main::main(); }