From c6a6d95440cfa483bfaf3c10c82aa9327e0eec94 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 24 Jun 2024 06:21:47 -0400 Subject: [PATCH] [Go] ai.Embedder is a separate type from its action See #402 and #458. --- go/ai/embedder.go | 19 ++++++++++--------- go/internal/fakeembedder/fakeembedder.go | 2 +- go/internal/fakeembedder/fakeembedder_test.go | 6 +++--- go/plugins/googleai/googleai.go | 8 ++++---- go/plugins/googleai/googleai_test.go | 2 +- go/plugins/localvec/localvec.go | 10 +++++----- go/plugins/pinecone/genkit.go | 8 ++++---- go/plugins/vertexai/vertexai.go | 6 +++--- go/plugins/vertexai/vertexai_test.go | 2 +- 9 files changed, 32 insertions(+), 31 deletions(-) diff --git a/go/ai/embedder.go b/go/ai/embedder.go index fca80fa43..df612bd58 100644 --- a/go/ai/embedder.go +++ b/go/ai/embedder.go @@ -21,9 +21,9 @@ import ( "github.com/firebase/genkit/go/internal/atype" ) -// EmbedderAction is used to convert a document to a +// An Embedder is used to convert a document to a // multidimensional vector. -type EmbedderAction = core.Action[*EmbedRequest, []float32, struct{}] +type Embedder core.Action[*EmbedRequest, []float32, struct{}] // EmbedRequest is the data we pass to convert a document // to a multidimensional vector. @@ -34,21 +34,22 @@ type EmbedRequest struct { // DefineEmbedder registers the given embed function as an action, and returns an // [EmbedderAction] that runs it. -func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) ([]float32, error)) *EmbedderAction { - return core.DefineAction(provider, name, atype.Embedder, nil, embed) +func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) ([]float32, error)) *Embedder { + return (*Embedder)(core.DefineAction(provider, name, atype.Embedder, nil, embed)) } // LookupEmbedder looks up an [EmbedderAction] registered by [DefineEmbedder]. // It returns nil if the embedder was not defined. -func LookupEmbedder(provider, name string) *EmbedderAction { +func LookupEmbedder(provider, name string) *Embedder { action := core.LookupActionFor[*EmbedRequest, []float32, struct{}](atype.Embedder, provider, name) if action == nil { return nil } - return action + return (*Embedder)(action) } -// Embed runs the given [EmbedderAction]. -func Embed(ctx context.Context, emb *EmbedderAction, req *EmbedRequest) ([]float32, error) { - return emb.Run(ctx, req, nil) +// Embed runs the given [Embedder]. +func (e *Embedder) Embed(ctx context.Context, req *EmbedRequest) ([]float32, error) { + a := (*core.Action[*EmbedRequest, []float32, struct{}])(e) + return a.Run(ctx, req, nil) } diff --git a/go/internal/fakeembedder/fakeembedder.go b/go/internal/fakeembedder/fakeembedder.go index 47c2e06b6..083a5f887 100644 --- a/go/internal/fakeembedder/fakeembedder.go +++ b/go/internal/fakeembedder/fakeembedder.go @@ -26,7 +26,7 @@ import ( "github.com/firebase/genkit/go/ai" ) -// Embedder is a fake implementation of genkit.Embedder. +// Embedder is a fake implementation of an Embedder. type Embedder struct { registry map[*ai.Document][]float32 } diff --git a/go/internal/fakeembedder/fakeembedder_test.go b/go/internal/fakeembedder/fakeembedder_test.go index 85319054f..b157463dd 100644 --- a/go/internal/fakeembedder/fakeembedder_test.go +++ b/go/internal/fakeembedder/fakeembedder_test.go @@ -24,7 +24,7 @@ import ( func TestFakeEmbedder(t *testing.T) { embed := New() - embedAction := ai.DefineEmbedder("fake", "embed", embed.Embed) + emb := ai.DefineEmbedder("fake", "embed", embed.Embed) d := ai.DocumentFromText("fakeembedder test", nil) vals := []float32{1, 2} @@ -34,7 +34,7 @@ func TestFakeEmbedder(t *testing.T) { Document: d, } ctx := context.Background() - got, err := ai.Embed(ctx, embedAction, req) + got, err := emb.Embed(ctx, req) if err != nil { t.Fatal(err) } @@ -43,7 +43,7 @@ func TestFakeEmbedder(t *testing.T) { } req.Document = ai.DocumentFromText("missing document", nil) - if _, err = ai.Embed(ctx, embedAction, req); err == nil { + if _, err = emb.Embed(ctx, req); err == nil { t.Error("embedding unknown document succeeded unexpectedly") } } diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 04b5b30c2..c31114797 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -120,7 +120,7 @@ func defineModel(name string, caps ai.ModelCapabilities) *ai.Model { } // DefineEmbedder defines an embedder with a given name. -func DefineEmbedder(name string) *ai.EmbedderAction { +func DefineEmbedder(name string) *ai.Embedder { state.mu.Lock() defer state.mu.Unlock() if !state.initted { @@ -130,7 +130,7 @@ func DefineEmbedder(name string) *ai.EmbedderAction { } // requires state.mu -func defineEmbedder(name string) *ai.EmbedderAction { +func defineEmbedder(name string) *ai.Embedder { return ai.DefineEmbedder(provider, name, func(ctx context.Context, input *ai.EmbedRequest) ([]float32, error) { em := state.client.EmbeddingModel(name) parts, err := convertParts(input.Document.Content) @@ -151,9 +151,9 @@ func Model(name string) *ai.Model { return ai.LookupModel(provider, name) } -// Embedder returns the [ai.EmbedderAction] with the given name. +// Embedder returns the [ai.Embedder] with the given name. // It returns nil if the embedder was not configured. -func Embedder(name string) *ai.EmbedderAction { +func Embedder(name string) *ai.Embedder { return ai.LookupEmbedder(provider, name) } diff --git a/go/plugins/googleai/googleai_test.go b/go/plugins/googleai/googleai_test.go index 5998bd5cb..41bd03ed2 100644 --- a/go/plugins/googleai/googleai_test.go +++ b/go/plugins/googleai/googleai_test.go @@ -85,7 +85,7 @@ func TestLive(t *testing.T) { }, ) t.Run("embedder", func(t *testing.T) { - out, err := ai.Embed(ctx, embedder, &ai.EmbedRequest{ + out, err := embedder.Embed(ctx, &ai.EmbedRequest{ Document: ai.DocumentFromText("yellow banana", nil), }) if err != nil { diff --git a/go/plugins/localvec/localvec.go b/go/plugins/localvec/localvec.go index 75091ba90..ded43584f 100644 --- a/go/plugins/localvec/localvec.go +++ b/go/plugins/localvec/localvec.go @@ -39,7 +39,7 @@ const provider = "devLocalVectorStore" type Config struct { // Where to store the data. Defaults to os.TempDir. Dir string - Embedder *ai.EmbedderAction + Embedder *ai.Embedder EmbedderOptions any } @@ -73,7 +73,7 @@ func Retriever(name string) *ai.Retriever { // This is based on js/plugins/dev-local-vectorstore/src/index.ts. type docStore struct { filename string - embedder *ai.EmbedderAction + embedder *ai.Embedder embedderOptions any data map[string]dbValue } @@ -85,7 +85,7 @@ type dbValue struct { } // newDocStore returns a new ai.DocumentStore to register. -func newDocStore(dir, name string, embedder *ai.EmbedderAction, embedderOptions any) (*docStore, error) { +func newDocStore(dir, name string, embedder *ai.Embedder, embedderOptions any) (*docStore, error) { if dir == "" { dir = os.TempDir() } @@ -124,7 +124,7 @@ func (ds *docStore) index(ctx context.Context, req *ai.IndexerRequest) error { Document: doc, Options: ds.embedderOptions, } - vals, err := ai.Embed(ctx, ds.embedder, ereq) + vals, err := ds.embedder.Embed(ctx, ereq) if err != nil { return fmt.Errorf("localvec index embedding failed: %v", err) } @@ -186,7 +186,7 @@ func (ds *docStore) retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai Document: req.Document, Options: ds.embedderOptions, } - vals, err := ai.Embed(ctx, ds.embedder, ereq) + vals, err := ds.embedder.Embed(ctx, ereq) if err != nil { return nil, fmt.Errorf("localvec retrieve embedding failed: %v", err) } diff --git a/go/plugins/pinecone/genkit.go b/go/plugins/pinecone/genkit.go index 4ba5cb1df..ada31bbdd 100644 --- a/go/plugins/pinecone/genkit.go +++ b/go/plugins/pinecone/genkit.go @@ -75,7 +75,7 @@ type Config struct { // The index ID to use. IndexID string // Embedder to use. Required. - Embedder *ai.EmbedderAction + Embedder *ai.Embedder EmbedderOptions any // The metadata key to use to store document text // in Pinecone; the default is "_content". @@ -160,7 +160,7 @@ type RetrieverOptions struct { // docStore implements the genkit [ai.DocumentStore] interface. type docStore struct { index *index - embedder *ai.EmbedderAction + embedder *ai.Embedder embedderOptions any textKey string } @@ -190,7 +190,7 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error { Document: doc, Options: ds.embedderOptions, } - vals, err := ai.Embed(ctx, ds.embedder, ereq) + vals, err := ds.embedder.Embed(ctx, ereq) if err != nil { return fmt.Errorf("pinecone index embedding failed: %v", err) } @@ -285,7 +285,7 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai Document: req.Document, Options: ds.embedderOptions, } - vals, err := ai.Embed(ctx, ds.embedder, ereq) + vals, err := ds.embedder.Embed(ctx, ereq) if err != nil { return nil, fmt.Errorf("pinecone retrieve embedding failed: %v", err) } diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index d9e9026f4..cc08dc83d 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -89,7 +89,7 @@ func DefineModel(name string) *ai.Model { } // DefineModel defines an embedder with the given name. -func DefineEmbedder(name string) *ai.EmbedderAction { +func DefineEmbedder(name string) *ai.Embedder { state.mu.Lock() defer state.mu.Unlock() if !state.initted { @@ -107,9 +107,9 @@ func Model(name string) *ai.Model { return ai.LookupModel(provider, name) } -// Embedder returns the [ai.EmbedderAction] with the given name. +// Embedder returns the [ai.Embedder] with the given name. // It returns nil if the embedder was not configured. -func Embedder(name string) *ai.EmbedderAction { +func Embedder(name string) *ai.Embedder { return ai.LookupEmbedder(provider, name) } diff --git a/go/plugins/vertexai/vertexai_test.go b/go/plugins/vertexai/vertexai_test.go index c26f10606..c2900be4f 100644 --- a/go/plugins/vertexai/vertexai_test.go +++ b/go/plugins/vertexai/vertexai_test.go @@ -173,7 +173,7 @@ func TestLive(t *testing.T) { } }) t.Run("embedder", func(t *testing.T) { - out, err := ai.Embed(ctx, embedder, &ai.EmbedRequest{ + out, err := embedder.Embed(ctx, &ai.EmbedRequest{ Document: ai.DocumentFromText("time flies like an arrow", nil), }) if err != nil {