Skip to content

Commit 05a6069

Browse files
[Go] add localvec to correspond to TypeScript dev-local-vectorstore (#124)
Add a genkit.DebugLog functions so that plugins can log debug info.
1 parent 0aeb66d commit 05a6069

File tree

3 files changed

+353
-0
lines changed

3 files changed

+353
-0
lines changed

go/genkit/logging.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,8 @@ func logger(ctx context.Context) *slog.Logger {
3939
}
4040
return slog.Default()
4141
}
42+
43+
// DebugLog is a helper function for plugins to log debugging info.
44+
func DebugLog(ctx context.Context, msg string, args ...any) {
45+
logger(ctx).Debug(msg, args...)
46+
}

go/plugins/localvec/localvec.go

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// Package localvec is a local vector database for development and testing.
16+
// The database is stored in a file in the local file system.
17+
// Production code should use a real vector database.
18+
package localvec
19+
20+
import (
21+
"cmp"
22+
"context"
23+
"crypto/md5"
24+
"encoding/json"
25+
"errors"
26+
"fmt"
27+
"io/fs"
28+
"math"
29+
"os"
30+
"path/filepath"
31+
"slices"
32+
33+
"github.com/google/genkit/go/ai"
34+
"github.com/google/genkit/go/genkit"
35+
)
36+
37+
// New returns a new local vector database. This will register a new
38+
// retriever with genkit, and also return it.
39+
// This retriever may only be used by a single goroutine at a time.
40+
// This is based on js/plugins/dev-local-vectorstore/src/index.ts.
41+
func New(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOptions any) (ai.Retriever, error) {
42+
r, err := newRetriever(ctx, dir, name, embedder, embedderOptions)
43+
if err != nil {
44+
return nil, err
45+
}
46+
ai.RegisterRetriever("devLocalVectorStore/"+name, r)
47+
return r, nil
48+
}
49+
50+
// retriever implements the [ai.Retriever] interface
51+
// for a local vector database.
52+
type retriever struct {
53+
filename string
54+
embedder ai.Embedder
55+
embedderOptions any
56+
data map[string]dbValue
57+
}
58+
59+
// dbValue is the type of a document stored in the database.
60+
type dbValue struct {
61+
Doc *ai.Document `json:"doc"`
62+
Embedding []float32 `json:"embedding"`
63+
}
64+
65+
// newRetriever returns a new ai.Retriever to register.
66+
func newRetriever(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOptions any) (ai.Retriever, error) {
67+
if err := os.MkdirAll(dir, 0o755); err != nil {
68+
return nil, err
69+
}
70+
dbname := "__db_" + name + ".json"
71+
filename := filepath.Join(dir, dbname)
72+
f, err := os.Open(filename)
73+
var data map[string]dbValue
74+
if err != nil {
75+
if !errors.Is(err, fs.ErrNotExist) {
76+
return nil, err
77+
}
78+
} else {
79+
defer f.Close()
80+
decoder := json.NewDecoder(f)
81+
if err := decoder.Decode(&data); err != nil {
82+
return nil, err
83+
}
84+
}
85+
86+
r := &retriever{
87+
filename: filename,
88+
embedder: embedder,
89+
embedderOptions: embedderOptions,
90+
data: data,
91+
}
92+
return r, nil
93+
}
94+
95+
// Index implements the genkit [ai.Retriever.Index] method.
96+
func (r *retriever) Index(ctx context.Context, req *ai.IndexerRequest) error {
97+
for _, doc := range req.Documents {
98+
ereq := &ai.EmbedRequest{
99+
Document: doc,
100+
Options: r.embedderOptions,
101+
}
102+
vals, err := r.embedder.Embed(ctx, ereq)
103+
if err != nil {
104+
return fmt.Errorf("localvec index embedding failed: %v", err)
105+
}
106+
107+
id, err := docID(doc)
108+
if err != nil {
109+
return err
110+
}
111+
112+
if _, ok := r.data[id]; ok {
113+
genkit.DebugLog(ctx, "localvec skipping document because already present", "id", id)
114+
continue
115+
}
116+
117+
if r.data == nil {
118+
r.data = make(map[string]dbValue)
119+
}
120+
121+
r.data[id] = dbValue{
122+
Doc: doc,
123+
Embedding: vals,
124+
}
125+
}
126+
127+
// Update the file every time we add documents.
128+
tmpname := r.filename + ".tmp"
129+
f, err := os.Create(tmpname)
130+
if err != nil {
131+
return err
132+
}
133+
encoder := json.NewEncoder(f)
134+
if err := encoder.Encode(r.data); err != nil {
135+
return err
136+
}
137+
if err := f.Close(); err != nil {
138+
return err
139+
}
140+
141+
return nil
142+
}
143+
144+
// RetrieverOptions may be passed in the Options field
145+
// of [ai.RetrieverRequest] to pass options to the retriever.
146+
// The Options field should be either nil or a value of type *RetrieverOptions.
147+
type RetrieverOptions struct {
148+
K int `json:"k,omitempty"` // number of entries to return
149+
}
150+
151+
// Retrieve implements the genkit [ai.Retriever.Retrieve] method.
152+
func (r *retriever) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
153+
// Use the embedder to convert the document we want to
154+
// retrieve into a vector.
155+
ereq := &ai.EmbedRequest{
156+
Document: req.Document,
157+
Options: r.embedderOptions,
158+
}
159+
vals, err := r.embedder.Embed(ctx, ereq)
160+
if err != nil {
161+
return nil, fmt.Errorf("localvec retrieve embedding failed: %v", err)
162+
}
163+
164+
type scoredDoc struct {
165+
score float64
166+
doc *ai.Document
167+
}
168+
scoredDocs := make([]scoredDoc, 0, len(r.data))
169+
for _, dbv := range r.data {
170+
score := similarity(vals, dbv.Embedding)
171+
scoredDocs = append(scoredDocs, scoredDoc{
172+
score: score,
173+
doc: dbv.Doc,
174+
})
175+
}
176+
177+
slices.SortFunc(scoredDocs, func(a, b scoredDoc) int {
178+
// We want to sort by descending score,
179+
// so pass b.score first to reverse the default ordering.
180+
return cmp.Compare(b.score, a.score)
181+
})
182+
183+
k := 3
184+
if options, _ := req.Options.(*RetrieverOptions); options != nil {
185+
k = options.K
186+
}
187+
k = min(k, len(scoredDocs))
188+
189+
docs := make([]*ai.Document, 0, k)
190+
for i := 0; i < k; i++ {
191+
docs = append(docs, scoredDocs[i].doc)
192+
}
193+
194+
resp := &ai.RetrieverResponse{
195+
Documents: docs,
196+
}
197+
return resp, nil
198+
}
199+
200+
// similarity computes the [cosine similarity] between two vectors.
201+
//
202+
// [cosine similarity]: https://en.wikipedia.org/wiki/Cosine_similarity
203+
func similarity(vals1, vals2 []float32) float64 {
204+
l2norm := func(v float64, s, t float64) (float64, float64) {
205+
if v == 0 {
206+
return s, t
207+
}
208+
a := math.Abs(v)
209+
if a > t {
210+
r := t / v
211+
s = 1 + s*r*r
212+
t = a
213+
} else {
214+
r := v / t
215+
s = s + r*r
216+
}
217+
return s, t
218+
}
219+
220+
dot := float64(0)
221+
s1 := float64(1)
222+
t1 := float64(0)
223+
s2 := float64(1)
224+
t2 := float64(0)
225+
226+
for i, v1f := range vals1 {
227+
v1 := float64(v1f)
228+
v2 := float64(vals2[i])
229+
dot += v1 * v2
230+
s1, t1 = l2norm(v1, s1, t1)
231+
s2, t2 = l2norm(v2, s2, t2)
232+
}
233+
234+
l1 := t1 * math.Sqrt(s1)
235+
l2 := t2 * math.Sqrt(s2)
236+
237+
return dot / (l1 * l2)
238+
}
239+
240+
// docID returns the ID to use for a Document.
241+
// This is intended to be the same as the genkit Typescript computation.
242+
func docID(doc *ai.Document) (string, error) {
243+
b, err := json.Marshal(doc)
244+
if err != nil {
245+
return "", fmt.Errorf("localvec: error marshaling document: %v", err)
246+
}
247+
return fmt.Sprintf("%02x", md5.Sum(b)), nil
248+
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package localvec
16+
17+
import (
18+
"context"
19+
"math"
20+
"strings"
21+
"testing"
22+
23+
"github.com/google/genkit/go/ai"
24+
"github.com/google/genkit/go/internal/fakeembedder"
25+
)
26+
27+
func TestLocalVec(t *testing.T) {
28+
ctx := context.Background()
29+
30+
// Make two very similar vectors and one different vector.
31+
// Arrange for a fake embedder to return those vector
32+
// when provided with documents.
33+
34+
const dim = 32
35+
v1 := make([]float32, dim)
36+
v2 := make([]float32, dim)
37+
v3 := make([]float32, dim)
38+
for i := range v1 {
39+
v1[i] = float32(i)
40+
v2[i] = float32(i)
41+
v3[i] = float32(dim - i)
42+
}
43+
v2[0] = 1
44+
45+
d1 := ai.DocumentFromText("hello1", nil)
46+
d2 := ai.DocumentFromText("hello2", nil)
47+
d3 := ai.DocumentFromText("goodbye", nil)
48+
49+
embedder := fakeembedder.New()
50+
embedder.Register(d1, v1)
51+
embedder.Register(d2, v2)
52+
embedder.Register(d3, v3)
53+
54+
r, err := newRetriever(ctx, t.TempDir(), "testLocalVec", embedder, nil)
55+
if err != nil {
56+
t.Fatal(err)
57+
}
58+
59+
indexerReq := &ai.IndexerRequest{
60+
Documents: []*ai.Document{d1, d2, d3},
61+
}
62+
err = r.Index(ctx, indexerReq)
63+
if err != nil {
64+
t.Fatalf("Index operation failed: %v", err)
65+
}
66+
67+
retrieverOptions := &RetrieverOptions{
68+
K: 2,
69+
}
70+
71+
retrieverReq := &ai.RetrieverRequest{
72+
Document: d1,
73+
Options: retrieverOptions,
74+
}
75+
retrieverResp, err := r.Retrieve(ctx, retrieverReq)
76+
if err != nil {
77+
t.Fatalf("Retrieve operation failed: %v", err)
78+
}
79+
80+
docs := retrieverResp.Documents
81+
if len(docs) != 2 {
82+
t.Errorf("got %d results, expected 2", len(docs))
83+
}
84+
for _, d := range docs {
85+
text := d.Content[0].Text()
86+
if !strings.HasPrefix(text, "hello") {
87+
t.Errorf("returned doc text %q does not start with %q", text, "hello")
88+
}
89+
}
90+
}
91+
92+
func TestSimilarity(t *testing.T) {
93+
x := []float32{5, 23, 2, 5, 9}
94+
y := []float32{3, 21, 2, 5, 14}
95+
got := similarity(x, y)
96+
want := 0.975
97+
if math.Abs(got-want) > 0.001 {
98+
t.Errorf("got %f, want %f", got, want)
99+
}
100+
}

0 commit comments

Comments
 (0)