-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathstate.go
296 lines (255 loc) · 7.59 KB
/
state.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
package pymlstate
import (
"encoding/binary"
"errors"
"fmt"
"github.com/ugorji/go/codec"
"gopkg.in/sensorbee/py.v0/pystate"
"gopkg.in/sensorbee/sensorbee.v0/core"
"gopkg.in/sensorbee/sensorbee.v0/data"
"io"
"sync"
)
var (
datPath = data.MustCompilePath("data")
)
// State is python instance specialized to multiple layer classification.
// The python instance and this struct must not be coppied directly by assignment
// statement because it doesn't increase reference count of instance.
type State struct {
base *pystate.Base
params MLParams
bucket []data.Value
rwm sync.RWMutex
}
// MLParams is parameters pymlstate defines in addition to those pystate does.
// These parameters come from a WITH clause of a CREATE STATE statement.
type MLParams struct {
// BatchSize is number of tuples in a single batch training. Write method,
// which is usually called by an INSERT INTOT statement via uds Sink, stores
// tuples without training until it has tuples as many as batch_train_size.
// This is an optional parameter and its default value is 10.
BatchSize int `codec:"batch_train_size"`
}
// New creates `core.SharedState` for multiple layer classification.
func New(baseParams *pystate.BaseParams, mlParams *MLParams, params data.Map) (*State, error) {
b, err := pystate.NewBase(baseParams, params)
if err != nil {
return nil, err
}
s := &State{
base: b,
params: *mlParams,
bucket: make([]data.Value, 0, mlParams.BatchSize),
}
return s, nil
}
// Terminate terminates this state.
func (s *State) Terminate(ctx *core.Context) error {
s.rwm.Lock()
defer s.rwm.Unlock()
if err := s.base.Terminate(ctx); err != nil {
return err
}
// Don't set s.base = nil because it's used for the termination detection.
s.bucket = nil
return nil
}
// Write stores a tuple to its bucket and calls "fit" function every
// "batch_train_size" times.
func (s *State) Write(ctx *core.Context, t *core.Tuple) error {
s.rwm.Lock()
defer s.rwm.Unlock()
if err := s.base.CheckTermination(); err != nil {
return err
}
dataSet, err := t.Data.Get(datPath)
if err != nil {
return err
}
if s.params.BatchSize > 1 {
s.bucket = append(s.bucket, dataSet)
if len(s.bucket) < s.params.BatchSize {
return nil
}
} else {
if dataSet.Type() == data.TypeArray {
arr, _ := data.AsArray(dataSet)
s.bucket = arr
} else {
s.bucket = []data.Value{dataSet}
}
}
_, err = s.fit(ctx, s.bucket)
prevBucketSize := len(s.bucket)
s.bucket = s.bucket[:0] // clear slice but keep capacity
if err != nil {
ctx.ErrLog(err).WithField("bucket_size", prevBucketSize).
Error("pymlstate's training via Write (INSERT INTO) failed")
return err
}
return nil
}
// Fit receives `data.Array` type but it assumes `[]data.Map` type
// for passing arguments to `fit` method.
func (s *State) Fit(ctx *core.Context, bucket []data.Value) (data.Value, error) {
s.rwm.RLock()
defer s.rwm.RUnlock()
return s.fit(ctx, bucket)
}
// fit is the internal implementation of Fit. fit doesn't acquire the lock nor
// check s.ins == nil. RLock is sufficient when calling this method because
// this method itself doesn't change any field of State. Although the model
// will be updated by the data, the model is protected by Python's GIL. So,
// this method doesn't require a write lock.
func (s *State) fit(ctx *core.Context, bucket []data.Value) (data.Value, error) {
return s.base.Call("fit", data.Array(bucket))
}
// Predict applies the model to the data. It returns a result returned from
// Python script.
func (s *State) Predict(ctx *core.Context, dt data.Value) (data.Value, error) {
s.rwm.RLock()
defer s.rwm.RUnlock()
return s.base.Call("predict", dt)
}
// Save saves the model of the state. pystate calls `save` method and
// use its return value as dumped model.
func (s *State) Save(ctx *core.Context, w io.Writer, params data.Map) error {
s.rwm.RLock()
defer s.rwm.RUnlock()
if err := s.base.CheckTermination(); err != nil {
return err
}
if err := s.saveState(w); err != nil {
return err
}
return s.base.Save(ctx, w, params)
}
const (
pyMLStateFormatVersion uint8 = 1
)
func (s *State) saveState(w io.Writer) error {
if _, err := w.Write([]byte{pyMLStateFormatVersion}); err != nil {
return err
}
// Save parameter of State before save python's model
msgpackHandle := &codec.MsgpackHandle{}
var out []byte
enc := codec.NewEncoderBytes(&out, msgpackHandle)
if err := enc.Encode(&s.params); err != nil {
return err
}
// Write size of MLParams
dataSize := uint32(len(out))
err := binary.Write(w, binary.LittleEndian, dataSize)
if err != nil {
return err
}
// Write MLParams in msgpack
n, err := w.Write(out)
if err != nil {
return err
}
if n < len(out) {
return errors.New("cannot save the MLParams data")
}
return nil
}
// Load loads the model of the state. pystate calls `load` method and
// pass to the model data by using method parameter.
func (s *State) Load(ctx *core.Context, r io.Reader, params data.Map) error {
s.rwm.Lock()
defer s.rwm.Unlock()
if err := s.base.CheckTermination(); err != nil {
return err
}
return s.load(ctx, r, params)
}
func (s *State) load(ctx *core.Context, r io.Reader, params data.Map) error {
var formatVersion uint8
if err := binary.Read(r, binary.LittleEndian, &formatVersion); err != nil {
return err
}
// TODO: remove MLParams specific parameters from params
switch formatVersion {
case 1:
return s.loadMLParamsAndDataV1(ctx, r, params)
default:
return fmt.Errorf("unsupported format version of State container: %v", formatVersion)
}
}
func (s *State) loadMLParamsAndDataV1(ctx *core.Context, r io.Reader, params data.Map) error {
var dataSize uint32
if err := binary.Read(r, binary.LittleEndian, &dataSize); err != nil {
return err
}
if dataSize == 0 {
return errors.New("size of MLParams must be greater than 0")
}
// Read MLParams from reader
buf := make([]byte, dataSize)
n, err := r.Read(buf)
if err != nil {
return err
}
if n != int(dataSize) {
return errors.New("read size is different from the size of MLParams")
}
// Desirialize MLParams
var saved MLParams
msgpackHandle := &codec.MsgpackHandle{}
dec := codec.NewDecoderBytes(buf, msgpackHandle)
if err := dec.Decode(&saved); err != nil {
return err
}
if s.base == nil { // loading for the first time
s.base, err = pystate.LoadBase(ctx, r, params)
if err != nil {
return err
}
} else {
if err := s.base.Load(ctx, r, params); err != nil {
return err
}
}
s.params = saved
return nil
}
// Fit trains the model. It applies tuples that bucket has in a batch manner.
// The return value of this function depends on the implementation of Python
// UDS.
func Fit(ctx *core.Context, stateName string, bucket []data.Value) (data.Value, error) {
s, err := lookupState(ctx, stateName)
if err != nil {
return nil, err
}
return s.Fit(ctx, bucket)
}
// Predict applies the model to the given data and returns estimated values.
// The format of the return value depends on each Python UDS.
func Predict(ctx *core.Context, stateName string, dt data.Value) (data.Value, error) {
s, err := lookupState(ctx, stateName)
if err != nil {
return nil, err
}
return s.Predict(ctx, dt)
}
// Flush pymlstate bucket. A return value is always nil.
func Flush(ctx *core.Context, stateName string) (data.Value, error) {
s, err := lookupState(ctx, stateName)
if err != nil {
return nil, err
}
s.bucket = s.bucket[:0]
return nil, nil
}
func lookupState(ctx *core.Context, stateName string) (*State, error) {
st, err := ctx.SharedStates.Get(stateName)
if err != nil {
return nil, err
}
if s, ok := st.(*State); ok {
return s, nil
}
return nil, fmt.Errorf("state '%v' isn't a State", stateName)
}