-
Notifications
You must be signed in to change notification settings - Fork 2
/
prediction.py
87 lines (77 loc) · 2.95 KB
/
prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import json
from operator import itemgetter
from pathlib import Path
from dataclasses import dataclass
from typing import Optional, List
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.text import tokenizer_from_json
from tensorflow.keras.preprocessing.sequence import pad_sequences
@dataclass
class AIModel:
modelPath: Path
tokenizerPath: Optional[Path] = None
metadataPath: Optional[Path] = None
model = None
tokenizer = None
metadata = None
def __post_init__(self):
if self.modelPath.exists():
self.model = load_model(self.modelPath)
else:
raise ValueError('Could not load model data')
#
if self.tokenizerPath and self.tokenizerPath.exists():
tokenizerText = self.tokenizerPath.read_text()
self.tokenizer = tokenizer_from_json(tokenizerText)
else:
raise ValueError('Could not load tokenizer data')
#
if self.metadataPath and self.metadataPath.exists():
self.metadata = json.loads(self.metadataPath.read_text())
else:
raise ValueError('Could not load metadata')
def getPaddedSequencesFromTexts(self, texts: List[str]):
sequences = self.tokenizer.texts_to_sequences(texts)
maxSeqLength = self.metadata['max_seq_length']
padded = pad_sequences(sequences, maxlen=maxSeqLength)
return padded
def getLabelName(self, labelIndex):
return self.metadata['label_legend_inverted'][str(labelIndex)]
def getTopPrediction(self, predictionDict):
if len(predictionDict) == 0:
return None
else:
topK, topV = sorted(
predictionDict.items(),
key=itemgetter(1),
reverse=True,
)[0]
return {
'label': topK,
'value': topV,
}
def _convertFloat(self, standardTypes, fVal):
""" Utility method to get rid of numpy numeric types."""
return float(fVal) if standardTypes else fVal
def predict(self, texts: List[str], standardTypes=True, echoInput=False):
xInput = self.getPaddedSequencesFromTexts(texts)
predictions = self.model.predict(xInput, batch_size=10)
labeledPredictions = [
{
self.getLabelName(predIndex): self._convertFloat(standardTypes,
predValue)
for predIndex, predValue in enumerate(list(prediction))
}
for prediction in predictions
]
results = [
{
**{
'prediction': labeledPrediction,
'res': self.getTopPrediction(labeledPrediction),
},
**({'input': inputText} if echoInput else {}),
}
for labeledPrediction, inputText in zip(labeledPredictions, texts)
]
return results