-
Notifications
You must be signed in to change notification settings - Fork 0
/
Model.py
56 lines (44 loc) · 1.97 KB
/
Model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from tensorflow import keras
from keras_bert import AdamWarmup, calc_train_steps, get_custom_objects, load_trained_model_from_checkpoint, backend
class BertModel():
def __init__(self,config_path, checkpoint_path,seq_len,batch_size,epochs, lr):
self.config_path = config_path
self.checkpoint_path = checkpoint_path
self.seq_len = seq_len
self.batch_size = batch_size
self.epochs = epochs
self.lr = lr
self.model = None
self.load_pre_trained_model()
def load_pre_trained_model(self):
self.pretrained_model = load_trained_model_from_checkpoint(
self.config_path,
self.checkpoint_path,
training=True,
trainable=True,
seq_len = self.seq_len,
)
def compile_model(self,data_size,loss_fn,metrics):
inputs = self.pretrained_model.inputs[:2]
dense = self.pretrained_model.get_layer('NSP-Dense').output
outputs = keras.layers.Dense(units=2, activation='softmax')(dense)
decay_steps, warmup_steps = calc_train_steps(data_size,
batch_size=self.batch_size,
epochs=self.epochs,
)
model = keras.models.Model(inputs, outputs)
model.compile(
AdamWarmup(decay_steps=decay_steps, warmup_steps=warmup_steps, lr=self.lr),
loss=loss_fn,
metrics=[metrics],
)
self.model = model
print(self.model.summary())
return self.model
def load(self, checkpoint_path):
"""
loads an H5 file
:param checkpoint_path:file path
:return:
"""
self.model = keras.models.load_model(checkpoint_path,custom_objects = get_custom_objects())