-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathroute.go
172 lines (162 loc) · 4.33 KB
/
route.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
package semanticrouter
import (
"context"
"fmt"
"golang.org/x/sync/errgroup"
"gonum.org/v1/gonum/mat"
)
// Router represents a semantic router.
//
// Router is a struct that contains a slice of Routes and an Encoder.
//
// Match can be called on a Router to find the best route for a given utterance.
type Router struct {
Routes []Route // Routes is a slice of Routes.
Encoder Encoder // Encoder is an Encoder that encodes utterances into vectors.
Storage Store // Storage is a Store that stores the utterances.
biFuncCoeffs []biFuncCoefficient // biFuncCoefficients is a slice of biFuncCoefficients that represent the bi-function coefficients.
workers int // workers is the number of workers to use for computing similarity scores.
}
// WithWorkers sets the number of workers to use for computing similarity scores.
func WithWorkers(workers int) Option {
return func(r *Router) {
r.workers = workers
}
}
// Route represents a route in the semantic router.
//
// It is a struct that contains a name and a slice of Utterances.
type Route struct {
Name string // Name is the name of the route.
Utterances []Utterance // Utterances is a slice of Utterances.
}
// biFuncCoefficient is an struct that represents a function and it's coefficient.
type biFuncCoefficient struct {
handler handler
coefficient float64
}
// NewRouter creates a new semantic router.
func NewRouter(
routes []Route,
encoder Encoder,
store Store,
opts ...Option,
) (router *Router, err error) {
router = &Router{}
routesLen := len(routes)
ctx := context.Background()
if len(opts) == 0 {
opts = []Option{
WithSimilarityDotMatrix(1.0),
WithEuclideanDistance(1.0),
WithManhattanDistance(1.0),
WithJaccardSimilarity(1.0),
WithPearsonCorrelation(1.0),
WithWorkers(1),
}
}
for _, opt := range opts {
opt(router)
}
for i := 0; i < routesLen; i++ {
for _, utter := range routes[i].Utterances {
_, err = store.Get(ctx, utter.Utterance)
if err == nil {
continue
}
en, err := encoder.Encode(ctx, utter.Utterance)
if err != nil {
return nil, fmt.Errorf("error encoding utterance: %w", err)
}
utter.Embed = en
err = store.Set(ctx, utter)
if err != nil {
return nil,
fmt.Errorf(
"error storing utterance: %s: %w",
utter.Utterance,
err,
)
}
}
}
return &Router{
Routes: routes,
Encoder: encoder,
Storage: store,
}, nil
}
// Match returns the route that matches the given utterance.
//
// The score is the similarity score between the query vector and the index vector.
//
// If the given context is canceled, the context's error is returned if it is non-nil.
func (r *Router) Match(
ctx context.Context,
utterance string,
) (bestRoute *Route, bestScore float64, err error) {
encoding, err := r.Encoder.Encode(ctx, utterance)
if err != nil {
return nil, 0.0, ErrEncoding{
Message: fmt.Sprintf(
"error encoding utterance: %s",
utterance,
),
}
}
queryVec := mat.NewVecDense(len(encoding), encoding)
var simScore float64
var indexVec *mat.VecDense
for _, route := range r.Routes {
for _, ut := range route.Utterances {
em, err := r.Storage.Get(ctx, ut.Utterance)
if err != nil {
return nil, 0.0, ErrGetEmbedding{
Message: fmt.Sprintf(
"error getting embedding: %s",
ut.Utterance,
),
}
}
emLen := len(em)
if emLen != queryVec.Len() {
continue
}
indexVec = mat.NewVecDense(emLen, em)
simScore, err = r.computeScore(queryVec, indexVec)
if err != nil {
return nil, 0.0, err
}
if simScore > bestScore {
bestScore = simScore
bestRoute = &route
}
}
}
return bestRoute, bestScore, nil
}
// computeScore computes the score for a given utterance and route.
//
// It takes a query vector and an index vector as input and returns a score.
//
// Additionally, it leverages the router's biFuncCoefficients to apply different
// weighting factors to functions to get the similarity score.
func (r *Router) computeScore(
queryVec *mat.VecDense,
indexVec *mat.VecDense,
) (float64, error) {
score := 0.0
eg := errgroup.Group{}
eg.SetLimit(r.workers)
for _, fn := range r.biFuncCoeffs {
eg.Go(func() error {
interScore, err := fn.handler(queryVec, indexVec)
if err != nil {
return err
}
score += fn.coefficient * interScore
return nil
})
}
return score, eg.Wait()
}