Skip to content

Commit

Permalink
[Go] replace IsKnownModel with IsDefinedModel (#615)
Browse files Browse the repository at this point in the history
IsDefinedModel is more general. You can call it right after Init
to determine if a model is known, or any time to see if you've already
defined a model, and avoid the panic from DefineModel.

Also remove KnownModels. It's unclear if it's useful. If it is, we can
always add it later.

Do the same for embedders.
  • Loading branch information
jba authored Jul 15, 2024
1 parent b19ae18 commit 8481d30
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 18 deletions.
5 changes: 5 additions & 0 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedReq
return (*Embedder)(core.DefineAction(provider, name, atype.Embedder, nil, embed))
}

// IsDefinedEmbedder reports whether an embedder is defined.
func IsDefinedEmbedder(provider, name string) bool {
return LookupEmbedder(provider, name) != nil
}

// LookupEmbedder looks up an [EmbedderAction] registered by [DefineEmbedder].
// It returns nil if the embedder was not defined.
func LookupEmbedder(provider, name string) *Embedder {
Expand Down
9 changes: 7 additions & 2 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ type ModelMetadata struct {
}

// DefineModel registers the given generate function as an action, and returns a
// [ModelAction] that runs it.
// [Model] that runs it.
func DefineModel(provider, name string, metadata *ModelMetadata, generate func(context.Context, *GenerateRequest, ModelStreamingCallback) (*GenerateResponse, error)) *Model {
metadataMap := map[string]any{}
if metadata == nil {
Expand All @@ -75,7 +75,12 @@ func DefineModel(provider, name string, metadata *ModelMetadata, generate func(c
}, generate))
}

// LookupModel looks up a [ModelAction] registered by [DefineModel].
// IsDefinedModel reports whether a model is defined.
func IsDefinedModel(provider, name string) bool {
return LookupModel(provider, name) != nil
}

// LookupModel looks up a [Model] registered by [DefineModel].
// It returns nil if the model was not defined.
func LookupModel(provider, name string) *Model {
return (*Model)(core.LookupActionFor[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk](atype.Model, provider, name))
Expand Down
26 changes: 10 additions & 16 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,15 @@ func Init(ctx context.Context, cfg *Config) (err error) {
return nil
}

// IsKnownModel reports whether a model is known to this plugin.
func IsKnownModel(name string) bool {
_, ok := knownCaps[name]
return ok
// IsDefinedModel reports whether a model is defined in this plugin.
func IsDefinedModel(name string) bool {
return ai.IsDefinedModel(provider, name)
}

// DefineModel defines an unknown model with the given name.
// The second argument describes the capability of the model.
// Use [IsKnownModel] to determine if a model is known.
// Use [IsDefinedModel] to determine if a model is already defined.
// After [Init] is called, only the known models are defined.
func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) {
state.mu.Lock()
defer state.mu.Unlock()
Expand All @@ -138,17 +138,6 @@ func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) {
return defineModel(name, mc), nil
}

// KnownModels returns a slice of all known model names.
func KnownModels() []string {
keys := make([]string, len(knownCaps))
i := 0
for k := range knownCaps {
keys[i] = k
i++
}
return keys
}

// requires state.mu
func defineModel(name string, caps ai.ModelCapabilities) *ai.Model {
meta := &ai.ModelMetadata{
Expand All @@ -169,6 +158,11 @@ func DefineEmbedder(name string) *ai.Embedder {
return defineEmbedder(name)
}

// IsDefinedEmbedder reports whether a model is defined in this plugin.
func IsDefinedEmbedder(name string) bool {
return ai.IsDefinedEmbedder(provider, name)
}

// requires state.mu
func defineEmbedder(name string) *ai.Embedder {
return ai.DefineEmbedder(provider, name, func(ctx context.Context, input *ai.EmbedRequest) (*ai.EmbedResponse, error) {
Expand Down

0 comments on commit 8481d30

Please sign in to comment.