-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathunet.cc
73 lines (51 loc) · 3.02 KB
/
unet.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include "../../include/ceras.hpp"
#include "../../include/utils/range.hpp"
#include "../../include/utils/better_assert.hpp"
int main()
{
using namespace ceras;
auto input = Input( {256, 256, 3} ); // 3D tensor input, (batch_size, 256, 256, 3)
auto l0 = relu( Conv2D( 64, {3, 3}, "same" )(input) ); // 256, 256, 64
auto l1 = max_pooling_2d( 2 ) ( l0 ); // 128, 128, 64
auto l2 = relu( Conv2D( 128, {3, 3}, "same" )( l1 ) ); // 128, 128, 128
auto l3 = relu( Conv2D( 128, {3, 3}, "same" )( l2 ) ); // 128, 128, 128
auto l4 = max_pooling_2d( 2 ) ( l3 ); // 64, 64, 128
auto l5 = relu( Conv2D( 256, {3, 3}, "same" )( l4 ) ); // 64, 64, 256
auto l6 = relu( Conv2D( 256, {3, 3}, "same" )( l5 ) ); // 64, 64, 256
auto l7 = relu( Conv2D( 256, {3, 3}, "same" )( l6 ) ); // 64, 64, 256
auto l8 = max_pooling_2d( 2 ) ( l7 ); // 32, 32, 256
auto l9 = relu( Conv2D( 512, {3, 3}, "same" )( l8 ) ); // 32, 32, 512
auto l10 = relu( Conv2D( 512, {3, 3}, "same" )( l9 ) ); // 32, 32, 512
auto l11 = relu( Conv2D( 512, {3, 3}, "same" )( l10 ) ); // 32, 32, 512
auto l12 = max_pooling_2d( 2 ) ( l11 ); // 16, 16, 512
auto l13 = relu( Conv2D( 512, {3, 3}, "same" )( l12 ) ); // 16, 16, 512
auto l14 = relu( Conv2D( 512, {3, 3}, "same" )( l13 ) ); // 16, 16, 512
auto l15 = relu( Conv2D( 512, {3, 3}, "same" )( l14 ) ); // 16, 16, 512
auto l16 = max_pooling_2d( 2 ) ( l15 ); // 8, 8, 512
auto l17 = relu( Conv2D( 512, {3, 3}, "same" )( l16 ) ); // 8, 8, 512
auto l18 = relu( Conv2D( 512, {3, 3}, "same" )( l17 ) ); // 8, 8, 512
auto l19 = up_sampling_2d( 2 )( l18 ); // 16, 16, 512
auto l20 = l15 + relu( Conv2D( 512, {3, 3}, "same" )( l19 ) ); // or concatenate instead of '+'
auto l21 = relu( Conv2D( 512, {3, 3}, "same" )( l20 ) ); // 16, 16, 512
auto l22 = relu( Conv2D( 512, {3, 3}, "same" )( l21 ) ); // 16, 16, 512
auto l23 = up_sampling_2d(2)( l22 ); // 32, 32, 512
auto l24 = l11 + relu( Conv2D( 512, {3, 3}, "same" )( l23 ) ); // 32, 32, 512
auto l25 = relu( Conv2D( 512, {3, 3}, "same" )( l24 ) );
auto l26 = relu( Conv2D( 512, {3, 3}, "same" )( l25 ) );
auto l27 = up_sampling_2d(2)( l26 ); // 64, 64, 512
auto l28 = l7 + relu( Conv2D( 256, {3, 3}, "same" )( l27 ) ); //64, 64, 256
auto l29 = relu( Conv2D( 256, {3, 3}, "same" )( l28 ) ); // 64, 64, 256
auto l30 = relu( Conv2D( 256, {3, 3}, "same" )( l29 ) ); // 64, 64, 256
auto l31 = up_sampling_2d(2)( l30 ); // 128, 128, 256
auto l32 = l3 + relu( Conv2D( 128, {3, 3}, "same" )( l31 ) );
auto l33 = relu( Conv2D( 128, {3, 3}, "same" )( l32 ) );
auto l34 = relu( Conv2D( 128, {3, 3}, "same" )( l33 ) ); // 128, 128, 128
auto l35 = up_sampling_2d(2)( l34 ); // 256, 256, 128
auto l36 = relu( Conv2D( 64, {3, 3}, "same" )( l35 ) );
auto l37 = sigmoid( Conv2D( 3, {3, 3}, "same" )( l36 ) );
auto output = l37;
auto m = model{ input, output }; // define a model
m.summary( "./examples/unet/unet.dot" );
//training code ommited.
return 0;
}