Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Go] ai.Embedder is a separate type from its action #461

Merged
merged 1 commit into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import (
"github.com/firebase/genkit/go/internal/atype"
)

// EmbedderAction is used to convert a document to a
// An Embedder is used to convert a document to a
// multidimensional vector.
type EmbedderAction = core.Action[*EmbedRequest, []float32, struct{}]
type Embedder core.Action[*EmbedRequest, []float32, struct{}]

// EmbedRequest is the data we pass to convert a document
// to a multidimensional vector.
Expand All @@ -34,21 +34,22 @@ type EmbedRequest struct {

// DefineEmbedder registers the given embed function as an action, and returns an
// [EmbedderAction] that runs it.
func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) ([]float32, error)) *EmbedderAction {
return core.DefineAction(provider, name, atype.Embedder, nil, embed)
func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) ([]float32, error)) *Embedder {
return (*Embedder)(core.DefineAction(provider, name, atype.Embedder, nil, embed))
}

// LookupEmbedder looks up an [EmbedderAction] registered by [DefineEmbedder].
// It returns nil if the embedder was not defined.
func LookupEmbedder(provider, name string) *EmbedderAction {
func LookupEmbedder(provider, name string) *Embedder {
action := core.LookupActionFor[*EmbedRequest, []float32, struct{}](atype.Embedder, provider, name)
if action == nil {
return nil
}
return action
return (*Embedder)(action)
}

// Embed runs the given [EmbedderAction].
func Embed(ctx context.Context, emb *EmbedderAction, req *EmbedRequest) ([]float32, error) {
return emb.Run(ctx, req, nil)
// Embed runs the given [Embedder].
func (e *Embedder) Embed(ctx context.Context, req *EmbedRequest) ([]float32, error) {
a := (*core.Action[*EmbedRequest, []float32, struct{}])(e)
return a.Run(ctx, req, nil)
}
2 changes: 1 addition & 1 deletion go/internal/fakeembedder/fakeembedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
"github.com/firebase/genkit/go/ai"
)

// Embedder is a fake implementation of genkit.Embedder.
// Embedder is a fake implementation of an Embedder.
type Embedder struct {
registry map[*ai.Document][]float32
}
Expand Down
6 changes: 3 additions & 3 deletions go/internal/fakeembedder/fakeembedder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (

func TestFakeEmbedder(t *testing.T) {
embed := New()
embedAction := ai.DefineEmbedder("fake", "embed", embed.Embed)
emb := ai.DefineEmbedder("fake", "embed", embed.Embed)
d := ai.DocumentFromText("fakeembedder test", nil)

vals := []float32{1, 2}
Expand All @@ -34,7 +34,7 @@ func TestFakeEmbedder(t *testing.T) {
Document: d,
}
ctx := context.Background()
got, err := ai.Embed(ctx, embedAction, req)
got, err := emb.Embed(ctx, req)
if err != nil {
t.Fatal(err)
}
Expand All @@ -43,7 +43,7 @@ func TestFakeEmbedder(t *testing.T) {
}

req.Document = ai.DocumentFromText("missing document", nil)
if _, err = ai.Embed(ctx, embedAction, req); err == nil {
if _, err = emb.Embed(ctx, req); err == nil {
t.Error("embedding unknown document succeeded unexpectedly")
}
}
8 changes: 4 additions & 4 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func defineModel(name string, caps ai.ModelCapabilities) *ai.Model {
}

// DefineEmbedder defines an embedder with a given name.
func DefineEmbedder(name string) *ai.EmbedderAction {
func DefineEmbedder(name string) *ai.Embedder {
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
Expand All @@ -130,7 +130,7 @@ func DefineEmbedder(name string) *ai.EmbedderAction {
}

// requires state.mu
func defineEmbedder(name string) *ai.EmbedderAction {
func defineEmbedder(name string) *ai.Embedder {
return ai.DefineEmbedder(provider, name, func(ctx context.Context, input *ai.EmbedRequest) ([]float32, error) {
em := state.client.EmbeddingModel(name)
parts, err := convertParts(input.Document.Content)
Expand All @@ -151,9 +151,9 @@ func Model(name string) *ai.Model {
return ai.LookupModel(provider, name)
}

// Embedder returns the [ai.EmbedderAction] with the given name.
// Embedder returns the [ai.Embedder] with the given name.
// It returns nil if the embedder was not configured.
func Embedder(name string) *ai.EmbedderAction {
func Embedder(name string) *ai.Embedder {
return ai.LookupEmbedder(provider, name)
}

Expand Down
2 changes: 1 addition & 1 deletion go/plugins/googleai/googleai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func TestLive(t *testing.T) {
},
)
t.Run("embedder", func(t *testing.T) {
out, err := ai.Embed(ctx, embedder, &ai.EmbedRequest{
out, err := embedder.Embed(ctx, &ai.EmbedRequest{
Document: ai.DocumentFromText("yellow banana", nil),
})
if err != nil {
Expand Down
10 changes: 5 additions & 5 deletions go/plugins/localvec/localvec.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ const provider = "devLocalVectorStore"
type Config struct {
// Where to store the data. Defaults to os.TempDir.
Dir string
Embedder *ai.EmbedderAction
Embedder *ai.Embedder
EmbedderOptions any
}

Expand Down Expand Up @@ -73,7 +73,7 @@ func Retriever(name string) *ai.Retriever {
// This is based on js/plugins/dev-local-vectorstore/src/index.ts.
type docStore struct {
filename string
embedder *ai.EmbedderAction
embedder *ai.Embedder
embedderOptions any
data map[string]dbValue
}
Expand All @@ -85,7 +85,7 @@ type dbValue struct {
}

// newDocStore returns a new ai.DocumentStore to register.
func newDocStore(dir, name string, embedder *ai.EmbedderAction, embedderOptions any) (*docStore, error) {
func newDocStore(dir, name string, embedder *ai.Embedder, embedderOptions any) (*docStore, error) {
if dir == "" {
dir = os.TempDir()
}
Expand Down Expand Up @@ -124,7 +124,7 @@ func (ds *docStore) index(ctx context.Context, req *ai.IndexerRequest) error {
Document: doc,
Options: ds.embedderOptions,
}
vals, err := ai.Embed(ctx, ds.embedder, ereq)
vals, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return fmt.Errorf("localvec index embedding failed: %v", err)
}
Expand Down Expand Up @@ -186,7 +186,7 @@ func (ds *docStore) retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
Document: req.Document,
Options: ds.embedderOptions,
}
vals, err := ai.Embed(ctx, ds.embedder, ereq)
vals, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return nil, fmt.Errorf("localvec retrieve embedding failed: %v", err)
}
Expand Down
8 changes: 4 additions & 4 deletions go/plugins/pinecone/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ type Config struct {
// The index ID to use.
IndexID string
// Embedder to use. Required.
Embedder *ai.EmbedderAction
Embedder *ai.Embedder
EmbedderOptions any
// The metadata key to use to store document text
// in Pinecone; the default is "_content".
Expand Down Expand Up @@ -160,7 +160,7 @@ type RetrieverOptions struct {
// docStore implements the genkit [ai.DocumentStore] interface.
type docStore struct {
index *index
embedder *ai.EmbedderAction
embedder *ai.Embedder
embedderOptions any
textKey string
}
Expand Down Expand Up @@ -190,7 +190,7 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
Document: doc,
Options: ds.embedderOptions,
}
vals, err := ai.Embed(ctx, ds.embedder, ereq)
vals, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return fmt.Errorf("pinecone index embedding failed: %v", err)
}
Expand Down Expand Up @@ -285,7 +285,7 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
Document: req.Document,
Options: ds.embedderOptions,
}
vals, err := ai.Embed(ctx, ds.embedder, ereq)
vals, err := ds.embedder.Embed(ctx, ereq)
if err != nil {
return nil, fmt.Errorf("pinecone retrieve embedding failed: %v", err)
}
Expand Down
6 changes: 3 additions & 3 deletions go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func DefineModel(name string) *ai.Model {
}

// DefineModel defines an embedder with the given name.
func DefineEmbedder(name string) *ai.EmbedderAction {
func DefineEmbedder(name string) *ai.Embedder {
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
Expand All @@ -107,9 +107,9 @@ func Model(name string) *ai.Model {
return ai.LookupModel(provider, name)
}

// Embedder returns the [ai.EmbedderAction] with the given name.
// Embedder returns the [ai.Embedder] with the given name.
// It returns nil if the embedder was not configured.
func Embedder(name string) *ai.EmbedderAction {
func Embedder(name string) *ai.Embedder {
return ai.LookupEmbedder(provider, name)
}

Expand Down
2 changes: 1 addition & 1 deletion go/plugins/vertexai/vertexai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func TestLive(t *testing.T) {
}
})
t.Run("embedder", func(t *testing.T) {
out, err := ai.Embed(ctx, embedder, &ai.EmbedRequest{
out, err := embedder.Embed(ctx, &ai.EmbedRequest{
Document: ai.DocumentFromText("time flies like an arrow", nil),
})
if err != nil {
Expand Down
Loading