From a50d98bb0a385fc0ca2e29b4f9a651cf409fb30d Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 13 Jun 2024 04:59:06 -0400 Subject: [PATCH] [Go] Indexer and Retriever types Calling Retrieve on a retriever as a method is more natural in Go than calling a top-level function, as we currently have. This PR changes Indexer and Retriever from being aliases of Actions to being their own types, allowing methods. Subsequent PRs will do the same for Model, Embedder and possibly Tool. --- go/ai/retriever.go | 53 ++++++++++++++++------------ go/plugins/localvec/localvec.go | 6 ++-- go/plugins/localvec/localvec_test.go | 8 ----- go/plugins/pinecone/genkit.go | 8 ++--- go/plugins/pinecone/genkit_test.go | 4 +-- go/samples/menu/s04.go | 6 ++-- go/samples/rag/main.go | 4 +-- 7 files changed, 45 insertions(+), 44 deletions(-) diff --git a/go/ai/retriever.go b/go/ai/retriever.go index 50b405ffd..3e70b788e 100644 --- a/go/ai/retriever.go +++ b/go/ai/retriever.go @@ -22,10 +22,15 @@ import ( ) type ( - // An IndexerAction is used to index documents in a store. - IndexerAction = core.Action[*IndexerRequest, struct{}, struct{}] - // A RetrieverAction is used to retrieve indexed documents. - RetrieverAction = core.Action[*RetrieverRequest, *RetrieverResponse, struct{}] + // An Indexer is used to index documents in a store. + Indexer core.Action[*IndexerRequest, struct{}, struct{}] + // A Retriever is used to retrieve indexed documents. + Retriever core.Action[*RetrieverRequest, *RetrieverResponse, struct{}] +) + +type ( + indexerAction = core.Action[*IndexerRequest, struct{}, struct{}] + retrieverAction = core.Action[*RetrieverRequest, *RetrieverResponse, struct{}] ) // IndexerRequest is the data we pass to add documents to the database. @@ -48,39 +53,43 @@ type RetrieverResponse struct { } // DefineIndexer registers the given index function as an action, and returns an -// [IndexerAction] that runs it. -func DefineIndexer(provider, name string, index func(context.Context, *IndexerRequest) error) *IndexerAction { +// [Indexer] that runs it. +func DefineIndexer(provider, name string, index func(context.Context, *IndexerRequest) error) *Indexer { f := func(ctx context.Context, req *IndexerRequest) (struct{}, error) { return struct{}{}, index(ctx, req) } - return core.DefineAction(provider, name, atype.Indexer, nil, f) + return (*Indexer)(core.DefineAction(provider, name, atype.Indexer, nil, f)) } -// LookupIndexer looks up a [IndexerAction] registered by [DefineIndexer]. +// LookupIndexer looks up a [Indexer] registered by [DefineIndexer]. // It returns nil if the model was not defined. -func LookupIndexer(provider, name string) *IndexerAction { - return core.LookupActionFor[*IndexerRequest, struct{}, struct{}](atype.Indexer, provider, name) +func LookupIndexer(provider, name string) *Indexer { + return (*Indexer)(core.LookupActionFor[*IndexerRequest, struct{}, struct{}](atype.Indexer, provider, name)) } // DefineRetriever registers the given retrieve function as an action, and returns a -// [RetrieverAction] that runs it. -func DefineRetriever(provider, name string, ret func(context.Context, *RetrieverRequest) (*RetrieverResponse, error)) *RetrieverAction { - return core.DefineAction(provider, name, atype.Retriever, nil, ret) +// [Retriever] that runs it. +func DefineRetriever(provider, name string, ret func(context.Context, *RetrieverRequest) (*RetrieverResponse, error)) *Retriever { + return (*Retriever)(core.DefineAction(provider, name, atype.Retriever, nil, ret)) } -// LookupRetriever looks up a [RetrieverAction] registered by [DefineRetriever]. +// LookupRetriever looks up a [Retriever] registered by [DefineRetriever]. // It returns nil if the model was not defined. -func LookupRetriever(provider, name string) *RetrieverAction { - return core.LookupActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](atype.Retriever, provider, name) +func LookupRetriever(provider, name string) *Retriever { + return (*Retriever)(core.LookupActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](atype.Retriever, provider, name)) } -// Index runs the given [IndexerAction]. -func Index(ctx context.Context, indexer *IndexerAction, req *IndexerRequest) error { - _, err := indexer.Run(ctx, req, nil) +// Index runs the given [Indexer]. +func (i *Indexer) Index(ctx context.Context, req *IndexerRequest) error { + _, err := (*indexerAction)(i).Run(ctx, req, nil) return err } -// Retrieve runs the given [RetrieverAction]. -func Retrieve(ctx context.Context, retriever *RetrieverAction, req *RetrieverRequest) (*RetrieverResponse, error) { - return retriever.Run(ctx, req, nil) +// Retrieve runs the given [Retriever]. +func (r *Retriever) Retrieve(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { + return (*retrieverAction)(r).Run(ctx, req, nil) } + +func (i *Indexer) Name() string { return (*indexerAction)(i).Name() } + +func (r *Retriever) Name() string { return (*retrieverAction)(r).Name() } diff --git a/go/plugins/localvec/localvec.go b/go/plugins/localvec/localvec.go index 3258af9ae..1ff15b219 100644 --- a/go/plugins/localvec/localvec.go +++ b/go/plugins/localvec/localvec.go @@ -48,7 +48,7 @@ func Init() error { return nil } // DefineStore defines an indexer and retriever that share the same underlying storage. // The name uniquely identifies the the indexer and retriever in the registry. -func DefineStore(name string, cfg Config) (*ai.IndexerAction, *ai.RetrieverAction, error) { +func DefineStore(name string, cfg Config) (*ai.Indexer, *ai.Retriever, error) { ds, err := newDocStore(cfg.Dir, name, cfg.Embedder, cfg.EmbedderOptions) if err != nil { return nil, nil, err @@ -59,13 +59,13 @@ func DefineStore(name string, cfg Config) (*ai.IndexerAction, *ai.RetrieverActio } // Indexer returns the registered indexer with the given name. -func Indexer(name string) *ai.IndexerAction { +func Indexer(name string) *ai.Indexer { return ai.LookupIndexer(provider, name) } // Retriever returns the retriever with the given name. // The name must match the [Config.Name] value passed to [Init]. -func Retriever(name string) *ai.RetrieverAction { +func Retriever(name string) *ai.Retriever { return ai.LookupRetriever(provider, name) } diff --git a/go/plugins/localvec/localvec_test.go b/go/plugins/localvec/localvec_test.go index 206bc73d0..276fdfd6a 100644 --- a/go/plugins/localvec/localvec_test.go +++ b/go/plugins/localvec/localvec_test.go @@ -205,11 +205,3 @@ func TestInit(t *testing.T) { t.Errorf("got %q, want %q", g, want) } } - -func names[T interface{ Name() string }](xs []T) []string { - var ns []string - for _, x := range xs { - ns = append(ns, x.Name()) - } - return ns -} diff --git a/go/plugins/pinecone/genkit.go b/go/plugins/pinecone/genkit.go index a6f23919f..fb054698a 100644 --- a/go/plugins/pinecone/genkit.go +++ b/go/plugins/pinecone/genkit.go @@ -82,7 +82,7 @@ type Config struct { TextKey string } -func DefineIndexer(ctx context.Context, cfg Config) (*ai.IndexerAction, error) { +func DefineIndexer(ctx context.Context, cfg Config) (*ai.Indexer, error) { ds, err := newDocStore(ctx, cfg) if err != nil { return nil, err @@ -90,7 +90,7 @@ func DefineIndexer(ctx context.Context, cfg Config) (*ai.IndexerAction, error) { return ai.DefineIndexer(provider, cfg.IndexID, ds.Index), nil } -func DefineRetriever(ctx context.Context, cfg Config) (*ai.RetrieverAction, error) { +func DefineRetriever(ctx context.Context, cfg Config) (*ai.Retriever, error) { ds, err := newDocStore(ctx, cfg) if err != nil { return nil, err @@ -131,12 +131,12 @@ func newDocStore(ctx context.Context, cfg Config) (*docStore, error) { } // Indexer returns the indexer with the given index name. -func Indexer(name string) *ai.IndexerAction { +func Indexer(name string) *ai.Indexer { return ai.LookupIndexer(provider, name) } // Retriever returns the retriever with the given index name. -func Retriever(name string) *ai.RetrieverAction { +func Retriever(name string) *ai.Retriever { return ai.LookupRetriever(provider, name) } diff --git a/go/plugins/pinecone/genkit_test.go b/go/plugins/pinecone/genkit_test.go index 7597d3c9a..b14c95703 100644 --- a/go/plugins/pinecone/genkit_test.go +++ b/go/plugins/pinecone/genkit_test.go @@ -97,7 +97,7 @@ func TestGenkit(t *testing.T) { Options: indexerOptions, } t.Logf("index flag = %q, indexData.Host = %q", *testIndex, indexData.Host) - err = ai.Index(ctx, indexer, indexerReq) + err = indexer.Index(ctx, indexerReq) if err != nil { t.Fatalf("Index operation failed: %v", err) } @@ -134,7 +134,7 @@ func TestGenkit(t *testing.T) { Document: d1, Options: retrieverOptions, } - retrieverResp, err := ai.Retrieve(ctx, retriever, retrieverReq) + retrieverResp, err := retriever.Retrieve(ctx, retrieverReq) if err != nil { t.Fatalf("Retrieve operation failed: %v", err) } diff --git a/go/samples/menu/s04.go b/go/samples/menu/s04.go index e2e404ee7..6b3beff8a 100644 --- a/go/samples/menu/s04.go +++ b/go/samples/menu/s04.go @@ -24,7 +24,7 @@ import ( "github.com/firebase/genkit/go/plugins/localvec" ) -func setup04(ctx context.Context, indexer *ai.IndexerAction, retriever *ai.RetrieverAction, model *ai.ModelAction) error { +func setup04(ctx context.Context, indexer *ai.Indexer, retriever *ai.Retriever, model *ai.ModelAction) error { ragDataMenuPrompt, err := dotprompt.Define("s04_ragDataMenu", ` You are acting as Walt, a helpful AI assistant here at the restaurant. @@ -70,7 +70,7 @@ func setup04(ctx context.Context, indexer *ai.IndexerAction, retriever *ai.Retri req := &ai.IndexerRequest{ Documents: docs, } - if err := ai.Index(ctx, indexer, req); err != nil { + if err := indexer.Index(ctx, req); err != nil { return nil, err } @@ -89,7 +89,7 @@ func setup04(ctx context.Context, indexer *ai.IndexerAction, retriever *ai.Retri K: 3, }, } - resp, err := ai.Retrieve(ctx, retriever, req) + resp, err := retriever.Retrieve(ctx, req) if err != nil { return nil, err } diff --git a/go/samples/rag/main.go b/go/samples/rag/main.go index f9ae4f5b5..0e7e40392 100644 --- a/go/samples/rag/main.go +++ b/go/samples/rag/main.go @@ -109,7 +109,7 @@ func main() { indexerReq := &ai.IndexerRequest{ Documents: []*ai.Document{d1, d2, d3}, } - err := ai.Index(ctx, indexer, indexerReq) + err := indexer.Index(ctx, indexerReq) if err != nil { return "", err } @@ -118,7 +118,7 @@ func main() { retrieverReq := &ai.RetrieverRequest{ Document: dRequest, } - response, err := ai.Retrieve(ctx, retriever, retrieverReq) + response, err := retriever.Retrieve(ctx, retrieverReq) if err != nil { return "", err }