-
Notifications
You must be signed in to change notification settings - Fork 97
/
Copy pathconv1x1.cc
142 lines (113 loc) · 3.35 KB
/
conv1x1.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
// Conv 1x1 PE
#include "net_hls.h"
#include <stdio.h>
#include <math.h>
#include <ap_fixed.h>
#include "hls_stream.h"
FIX_FM_acc compute_engine_16(FIX_WT w0, FIX_FM b0,
FIX_WT w1, FIX_FM b1,
FIX_WT w2, FIX_FM b2,
FIX_WT w3, FIX_FM b3,
FIX_WT w4, FIX_FM b4,
FIX_WT w5, FIX_FM b5,
FIX_WT w6, FIX_FM b6,
FIX_WT w7, FIX_FM b7,
FIX_WT w8, FIX_FM b8,
FIX_WT w9, FIX_FM b9,
FIX_WT w10, FIX_FM b10,
FIX_WT w11, FIX_FM b11,
FIX_WT w12, FIX_FM b12,
FIX_WT w13, FIX_FM b13,
FIX_WT w14, FIX_FM b14,
FIX_WT w15, FIX_FM b15)
{
FIX_32_10 mul0, mul1, mul2, mul3, mul4, mul5, mul6, mul7;
FIX_32_10 mul8, mul9, mul10, mul11, mul12, mul13, mul14, mul15;
FIX_32_10 add0, add1, add2, add3, add4, add5, add6;
FIX_32_10 add7, add8, add9, add10, add11, add12, add13, add14;
mul0 = w0 * b0;
mul1 = w1 * b1;
mul2 = w2 * b2;
mul3 = w3 * b3;
mul4 = w4 * b4;
mul5 = w5 * b5;
mul6 = w6 * b6;
mul7 = w7 * b7;
mul8 = w8 * b8;
mul9 = w9 * b9;
mul10 = w10 * b10;
mul11 = w11 * b11;
mul12 = w12 * b12;
mul13 = w13 * b13;
mul14 = w14 * b14;
mul15 = w15 * b15;
add0 = mul0 + mul1;
add1 = mul2 + mul3;
add2 = mul4 + mul5;
add3 = mul6 + mul7;
add4 = mul8 + mul9;
add5 = mul10 + mul11;
add6 = mul12 + mul13;
add7 = mul14 + mul15;
add8 = add0 + add1;
add9 = add2 + add3;
add10 = add4 + add5;
add11 = add6 + add7;
add12 = add8 + add9;
add13 = add10 + add11;
add14 = add12 + add13;
return add14;
}
void load_weights(FIX_WT weights[32][32], FIX_WT weight_buf[32][16], int CI)
{
for(int ci = 0; ci < 16; ci++) {
//#pragma HLS pipeline
for(int co = 0; co < 32; co++) {
//#pragma HLS unroll
#pragma HLS pipeline
weight_buf[co][ci] = weights[co][ci + CI];
}
}
}
void CONV_1x1(FIX_FM bottom[32][44][84],
FIX_FM_acc top[32][44][84],
FIX_WT weights[32][32])
{
#pragma HLS array_partition variable=top dim=1 complete
#pragma HLS array_partition variable=bottom dim=1 complete
#pragma HLS array_partition variable=weights dim=1 complete
FIX_WT weight_buf[32][16];
#pragma HLS array_partition variable=weight_buf dim=1 complete
#pragma HLS array_partition variable=weight_buf dim=2 complete
for(int ci = 0; ci < 32; ci += 16) {
//for(int co = 0; co < 32; co += 32) {
load_weights(weights, weight_buf, ci);
int buf_idx = ci/16;
for(int h = 1; h <= 42; h++){
for(int w = 1; w <= 82; w++) {
#pragma HLS pipeline II=2
for(int coo = 0; coo < 32; coo++) {
#pragma HLS unroll
top[coo][h][w] += compute_engine_16(
weight_buf[coo][0], bottom[ci+0][h][w],
weight_buf[coo][1], bottom[ci+1][h][w],
weight_buf[coo][2], bottom[ci+2][h][w],
weight_buf[coo][3], bottom[ci+3][h][w],
weight_buf[coo][4], bottom[ci+4][h][w],
weight_buf[coo][5], bottom[ci+5][h][w],
weight_buf[coo][6], bottom[ci+6][h][w],
weight_buf[coo][7], bottom[ci+7][h][w],
weight_buf[coo][8], bottom[ci+8][h][w],
weight_buf[coo][9], bottom[ci+9][h][w],
weight_buf[coo][10], bottom[ci+10][h][w],
weight_buf[coo][11], bottom[ci+11][h][w],
weight_buf[coo][12], bottom[ci+12][h][w],
weight_buf[coo][13], bottom[ci+13][h][w],
weight_buf[coo][14], bottom[ci+14][h][w],
weight_buf[coo][15], bottom[ci+15][h][w]);
}
}
}
}
//}
}