diff --git a/go/ai/retriever.go b/go/ai/retriever.go index 50b405ffd..3e70b788e 100644 --- a/go/ai/retriever.go +++ b/go/ai/retriever.go @@ -22,10 +22,15 @@ import ( ) type ( - // An IndexerAction is used to index documents in a store. - IndexerAction = core.Action[*IndexerRequest, struct{}, struct{}] - // A RetrieverAction is used to retrieve indexed documents. - RetrieverAction = core.Action[*RetrieverRequest, *RetrieverResponse, struct{}] + // An Indexer is used to index documents in a store. + Indexer core.Action[*IndexerRequest, struct{}, struct{}] + // A Retriever is used to retrieve indexed documents. + Retriever core.Action[*RetrieverRequest, *RetrieverResponse, struct{}] +) + +type ( + indexerAction = core.Action[*IndexerRequest, struct{}, struct{}] + retrieverAction = core.Action[*RetrieverRequest, *RetrieverResponse, struct{}] ) // IndexerRequest is the data we pass to add documents to the database. @@ -48,39 +53,43 @@ type RetrieverResponse struct { } // DefineIndexer registers the given index function as an action, and returns an -// [IndexerAction] that runs it. -func DefineIndexer(provider, name string, index func(context.Context, *IndexerRequest) error) *IndexerAction { +// [Indexer] that runs it. +func DefineIndexer(provider, name string, index func(context.Context, *IndexerRequest) error) *Indexer { f := func(ctx context.Context, req *IndexerRequest) (struct{}, error) { return struct{}{}, index(ctx, req) } - return core.DefineAction(provider, name, atype.Indexer, nil, f) + return (*Indexer)(core.DefineAction(provider, name, atype.Indexer, nil, f)) } -// LookupIndexer looks up a [IndexerAction] registered by [DefineIndexer]. +// LookupIndexer looks up a [Indexer] registered by [DefineIndexer]. // It returns nil if the model was not defined. -func LookupIndexer(provider, name string) *IndexerAction { - return core.LookupActionFor[*IndexerRequest, struct{}, struct{}](atype.Indexer, provider, name) +func LookupIndexer(provider, name string) *Indexer { + return (*Indexer)(core.LookupActionFor[*IndexerRequest, struct{}, struct{}](atype.Indexer, provider, name)) } // DefineRetriever registers the given retrieve function as an action, and returns a -// [RetrieverAction] that runs it. -func DefineRetriever(provider, name string, ret func(context.Context, *RetrieverRequest) (*RetrieverResponse, error)) *RetrieverAction { - return core.DefineAction(provider, name, atype.Retriever, nil, ret) +// [Retriever] that runs it. +func DefineRetriever(provider, name string, ret func(context.Context, *RetrieverRequest) (*RetrieverResponse, error)) *Retriever { + return (*Retriever)(core.DefineAction(provider, name, atype.Retriever, nil, ret)) } -// LookupRetriever looks up a [RetrieverAction] registered by [DefineRetriever]. +// LookupRetriever looks up a [Retriever] registered by [DefineRetriever]. // It returns nil if the model was not defined. -func LookupRetriever(provider, name string) *RetrieverAction { - return core.LookupActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](atype.Retriever, provider, name) +func LookupRetriever(provider, name string) *Retriever { + return (*Retriever)(core.LookupActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](atype.Retriever, provider, name)) } -// Index runs the given [IndexerAction]. -func Index(ctx context.Context, indexer *IndexerAction, req *IndexerRequest) error { - _, err := indexer.Run(ctx, req, nil) +// Index runs the given [Indexer]. +func (i *Indexer) Index(ctx context.Context, req *IndexerRequest) error { + _, err := (*indexerAction)(i).Run(ctx, req, nil) return err } -// Retrieve runs the given [RetrieverAction]. -func Retrieve(ctx context.Context, retriever *RetrieverAction, req *RetrieverRequest) (*RetrieverResponse, error) { - return retriever.Run(ctx, req, nil) +// Retrieve runs the given [Retriever]. +func (r *Retriever) Retrieve(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { + return (*retrieverAction)(r).Run(ctx, req, nil) } + +func (i *Indexer) Name() string { return (*indexerAction)(i).Name() } + +func (r *Retriever) Name() string { return (*retrieverAction)(r).Name() } diff --git a/go/plugins/localvec/localvec.go b/go/plugins/localvec/localvec.go index 3258af9ae..1ff15b219 100644 --- a/go/plugins/localvec/localvec.go +++ b/go/plugins/localvec/localvec.go @@ -48,7 +48,7 @@ func Init() error { return nil } // DefineStore defines an indexer and retriever that share the same underlying storage. // The name uniquely identifies the the indexer and retriever in the registry. -func DefineStore(name string, cfg Config) (*ai.IndexerAction, *ai.RetrieverAction, error) { +func DefineStore(name string, cfg Config) (*ai.Indexer, *ai.Retriever, error) { ds, err := newDocStore(cfg.Dir, name, cfg.Embedder, cfg.EmbedderOptions) if err != nil { return nil, nil, err @@ -59,13 +59,13 @@ func DefineStore(name string, cfg Config) (*ai.IndexerAction, *ai.RetrieverActio } // Indexer returns the registered indexer with the given name. -func Indexer(name string) *ai.IndexerAction { +func Indexer(name string) *ai.Indexer { return ai.LookupIndexer(provider, name) } // Retriever returns the retriever with the given name. // The name must match the [Config.Name] value passed to [Init]. -func Retriever(name string) *ai.RetrieverAction { +func Retriever(name string) *ai.Retriever { return ai.LookupRetriever(provider, name) } diff --git a/go/plugins/localvec/localvec_test.go b/go/plugins/localvec/localvec_test.go index 206bc73d0..276fdfd6a 100644 --- a/go/plugins/localvec/localvec_test.go +++ b/go/plugins/localvec/localvec_test.go @@ -205,11 +205,3 @@ func TestInit(t *testing.T) { t.Errorf("got %q, want %q", g, want) } } - -func names[T interface{ Name() string }](xs []T) []string { - var ns []string - for _, x := range xs { - ns = append(ns, x.Name()) - } - return ns -} diff --git a/go/plugins/pinecone/genkit.go b/go/plugins/pinecone/genkit.go index a6f23919f..fb054698a 100644 --- a/go/plugins/pinecone/genkit.go +++ b/go/plugins/pinecone/genkit.go @@ -82,7 +82,7 @@ type Config struct { TextKey string } -func DefineIndexer(ctx context.Context, cfg Config) (*ai.IndexerAction, error) { +func DefineIndexer(ctx context.Context, cfg Config) (*ai.Indexer, error) { ds, err := newDocStore(ctx, cfg) if err != nil { return nil, err @@ -90,7 +90,7 @@ func DefineIndexer(ctx context.Context, cfg Config) (*ai.IndexerAction, error) { return ai.DefineIndexer(provider, cfg.IndexID, ds.Index), nil } -func DefineRetriever(ctx context.Context, cfg Config) (*ai.RetrieverAction, error) { +func DefineRetriever(ctx context.Context, cfg Config) (*ai.Retriever, error) { ds, err := newDocStore(ctx, cfg) if err != nil { return nil, err @@ -131,12 +131,12 @@ func newDocStore(ctx context.Context, cfg Config) (*docStore, error) { } // Indexer returns the indexer with the given index name. -func Indexer(name string) *ai.IndexerAction { +func Indexer(name string) *ai.Indexer { return ai.LookupIndexer(provider, name) } // Retriever returns the retriever with the given index name. -func Retriever(name string) *ai.RetrieverAction { +func Retriever(name string) *ai.Retriever { return ai.LookupRetriever(provider, name) } diff --git a/go/plugins/pinecone/genkit_test.go b/go/plugins/pinecone/genkit_test.go index 7597d3c9a..b14c95703 100644 --- a/go/plugins/pinecone/genkit_test.go +++ b/go/plugins/pinecone/genkit_test.go @@ -97,7 +97,7 @@ func TestGenkit(t *testing.T) { Options: indexerOptions, } t.Logf("index flag = %q, indexData.Host = %q", *testIndex, indexData.Host) - err = ai.Index(ctx, indexer, indexerReq) + err = indexer.Index(ctx, indexerReq) if err != nil { t.Fatalf("Index operation failed: %v", err) } @@ -134,7 +134,7 @@ func TestGenkit(t *testing.T) { Document: d1, Options: retrieverOptions, } - retrieverResp, err := ai.Retrieve(ctx, retriever, retrieverReq) + retrieverResp, err := retriever.Retrieve(ctx, retrieverReq) if err != nil { t.Fatalf("Retrieve operation failed: %v", err) } diff --git a/go/samples/menu/s04.go b/go/samples/menu/s04.go index e2e404ee7..6b3beff8a 100644 --- a/go/samples/menu/s04.go +++ b/go/samples/menu/s04.go @@ -24,7 +24,7 @@ import ( "github.com/firebase/genkit/go/plugins/localvec" ) -func setup04(ctx context.Context, indexer *ai.IndexerAction, retriever *ai.RetrieverAction, model *ai.ModelAction) error { +func setup04(ctx context.Context, indexer *ai.Indexer, retriever *ai.Retriever, model *ai.ModelAction) error { ragDataMenuPrompt, err := dotprompt.Define("s04_ragDataMenu", ` You are acting as Walt, a helpful AI assistant here at the restaurant. @@ -70,7 +70,7 @@ func setup04(ctx context.Context, indexer *ai.IndexerAction, retriever *ai.Retri req := &ai.IndexerRequest{ Documents: docs, } - if err := ai.Index(ctx, indexer, req); err != nil { + if err := indexer.Index(ctx, req); err != nil { return nil, err } @@ -89,7 +89,7 @@ func setup04(ctx context.Context, indexer *ai.IndexerAction, retriever *ai.Retri K: 3, }, } - resp, err := ai.Retrieve(ctx, retriever, req) + resp, err := retriever.Retrieve(ctx, req) if err != nil { return nil, err } diff --git a/go/samples/rag/main.go b/go/samples/rag/main.go index f9ae4f5b5..0e7e40392 100644 --- a/go/samples/rag/main.go +++ b/go/samples/rag/main.go @@ -109,7 +109,7 @@ func main() { indexerReq := &ai.IndexerRequest{ Documents: []*ai.Document{d1, d2, d3}, } - err := ai.Index(ctx, indexer, indexerReq) + err := indexer.Index(ctx, indexerReq) if err != nil { return "", err } @@ -118,7 +118,7 @@ func main() { retrieverReq := &ai.RetrieverRequest{ Document: dRequest, } - response, err := ai.Retrieve(ctx, retriever, retrieverReq) + response, err := retriever.Retrieve(ctx, retrieverReq) if err != nil { return "", err }