Skip to content

Commit

Permalink
[Go] return an error if a nil Model, Embedder, etc is used. (#613)
Browse files Browse the repository at this point in the history
  • Loading branch information
jba authored Jul 15, 2024
1 parent 65e7814 commit 3816b43
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 0 deletions.
4 changes: 4 additions & 0 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package ai

import (
"context"
"errors"

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/atype"
Expand Down Expand Up @@ -61,6 +62,9 @@ func LookupEmbedder(provider, name string) *Embedder {

// Embed runs the given [Embedder].
func (e *Embedder) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
if e == nil {
return nil, errors.New("Embed called on a nil Embedder; check that all embedders are defined")
}
a := (*core.Action[*EmbedRequest, *EmbedResponse, struct{}])(e)
return a.Run(ctx, req, nil)
}
3 changes: 3 additions & 0 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ func LookupModel(provider, name string) *Model {

// Generate applies the [Model] to some input, handling tool requests.
func (m *Model) Generate(ctx context.Context, req *GenerateRequest, cb ModelStreamingCallback) (*GenerateResponse, error) {
if m == nil {
return nil, errors.New("Generate called on a nil Model; check that all models are defined")
}
if err := conformOutput(req); err != nil {
return nil, err
}
Expand Down
4 changes: 4 additions & 0 deletions go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package ai

import (
"context"
"errors"
"maps"

"github.com/firebase/genkit/go/core"
Expand Down Expand Up @@ -49,5 +50,8 @@ func LookupPrompt(provider, name string) *Prompt {

// Render renders the [Prompt] with some input data.
func (p *Prompt) Render(ctx context.Context, input any) (*GenerateRequest, error) {
if p == nil {
return nil, errors.New("Render called on a nil Prompt; check that all prompts are defined")
}
return (*core.Action[any, *GenerateRequest, struct{}])(p).Run(ctx, input, nil)
}
7 changes: 7 additions & 0 deletions go/ai/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package ai

import (
"context"
"errors"

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/atype"
Expand Down Expand Up @@ -81,12 +82,18 @@ func LookupRetriever(provider, name string) *Retriever {

// Index runs the given [Indexer].
func (i *Indexer) Index(ctx context.Context, req *IndexerRequest) error {
if i == nil {
return errors.New("Index called on a nil Indexer; check that all indexers are defined")
}
_, err := (*indexerAction)(i).Run(ctx, req, nil)
return err
}

// Retrieve runs the given [Retriever].
func (r *Retriever) Retrieve(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) {
if r == nil {
return nil, errors.New("Retriever called on a nil Retriever; check that all retrievers are defined")
}
return (*retrieverAction)(r).Run(ctx, req, nil)
}

Expand Down

0 comments on commit 3816b43

Please sign in to comment.