-
Notifications
You must be signed in to change notification settings - Fork 0
/
resnet18.py
79 lines (67 loc) · 3.38 KB
/
resnet18.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Made with tutorial from: https://medium.com/@divyanshuraj.6815/learn-to-code-in-tensorflow2-part2-b1c448abbf1e
import tensorflow as tf
class CommonLayers(tf.keras.Model):
def __init__(self, num_filters: int=0):
super().__init__()
self.num_filters = num_filters
if num_filters != 0:
self.conv = tf.keras.layers.Conv2D(
filters = num_filters, kernel_size = 3, padding="SAME", kernel_initializer='glorot_uniform', use_bias=False
)
self.batch_norm = tf.keras.layers.BatchNormalization ()
def call(self, inputs):
if self.num_filters == 0:
return tf.nn.relu (self.batch_norm (inputs))
else:
return tf.nn.relu (self.batch_norm (self.conv (inputs)))
class ResBlock(tf.keras.Model):
def __init__ (self, c_out, flag_reshape, flag_common):
super ().__init__ ()
self.common1 = CommonLayers ()
self.common2 = CommonLayers ()
if flag_reshape == False:
self.conv1 = tf.keras.layers.Conv2D(filters=c_out, kernel_size=3, padding="SAME", kernel_initializer='glorot_uniform', use_bias=False)
self.conv2 = tf.keras.layers.Conv2D(filters=c_out, kernel_size=3, padding="SAME", kernel_initializer='glorot_uniform', use_bias=False)
else:
self.conv1 = tf.keras.layers.Conv2D(filters=c_out, kernel_size=3, padding="SAME", kernel_initializer='glorot_uniform', strides = 2, use_bias=False)
self.conv2 = tf.keras.layers.Conv2D(filters=c_out, kernel_size=3, padding="SAME", kernel_initializer='glorot_uniform', use_bias=False)
self.pool = tf.keras.layers.Conv2D(filters=c_out, kernel_size=1, strides = 2, kernel_initializer='glorot_uniform', use_bias=False)
self.maxpool = tf.keras.layers.MaxPool2D ((4, 4))
self.avgpool = tf.keras.layers.AveragePooling2D ((4, 4))
self.flag_reshape = flag_reshape
self.flag_common = flag_common
self.c_out = c_out
def call (self, inputs):
h = self.conv2 (self.common1 (self.conv1 (inputs)))
#if we want to reshape/reduce the image size or increase the number of channels
if self.flag_reshape == True:
h = h + self.pool (inputs)
else:
h = h + inputs
#for last blocks
if self.flag_common == False:
return tf.keras.layers.concatenate ([self.maxpool (h), self.avgpool (h)])
else:
return self.common2 (h)
class ResNet18(tf.keras.Model):
def __init__ (self, c = 64):
super ().__init__ ()
self.common = CommonLayers(c)
self.blk1_1 = ResBlock(c, False, True)
self.blk1_2 = ResBlock(c, False, True)
self.blk2_1 = ResBlock(c * 2, True, True)
self.blk2_2 = ResBlock(c * 2, False, True)
self.blk3_1 = ResBlock(c * 4, True, True)
self.blk3_2 = ResBlock(c * 4, False, True)
self.blk4_1 = ResBlock(c * 8, True, True)
self.blk4_2 = ResBlock(c * 8, False, False)
self.linear = tf.keras.layers.Conv2D(filters = 10, kernel_size = 1, kernel_initializer='glorot_uniform', use_bias=False)
self.flat = tf.keras.layers.Flatten()
self.actn = tf.keras.layers.Activation ('softmax')
def call (self, x, y):
h = self.common(x)
h = self.actn(self.flat (self.linear (self.blk4_2 (self.blk4_1 (self.blk3_2 (self.blk3_1 (self.blk2_2 (self.blk2_1 (self.blk1_2 (self.blk1_1 (h)))))))))))
ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
loss = tf.reduce_sum(ce)
correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
return loss, correct