-
Notifications
You must be signed in to change notification settings - Fork 2
/
interface.py
executable file
·50 lines (43 loc) · 1.33 KB
/
interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import pickle
from collections import defaultdict
from skgmm import GMMSet
from features import get_feature
import time
class ModelInterface:
def __init__(self):
self.features = defaultdict(list)
self.gmmset = GMMSet()
def enroll(self, name, fs, signal):
feat = get_feature(fs, signal)
self.features[name].extend(feat)
def train(self):
self.gmmset = GMMSet()
start_time = time.time()
for name, feats in self.features.items():
try:
self.gmmset.fit_new(feats, name)
except Exception as e :
print ("%s failed"%(name))
print (time.time() - start_time, " seconds")
def dump(self, fname):
""" dump all models to file"""
self.gmmset.before_pickle()
with open(fname, 'wb') as f:
pickle.dump(self, f, -1)
self.gmmset.after_pickle()
def predict(self, fs, signal):
"""
return a label (name)
"""
try:
feat = get_feature(fs, signal)
except Exception as e:
print (e)
return self.gmmset.predict_one(feat)
@staticmethod
def load(fname):
""" load from a dumped model file"""
with open(fname, 'rb') as f:
R = pickle.load(f)
R.gmmset.after_pickle()
return R