Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Go] replace Embedder interface with EmbedderAction #349

Merged
merged 1 commit into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 13 additions & 18 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@ import (
"github.com/firebase/genkit/go/core"
)

// Embedder is the interface used to convert a document to a
// multidimensional vector. A [DocumentStore] will use a value of this type.
type Embedder interface {
Embed(context.Context, *EmbedRequest) ([]float32, error)
}
// EmbedderAction is used to convert a document to a
// multidimensional vector.
type EmbedderAction = core.Action[*EmbedRequest, []float32, struct{}]

// EmbedRequest is the data we pass to convert a document
// to a multidimensional vector.
Expand All @@ -34,25 +32,22 @@ type EmbedRequest struct {
}

// DefineEmbedder registers the given embed function as an action, and returns an
// [Embedder] whose Embed method runs it.
func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) ([]float32, error)) Embedder {
return embedder{core.DefineAction(provider, name, core.ActionTypeEmbedder, nil, embed)}
// [EmbedderAction] that runs it.
func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) ([]float32, error)) *EmbedderAction {
return core.DefineAction(provider, name, core.ActionTypeEmbedder, nil, embed)
}

// LookupEmbedder looks up an [Embedder] registered by [DefineEmbedder].
// It returns nil if the Embedder was not defined.
func LookupEmbedder(provider, name string) Embedder {
// LookupEmbedder looks up an [EmbedderAction] registered by [DefineEmbedder].
// It returns nil if the embedder was not defined.
func LookupEmbedder(provider, name string) *EmbedderAction {
action := core.LookupActionFor[*EmbedRequest, []float32, struct{}](core.ActionTypeEmbedder, provider, name)
if action == nil {
return nil
}
return embedder{action}
}

type embedder struct {
embedAction *core.Action[*EmbedRequest, []float32, struct{}]
return action
}

func (e embedder) Embed(ctx context.Context, req *EmbedRequest) ([]float32, error) {
return e.embedAction.Run(ctx, req, nil)
// Embed runs the given [EmbedderAction].
func Embed(ctx context.Context, emb *EmbedderAction, req *EmbedRequest) ([]float32, error) {
return emb.Run(ctx, req, nil)
}
1 change: 0 additions & 1 deletion go/internal/fakeembedder/fakeembedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ func (e *Embedder) Register(d *ai.Document, vals []float32) {
e.registry[d] = vals
}

// Embed implements genkit.Embedder.
func (e *Embedder) Embed(ctx context.Context, req *ai.EmbedRequest) ([]float32, error) {
vals, ok := e.registry[req.Document]
if !ok {
Expand Down
8 changes: 3 additions & 5 deletions go/internal/fakeembedder/fakeembedder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,17 @@ import (

func TestFakeEmbedder(t *testing.T) {
embed := New()

embedAction := ai.DefineEmbedder("fake", "embed", embed.Embed)
d := ai.DocumentFromText("fakeembedder test", nil)

vals := []float32{1, 2}
embed.Register(d, vals)

var genkitEmbedder ai.Embedder
genkitEmbedder = embed
req := &ai.EmbedRequest{
Document: d,
}
ctx := context.Background()
got, err := genkitEmbedder.Embed(ctx, req)
got, err := ai.Embed(ctx, embedAction, req)
if err != nil {
t.Fatal(err)
}
Expand All @@ -45,7 +43,7 @@ func TestFakeEmbedder(t *testing.T) {
}

req.Document = ai.DocumentFromText("missing document", nil)
if _, err = genkitEmbedder.Embed(ctx, req); err == nil {
if _, err = ai.Embed(ctx, embedAction, req); err == nil {
t.Error("embedding unknown document succeeded unexpectedly")
}
}
4 changes: 2 additions & 2 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ func Model(name string) *ai.ModelAction {
return ai.LookupModel(provider, name)
}

// Embedder returns the embedder with the given name.
// Embedder returns the [ai.EmbedderAction] with the given name.
// It returns nil if the embedder was not configured.
func Embedder(name string) ai.Embedder {
func Embedder(name string) *ai.EmbedderAction {
return ai.LookupEmbedder(provider, name)
}

Expand Down
2 changes: 1 addition & 1 deletion go/plugins/googleai/googleai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func TestLive(t *testing.T) {
},
)
t.Run("embedder", func(t *testing.T) {
out, err := googleai.Embedder(embeddingModel).Embed(ctx, &ai.EmbedRequest{
out, err := ai.Embed(ctx, googleai.Embedder(embeddingModel), &ai.EmbedRequest{
Document: ai.DocumentFromText("yellow banana", nil),
})
if err != nil {
Expand Down
10 changes: 5 additions & 5 deletions go/plugins/localvec/localvec.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import (
// retriever with genkit, and also return it.
// This retriever may only be used by a single goroutine at a time.
// This is based on js/plugins/dev-local-vectorstore/src/index.ts.
func New(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOptions any) (ai.DocumentStore, error) {
func New(ctx context.Context, dir, name string, embedder *ai.EmbedderAction, embedderOptions any) (ai.DocumentStore, error) {
r, err := newDocStore(ctx, dir, name, embedder, embedderOptions)
if err != nil {
return nil, err
Expand All @@ -50,7 +50,7 @@ func New(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOp
// for a local vector database.
type docStore struct {
filename string
embedder ai.Embedder
embedder *ai.EmbedderAction
embedderOptions any
data map[string]dbValue
}
Expand All @@ -62,7 +62,7 @@ type dbValue struct {
}

// newDocStore returns a new ai.DocumentStore to register.
func newDocStore(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOptions any) (ai.DocumentStore, error) {
func newDocStore(ctx context.Context, dir, name string, embedder *ai.EmbedderAction, embedderOptions any) (ai.DocumentStore, error) {
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, err
}
Expand Down Expand Up @@ -98,7 +98,7 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
Document: doc,
Options: ds.embedderOptions,
}
vals, err := ds.embedder.Embed(ctx, ereq)
vals, err := ai.Embed(ctx, ds.embedder, ereq)
if err != nil {
return fmt.Errorf("localvec index embedding failed: %v", err)
}
Expand Down Expand Up @@ -160,7 +160,7 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
Document: req.Document,
Options: ds.embedderOptions,
}
vals, err := ds.embedder.Embed(ctx, ereq)
vals, err := ai.Embed(ctx, ds.embedder, ereq)
if err != nil {
return nil, fmt.Errorf("localvec retrieve embedding failed: %v", err)
}
Expand Down
9 changes: 5 additions & 4 deletions go/plugins/localvec/localvec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ func TestLocalVec(t *testing.T) {
embedder.Register(d1, v1)
embedder.Register(d2, v2)
embedder.Register(d3, v3)

ds, err := newDocStore(ctx, t.TempDir(), "testLocalVec", embedder, nil)
embedAction := ai.DefineEmbedder("fake", "embedder1", embedder.Embed)
ds, err := newDocStore(ctx, t.TempDir(), "testLocalVec", embedAction, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -110,10 +110,11 @@ func TestPersistentIndexing(t *testing.T) {
embedder.Register(d1, v1)
embedder.Register(d2, v2)
embedder.Register(d3, v3)
embedAction := ai.DefineEmbedder("fake", "embedder2", embedder.Embed)

tDir := t.TempDir()

ds, err := newDocStore(ctx, tDir, "testLocalVec", embedder, nil)
ds, err := newDocStore(ctx, tDir, "testLocalVec", embedAction, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -144,7 +145,7 @@ func TestPersistentIndexing(t *testing.T) {
t.Errorf("got %d results, expected 2", len(docs))
}

dsAnother, err := newDocStore(ctx, tDir, "testLocalVec", embedder, nil)
dsAnother, err := newDocStore(ctx, tDir, "testLocalVec", embedAction, nil)
if err != nil {
t.Fatal(err)
}
Expand Down
8 changes: 4 additions & 4 deletions go/plugins/pinecone/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ const defaultTextKey = "_content"
//
// The textKey parameter is the metadata key to use to store document text
// in Pinecone; the default is "_content".
func New(ctx context.Context, apiKey, host string, embedder ai.Embedder, embedderOptions any, textKey string) (ai.DocumentStore, error) {
func New(ctx context.Context, apiKey, host string, embedder *ai.EmbedderAction, embedderOptions any, textKey string) (ai.DocumentStore, error) {
client, err := NewClient(ctx, apiKey)
if err != nil {
return nil, err
Expand Down Expand Up @@ -90,7 +90,7 @@ type RetrieverOptions struct {
// docStore implements the genkit [ai.DocumentStore] interface.
type docStore struct {
index *Index
embedder ai.Embedder
embedder *ai.EmbedderAction
embedderOptions any
textKey string
}
Expand Down Expand Up @@ -120,7 +120,7 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
Document: doc,
Options: ds.embedderOptions,
}
vals, err := ds.embedder.Embed(ctx, ereq)
vals, err := ai.Embed(ctx, ds.embedder, ereq)
if err != nil {
return fmt.Errorf("pinecone index embedding failed: %v", err)
}
Expand Down Expand Up @@ -215,7 +215,7 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
Document: req.Document,
Options: ds.embedderOptions,
}
vals, err := ds.embedder.Embed(ctx, ereq)
vals, err := ai.Embed(ctx, ds.embedder, ereq)
if err != nil {
return nil, fmt.Errorf("pinecone retrieve embedding failed: %v", err)
}
Expand Down
3 changes: 2 additions & 1 deletion go/plugins/pinecone/genkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ func TestGenkit(t *testing.T) {
embedder.Register(d1, v1)
embedder.Register(d2, v2)
embedder.Register(d3, v3)
embedAction := ai.DefineEmbedder("fake", "embedder3", embedder.Embed)

r, err := New(ctx, *testAPIKey, indexData.Host, embedder, nil, "")
r, err := New(ctx, *testAPIKey, indexData.Host, embedAction, nil, "")
if err != nil {
t.Fatal(err)
}
Expand Down
4 changes: 2 additions & 2 deletions go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ func Model(name string) *ai.ModelAction {
return ai.LookupModel(provider, name)
}

// Embedder returns the embedder with the given name.
// Embedder returns the [ai.EmbedderAction] with the given name.
// It returns nil if the embedder was not configured.
func Embedder(name string) ai.Embedder {
func Embedder(name string) *ai.EmbedderAction {
return ai.LookupEmbedder(provider, name)
}

Expand Down
2 changes: 1 addition & 1 deletion go/plugins/vertexai/vertexai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func TestLive(t *testing.T) {
}
})
t.Run("embedder", func(t *testing.T) {
out, err := vertexai.Embedder(embedderName).Embed(ctx, &ai.EmbedRequest{
out, err := ai.Embed(ctx, vertexai.Embedder(embedderName), &ai.EmbedRequest{
Document: ai.DocumentFromText("time flies like an arrow", nil),
})
if err != nil {
Expand Down
Loading