Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Go] add localvec to correspond to TypeScript dev-local-vectorstore #124

Merged
merged 2 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions go/genkit/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,8 @@ func logger(ctx context.Context) *slog.Logger {
}
return slog.Default()
}

// DebugLog is a helper function for plugins to log debugging info.
func DebugLog(ctx context.Context, msg string, args ...any) {
logger(ctx).Debug(msg, args...)
}
248 changes: 248 additions & 0 deletions go/plugins/localvec/localvec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package localvec is a local vector database for development and testing.
// The database is stored in a file in the local file system.
// Production code should use a real vector database.
package localvec

import (
"cmp"
"context"
"crypto/md5"
"encoding/json"
"errors"
"fmt"
"io/fs"
"math"
"os"
"path/filepath"
"slices"

"github.com/google/genkit/go/ai"
"github.com/google/genkit/go/genkit"
)

// New returns a new local vector database. This will register a new
// retriever with genkit, and also return it.
// This retriever may only be used by a single goroutine at a time.
ianlancetaylor marked this conversation as resolved.
Show resolved Hide resolved
// This is based on js/plugins/dev-local-vectorstore/src/index.ts.
func New(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOptions any) (ai.Retriever, error) {
r, err := newRetriever(ctx, dir, name, embedder, embedderOptions)
if err != nil {
return nil, err
}
ai.RegisterRetriever("devLocalVectorStore/"+name, r)
return r, nil
}

// retriever implements the [ai.Retriever] interface
// for a local vector database.
type retriever struct {
filename string
embedder ai.Embedder
embedderOptions any
data map[string]dbValue
}

// dbValue is the type of a document stored in the database.
type dbValue struct {
Doc *ai.Document `json:"doc"`
Embedding []float32 `json:"embedding"`
}

// newRetriever returns a new ai.Retriever to register.
func newRetriever(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOptions any) (ai.Retriever, error) {
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, err
}
dbname := "__db_" + name + ".json"
filename := filepath.Join(dir, dbname)
f, err := os.Open(filename)
var data map[string]dbValue
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
return nil, err
}
} else {
defer f.Close()
decoder := json.NewDecoder(f)
if err := decoder.Decode(&data); err != nil {
return nil, err
}
}

r := &retriever{
filename: filename,
embedder: embedder,
embedderOptions: embedderOptions,
data: data,
}
return r, nil
}

// Index implements the genkit [ai.Retriever.Index] method.
func (r *retriever) Index(ctx context.Context, req *ai.IndexerRequest) error {
for _, doc := range req.Documents {
ereq := &ai.EmbedRequest{
Document: doc,
Options: r.embedderOptions,
}
vals, err := r.embedder.Embed(ctx, ereq)
if err != nil {
return fmt.Errorf("localvec index embedding failed: %v", err)
}

id, err := docID(doc)
if err != nil {
return err
}

if _, ok := r.data[id]; ok {
genkit.DebugLog(ctx, "localvec skipping document because already present", "id", id)
continue
}

if r.data == nil {
r.data = make(map[string]dbValue)
}

r.data[id] = dbValue{
Doc: doc,
Embedding: vals,
}
}

// Update the file every time we add documents.
tmpname := r.filename + ".tmp"
f, err := os.Create(tmpname)
if err != nil {
return err
}
encoder := json.NewEncoder(f)
if err := encoder.Encode(r.data); err != nil {
return err
}
if err := f.Close(); err != nil {
return err
}

return nil
}

// RetrieverOptions may be passed in the Options field
// of [ai.RetrieverRequest] to pass options to the retriever.
// The Options field should be either nil or a value of type *RetrieverOptions.
type RetrieverOptions struct {
K int `json:"k,omitempty"` // number of entries to return
}

// Retrieve implements the genkit [ai.Retriever.Retrieve] method.
func (r *retriever) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
// Use the embedder to convert the document we want to
// retrieve into a vector.
ereq := &ai.EmbedRequest{
Document: req.Document,
Options: r.embedderOptions,
}
vals, err := r.embedder.Embed(ctx, ereq)
if err != nil {
return nil, fmt.Errorf("localvec retrieve embedding failed: %v", err)
}

type scoredDoc struct {
score float64
doc *ai.Document
}
scoredDocs := make([]scoredDoc, 0, len(r.data))
for _, dbv := range r.data {
score := similarity(vals, dbv.Embedding)
scoredDocs = append(scoredDocs, scoredDoc{
score: score,
doc: dbv.Doc,
})
}

