forked from janzd/EAST
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
79 lines (66 loc) · 4.11 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import keras
from keras.applications.resnet50 import ResNet50
from keras.models import Model
from keras.layers import Conv2D, concatenate, BatchNormalization, Lambda, Input, multiply, add, ZeroPadding2D, Activation, Layer, MaxPooling2D, Dropout
from keras import regularizers
import keras.backend as K
import tensorflow as tf
import numpy as np
RESIZE_FACTOR = 2
def resize_bilinear(x):
return tf.image.resize_bilinear(x, size=[K.shape(x)[1]*RESIZE_FACTOR, K.shape(x)[2]*RESIZE_FACTOR])
def resize_output_shape(input_shape):
shape = list(input_shape)
assert len(shape) == 4
shape[1] *= RESIZE_FACTOR
shape[2] *= RESIZE_FACTOR
return tuple(shape)
class EAST_model:
def __init__(self, input_size=512):
input_image = Input(shape=(None, None, 3), name='input_image')
overly_small_text_region_training_mask = Input(shape=(None, None, 1), name='overly_small_text_region_training_mask')
text_region_boundary_training_mask = Input(shape=(None, None, 1), name='text_region_boundary_training_mask')
target_score_map = Input(shape=(None, None, 1), name='target_score_map')
resnet = ResNet50(input_tensor=input_image, weights='imagenet', include_top=False, pooling=None)
x = resnet.get_layer('activation_49').output
x = Lambda(resize_bilinear, name='resize_1')(x)
x = concatenate([x, resnet.get_layer('activation_40').output], axis=3)
x = Conv2D(128, (1, 1), padding='same', kernel_regularizer=regularizers.l2(1e-5))(x)
x = BatchNormalization(momentum=0.997, epsilon=1e-5, scale=True)(x)
x = Activation('relu')(x)
x = Conv2D(128, (3, 3), padding='same', kernel_regularizer=regularizers.l2(1e-5))(x)
x = BatchNormalization(momentum=0.997, epsilon=1e-5, scale=True)(x)
x = Activation('relu')(x)
x = Lambda(resize_bilinear, name='resize_2')(x)
x = concatenate([x, resnet.get_layer('activation_22').output], axis=3)
x = Conv2D(64, (1, 1), padding='same', kernel_regularizer=regularizers.l2(1e-5))(x)
x = BatchNormalization(momentum=0.997, epsilon=1e-5, scale=True)(x)
x = Activation('relu')(x)
x = Conv2D(64, (3, 3), padding='same', kernel_regularizer=regularizers.l2(1e-5))(x)
x = BatchNormalization(momentum=0.997, epsilon=1e-5, scale=True)(x)
x = Activation('relu')(x)
x = Lambda(resize_bilinear, name='resize_3')(x)
x = concatenate([x, ZeroPadding2D(((1, 0),(1, 0)))(resnet.get_layer('activation_10').output)], axis=3)
x = Conv2D(32, (1, 1), padding='same', kernel_regularizer=regularizers.l2(1e-5))(x)
x = BatchNormalization(momentum=0.997, epsilon=1e-5, scale=True)(x)
x = Activation('relu')(x)
x = Conv2D(32, (3, 3), padding='same', kernel_regularizer=regularizers.l2(1e-5))(x)
x = BatchNormalization(momentum=0.997, epsilon=1e-5, scale=True)(x)
x = Activation('relu')(x)
x = Conv2D(32, (3, 3), padding='same', kernel_regularizer=regularizers.l2(1e-5))(x)
x = BatchNormalization(momentum=0.997, epsilon=1e-5, scale=True)(x)
x = Activation('relu')(x)
pred_score_map = Conv2D(1, (1, 1), activation=tf.nn.sigmoid, name='pred_score_map')(x)
rbox_geo_map = Conv2D(4, (1, 1), activation=tf.nn.sigmoid, name='rbox_geo_map')(x)
rbox_geo_map = Lambda(lambda x: x * input_size)(rbox_geo_map)
angle_map = Conv2D(1, (1, 1), activation=tf.nn.sigmoid, name='rbox_angle_map')(x)
angle_map = Lambda(lambda x: (x - 0.5) * np.pi / 2)(angle_map)
pred_geo_map = concatenate([rbox_geo_map, angle_map], axis=3, name='pred_geo_map')
model = Model(inputs=[input_image, overly_small_text_region_training_mask, text_region_boundary_training_mask, target_score_map], outputs=[pred_score_map, pred_geo_map])
self.model = model
self.input_image = input_image
self.overly_small_text_region_training_mask = overly_small_text_region_training_mask
self.text_region_boundary_training_mask = text_region_boundary_training_mask
self.target_score_map = target_score_map
self.pred_score_map = pred_score_map
self.pred_geo_map = pred_geo_map