Skip to content

Commit

Permalink
[Go] Indexer and Retriever types (#402)
Browse files Browse the repository at this point in the history
Calling Retrieve on a retriever as a method is more natural
in Go than calling a top-level function, as we currently have.

This PR changes Indexer and Retriever from being aliases of Actions
to being their own types, allowing methods.

Subsequent PRs will do the same for Model, Embedder and possibly Tool.
  • Loading branch information
jba authored Jun 23, 2024
1 parent 9da5724 commit 07ea151
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 44 deletions.
53 changes: 31 additions & 22 deletions go/ai/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,15 @@ import (
)

type (
// An IndexerAction is used to index documents in a store.
IndexerAction = core.Action[*IndexerRequest, struct{}, struct{}]
// A RetrieverAction is used to retrieve indexed documents.
RetrieverAction = core.Action[*RetrieverRequest, *RetrieverResponse, struct{}]
// An Indexer is used to index documents in a store.
Indexer core.Action[*IndexerRequest, struct{}, struct{}]
// A Retriever is used to retrieve indexed documents.
Retriever core.Action[*RetrieverRequest, *RetrieverResponse, struct{}]
)

type (
indexerAction = core.Action[*IndexerRequest, struct{}, struct{}]
retrieverAction = core.Action[*RetrieverRequest, *RetrieverResponse, struct{}]
)

// IndexerRequest is the data we pass to add documents to the database.
Expand All @@ -48,39 +53,43 @@ type RetrieverResponse struct {
}

// DefineIndexer registers the given index function as an action, and returns an
// [IndexerAction] that runs it.
func DefineIndexer(provider, name string, index func(context.Context, *IndexerRequest) error) *IndexerAction {
// [Indexer] that runs it.
func DefineIndexer(provider, name string, index func(context.Context, *IndexerRequest) error) *Indexer {
f := func(ctx context.Context, req *IndexerRequest) (struct{}, error) {
return struct{}{}, index(ctx, req)
}
return core.DefineAction(provider, name, atype.Indexer, nil, f)
return (*Indexer)(core.DefineAction(provider, name, atype.Indexer, nil, f))
}

// LookupIndexer looks up a [IndexerAction] registered by [DefineIndexer].
// LookupIndexer looks up a [Indexer] registered by [DefineIndexer].
// It returns nil if the model was not defined.
func LookupIndexer(provider, name string) *IndexerAction {
return core.LookupActionFor[*IndexerRequest, struct{}, struct{}](atype.Indexer, provider, name)
func LookupIndexer(provider, name string) *Indexer {
return (*Indexer)(core.LookupActionFor[*IndexerRequest, struct{}, struct{}](atype.Indexer, provider, name))
}

// DefineRetriever registers the given retrieve function as an action, and returns a
// [RetrieverAction] that runs it.
func DefineRetriever(provider, name string, ret func(context.Context, *RetrieverRequest) (*RetrieverResponse, error)) *RetrieverAction {
return core.DefineAction(provider, name, atype.Retriever, nil, ret)
// [Retriever] that runs it.
func DefineRetriever(provider, name string, ret func(context.Context, *RetrieverRequest) (*RetrieverResponse, error)) *Retriever {
return (*Retriever)(core.DefineAction(provider, name, atype.Retriever, nil, ret))
}

// LookupRetriever looks up a [RetrieverAction] registered by [DefineRetriever].
// LookupRetriever looks up a [Retriever] registered by [DefineRetriever].
// It returns nil if the model was not defined.
func LookupRetriever(provider, name string) *RetrieverAction {
return core.LookupActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](atype.Retriever, provider, name)
func LookupRetriever(provider, name string) *Retriever {
return (*Retriever)(core.LookupActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](atype.Retriever, provider, name))
}

// Index runs the given [IndexerAction].
func Index(ctx context.Context, indexer *IndexerAction, req *IndexerRequest) error {
_, err := indexer.Run(ctx, req, nil)
// Index runs the given [Indexer].
func (i *Indexer) Index(ctx context.Context, req *IndexerRequest) error {
_, err := (*indexerAction)(i).Run(ctx, req, nil)
return err
}

// Retrieve runs the given [RetrieverAction].
func Retrieve(ctx context.Context, retriever *RetrieverAction, req *RetrieverRequest) (*RetrieverResponse, error) {
return retriever.Run(ctx, req, nil)
// Retrieve runs the given [Retriever].
func (r *Retriever) Retrieve(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) {
return (*retrieverAction)(r).Run(ctx, req, nil)
}

func (i *Indexer) Name() string { return (*indexerAction)(i).Name() }

func (r *Retriever) Name() string { return (*retrieverAction)(r).Name() }
6 changes: 3 additions & 3 deletions go/plugins/localvec/localvec.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func Init() error { return nil }

// DefineStore defines an indexer and retriever that share the same underlying storage.
// The name uniquely identifies the the indexer and retriever in the registry.
func DefineStore(name string, cfg Config) (*ai.IndexerAction, *ai.RetrieverAction, error) {
func DefineStore(name string, cfg Config) (*ai.Indexer, *ai.Retriever, error) {
ds, err := newDocStore(cfg.Dir, name, cfg.Embedder, cfg.EmbedderOptions)
if err != nil {
return nil, nil, err
Expand All @@ -59,13 +59,13 @@ func DefineStore(name string, cfg Config) (*ai.IndexerAction, *ai.RetrieverActio
}

// Indexer returns the registered indexer with the given name.
func Indexer(name string) *ai.IndexerAction {
func Indexer(name string) *ai.Indexer {
return ai.LookupIndexer(provider, name)
}

// Retriever returns the retriever with the given name.
// The name must match the [Config.Name] value passed to [Init].
func Retriever(name string) *ai.RetrieverAction {
func Retriever(name string) *ai.Retriever {
return ai.LookupRetriever(provider, name)
}

Expand Down
8 changes: 0 additions & 8 deletions go/plugins/localvec/localvec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,3 @@ func TestInit(t *testing.T) {
t.Errorf("got %q, want %q", g, want)
}
}

func names[T interface{ Name() string }](xs []T) []string {
var ns []string
for _, x := range xs {
ns = append(ns, x.Name())
}
return ns
}
8 changes: 4 additions & 4 deletions go/plugins/pinecone/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ type Config struct {
TextKey string
}

func DefineIndexer(ctx context.Context, cfg Config) (*ai.IndexerAction, error) {
func DefineIndexer(ctx context.Context, cfg Config) (*ai.Indexer, error) {
ds, err := newDocStore(ctx, cfg)
if err != nil {
return nil, err
}
return ai.DefineIndexer(provider, cfg.IndexID, ds.Index), nil
}

func DefineRetriever(ctx context.Context, cfg Config) (*ai.RetrieverAction, error) {
func DefineRetriever(ctx context.Context, cfg Config) (*ai.Retriever, error) {
ds, err := newDocStore(ctx, cfg)
if err != nil {
return nil, err
Expand Down Expand Up @@ -131,12 +131,12 @@ func newDocStore(ctx context.Context, cfg Config) (*docStore, error) {
}

// Indexer returns the indexer with the given index name.
func Indexer(name string) *ai.IndexerAction {
func Indexer(name string) *ai.Indexer {
return ai.LookupIndexer(provider, name)
}

// Retriever returns the retriever with the given index name.
func Retriever(name string) *ai.RetrieverAction {
func Retriever(name string) *ai.Retriever {
return ai.LookupRetriever(provider, name)
}

Expand Down
4 changes: 2 additions & 2 deletions go/plugins/pinecone/genkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestGenkit(t *testing.T) {
Options: indexerOptions,
}
t.Logf("index flag = %q, indexData.Host = %q", *testIndex, indexData.Host)
err = ai.Index(ctx, indexer, indexerReq)
err = indexer.Index(ctx, indexerReq)
if err != nil {
t.Fatalf("Index operation failed: %v", err)
}
Expand Down Expand Up @@ -134,7 +134,7 @@ func TestGenkit(t *testing.T) {
Document: d1,
Options: retrieverOptions,
}
retrieverResp, err := ai.Retrieve(ctx, retriever, retrieverReq)
retrieverResp, err := retriever.Retrieve(ctx, retrieverReq)
if err != nil {
t.Fatalf("Retrieve operation failed: %v", err)
}
Expand Down
6 changes: 3 additions & 3 deletions go/samples/menu/s04.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
"github.com/firebase/genkit/go/plugins/localvec"
)

func setup04(ctx context.Context, indexer *ai.IndexerAction, retriever *ai.RetrieverAction, model *ai.ModelAction) error {
func setup04(ctx context.Context, indexer *ai.Indexer, retriever *ai.Retriever, model *ai.ModelAction) error {
ragDataMenuPrompt, err := dotprompt.Define("s04_ragDataMenu",
`
You are acting as Walt, a helpful AI assistant here at the restaurant.
Expand Down Expand Up @@ -70,7 +70,7 @@ func setup04(ctx context.Context, indexer *ai.IndexerAction, retriever *ai.Retri
req := &ai.IndexerRequest{
Documents: docs,
}
if err := ai.Index(ctx, indexer, req); err != nil {
if err := indexer.Index(ctx, req); err != nil {
return nil, err
}

Expand All @@ -89,7 +89,7 @@ func setup04(ctx context.Context, indexer *ai.IndexerAction, retriever *ai.Retri
K: 3,
},
}
resp, err := ai.Retrieve(ctx, retriever, req)
resp, err := retriever.Retrieve(ctx, req)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions go/samples/rag/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func main() {
indexerReq := &ai.IndexerRequest{
Documents: []*ai.Document{d1, d2, d3},
}
err := ai.Index(ctx, indexer, indexerReq)
err := indexer.Index(ctx, indexerReq)
if err != nil {
return "", err
}
Expand All @@ -118,7 +118,7 @@ func main() {
retrieverReq := &ai.RetrieverRequest{
Document: dRequest,
}
response, err := ai.Retrieve(ctx, retriever, retrieverReq)
response, err := retriever.Retrieve(ctx, retrieverReq)
if err != nil {
return "", err
}
Expand Down

0 comments on commit 07ea151

Please sign in to comment.