slices.SortFunc(scoredDocs, func(a, b scoredDoc) int {
// We want to sort by descending score,
// so pass b.score first to reverse the default ordering.
return cmp.Compare(b.score, a.score)
})

k := 3
if options, _ := req.Options.(*RetrieverOptions); options != nil {
k = options.K
}
k = min(k, len(scoredDocs))

docs := make([]*ai.Document, 0, k)
for i := 0; i < k; i++ {
docs = append(docs, scoredDocs[i].doc)
}

resp := &ai.RetrieverResponse{
Documents: docs,
}
return resp, nil
}

// similarity computes the [cosine similarity] between two vectors.
//
// [cosine similarity]: https://en.wikipedia.org/wiki/Cosine_similarity
func similarity(vals1, vals2 []float32) float64 {
l2norm := func(v float64, s, t float64) (float64, float64) {
if v == 0 {
return s, t
}
a := math.Abs(v)
if a > t {
r := t / v
s = 1 + s*r*r
t = a
} else {
r := v / t
s = s + r*r
}
return s, t
}

dot := float64(0)
s1 := float64(1)
t1 := float64(0)
s2 := float64(1)
t2 := float64(0)

for i, v1f := range vals1 {
v1 := float64(v1f)
v2 := float64(vals2[i])
dot += v1 * v2
s1, t1 = l2norm(v1, s1, t1)
s2, t2 = l2norm(v2, s2, t2)
}

l1 := t1 * math.Sqrt(s1)
l2 := t2 * math.Sqrt(s2)

return dot / (l1 * l2)
}

// docID returns the ID to use for a Document.
// This is intended to be the same as the genkit Typescript computation.
func docID(doc *ai.Document) (string, error) {
b, err := json.Marshal(doc)
if err != nil {
return "", fmt.Errorf("localvec: error marshaling document: %v", err)
}
return fmt.Sprintf("%02x", md5.Sum(b)), nil
}
100 changes: 100 additions & 0 deletions go/plugins/localvec/localvec_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package localvec

import (
"context"
"math"
"strings"
"testing"

"github.com/google/genkit/go/ai"
"github.com/google/genkit/go/internal/fakeembedder"
)

func TestLocalVec(t *testing.T) {
ctx := context.Background()

// Make two very similar vectors and one different vector.
// Arrange for a fake embedder to return those vector
// when provided with documents.

const dim = 32
v1 := make([]float32, dim)
v2 := make([]float32, dim)
v3 := make([]float32, dim)
for i := range v1 {
v1[i] = float32(i)
v2[i] = float32(i)
v3[i] = float32(dim - i)
}
v2[0] = 1

d1 := ai.DocumentFromText("hello1", nil)
d2 := ai.DocumentFromText("hello2", nil)
d3 := ai.DocumentFromText("goodbye", nil)

embedder := fakeembedder.New()
embedder.Register(d1, v1)
embedder.Register(d2, v2)
embedder.Register(d3, v3)

r, err := newRetriever(ctx, t.TempDir(), "testLocalVec", embedder, nil)
if err != nil {
t.Fatal(err)
}

indexerReq := &ai.IndexerRequest{
Documents: []*ai.Document{d1, d2, d3},
}
err = r.Index(ctx, indexerReq)
if err != nil {
t.Fatalf("Index operation failed: %v", err)
}

retrieverOptions := &RetrieverOptions{
K: 2,
}

retrieverReq := &ai.RetrieverRequest{
Document: d1,
Options: retrieverOptions,
}
retrieverResp, err := r.Retrieve(ctx, retrieverReq)
if err != nil {
t.Fatalf("Retrieve operation failed: %v", err)
}

docs := retrieverResp.Documents
if len(docs) != 2 {
t.Errorf("got %d results, expected 2", len(docs))
}
for _, d := range docs {
text := d.Content[0].Text()
if !strings.HasPrefix(text, "hello") {
t.Errorf("returned doc text %q does not start with %q", text, "hello")
}
}
}

func TestSimilarity(t *testing.T) {
x := []float32{5, 23, 2, 5, 9}
y := []float32{3, 21, 2, 5, 14}
got := similarity(x, y)
want := 0.975
if math.Abs(got-want) > 0.001 {
t.Errorf("got %f, want %f", got, want)
}
}
Loading