-
Notifications
You must be signed in to change notification settings - Fork 3
/
program.go
98 lines (88 loc) · 2.25 KB
/
program.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
package xgp
import (
"encoding/json"
"errors"
"github.com/MaxHalford/xgp/metrics"
"github.com/MaxHalford/xgp/op"
"github.com/gonum/floats"
)
// A Program is a thin layer on top of an Operator.
type Program struct {
*GP
Op op.Operator
}
// String formatting.
func (prog Program) String() string {
return prog.Op.String()
}
// Classification determines if the Program has to perform classification or
// not. It does so by looking at the GP's LossMetric.
func (prog Program) classification() bool {
if prog.GP != nil {
if prog.GP.LossMetric != nil {
return prog.GP.LossMetric.Classification()
}
}
return false
}
// Predict predicts the output of a slice of features.
func (prog Program) Predict(X [][]float64, proba bool) ([]float64, error) {
// Make predictions
yPred := prog.Op.Eval(X)
// Check the predictions don't contain any NaNs
if floats.HasNaN(yPred) {
return nil, errors.New("yPred contains NaNs")
}
// Regression
if !prog.classification() {
return yPred, nil
}
// Classification
var transform = map[bool]func(float64) float64{true: sigmoid, false: binary}[proba]
for i, y := range yPred {
yPred[i] = transform(y)
}
return yPred, nil
}
// PredictPartial is a convenience function on top of Predict to make
// predictions on a single instance.
func (prog Program) PredictPartial(x []float64, proba bool) (float64, error) {
var X = make([][]float64, len(x))
for i, xi := range x {
X[i] = []float64{xi}
}
yPred, err := prog.Predict(X, proba)
if err != nil {
return 0, err
}
return yPred[0], nil
}
type serialProgram struct {
Op op.SerialOp `json:"op"`
LossMetric string `json:"loss_metric"`
}
// MarshalJSON serializes a Program.
func (prog Program) MarshalJSON() ([]byte, error) {
return json.Marshal(&serialProgram{
Op: op.SerializeOp(prog.Op),
LossMetric: prog.GP.LossMetric.String(),
})
}
// UnmarshalJSON parses a Program.
func (prog *Program) UnmarshalJSON(bytes []byte) error {
var serial = &serialProgram{}
if err := json.Unmarshal(bytes, serial); err != nil {
return err
}
loss, err := metrics.ParseMetric(serial.LossMetric, 1)
if err != nil {
return err
}
operator, err := op.ParseOp(serial.Op)
if err != nil {
return err
}
prog.Op = operator
prog.GP = &GP{LossMetric: loss}
return nil
}