@@ -28,8 +28,8 @@ import (
28
28
"github.com/firebase/genkit/go/internal/atype"
29
29
)
30
30
31
- // A ModelAction is used to generate content from an AI model.
32
- type ModelAction = core.Action [* GenerateRequest , * GenerateResponse , * GenerateResponseChunk ]
31
+ // A Model is used to generate content from an AI model.
32
+ type Model core.Action [* GenerateRequest , * GenerateResponse , * GenerateResponseChunk ]
33
33
34
34
// ModelStreamingCallback is the type for the streaming callback of a model.
35
35
type ModelStreamingCallback = func (context.Context , * GenerateResponseChunk ) error
@@ -50,7 +50,7 @@ type ModelMetadata struct {
50
50
51
51
// DefineModel registers the given generate function as an action, and returns a
52
52
// [ModelAction] that runs it.
53
- func DefineModel (provider , name string , metadata * ModelMetadata , generate func (context.Context , * GenerateRequest , ModelStreamingCallback ) (* GenerateResponse , error )) * ModelAction {
53
+ func DefineModel (provider , name string , metadata * ModelMetadata , generate func (context.Context , * GenerateRequest , ModelStreamingCallback ) (* GenerateResponse , error )) * Model {
54
54
metadataMap := map [string ]any {}
55
55
if metadata != nil {
56
56
if metadata .Label != "" {
@@ -64,25 +64,26 @@ func DefineModel(provider, name string, metadata *ModelMetadata, generate func(c
64
64
}
65
65
metadataMap ["supports" ] = supports
66
66
}
67
- return core .DefineStreamingAction (provider , name , atype .Model , map [string ]any {
67
+ return ( * Model )( core .DefineStreamingAction (provider , name , atype .Model , map [string ]any {
68
68
"model" : metadataMap ,
69
- }, generate )
69
+ }, generate ))
70
70
}
71
71
72
72
// LookupModel looks up a [ModelAction] registered by [DefineModel].
73
73
// It returns nil if the model was not defined.
74
- func LookupModel (provider , name string ) * ModelAction {
75
- return core .LookupActionFor [* GenerateRequest , * GenerateResponse , * GenerateResponseChunk ](atype .Model , provider , name )
74
+ func LookupModel (provider , name string ) * Model {
75
+ return ( * Model )( core .LookupActionFor [* GenerateRequest , * GenerateResponse , * GenerateResponseChunk ](atype .Model , provider , name ) )
76
76
}
77
77
78
- // Generate applies a [ModelAction ] to some input, handling tool requests.
79
- func Generate (ctx context.Context , g * ModelAction , req * GenerateRequest , cb ModelStreamingCallback ) (* GenerateResponse , error ) {
78
+ // Generate applies the [Model ] to some input, handling tool requests.
79
+ func ( m * Model ) Generate (ctx context.Context , req * GenerateRequest , cb ModelStreamingCallback ) (* GenerateResponse , error ) {
80
80
if err := conformOutput (req ); err != nil {
81
81
return nil , err
82
82
}
83
83
84
+ a := (* core.Action [* GenerateRequest , * GenerateResponse , * GenerateResponseChunk ])(m )
84
85
for {
85
- resp , err := g .Run (ctx , req , cb )
86
+ resp , err := a .Run (ctx , req , cb )
86
87
if err != nil {
87
88
return nil , err
88
89
}
0 commit comments