diff --git a/go/ai/gen.go b/go/ai/gen.go index 927b279fc..409e4f096 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -16,59 +16,16 @@ package ai -// A Candidate is one of several possible generated responses from a generation -// request. It contains a single generated message along with additional -// metadata about its generation. A generation request may result in multiple Candidates. -type Candidate struct { - Custom any `json:"custom,omitempty"` - FinishMessage string `json:"finishMessage,omitempty"` - FinishReason FinishReason `json:"finishReason,omitempty"` - Index int `json:"index"` - Message *Message `json:"message,omitempty"` - Usage *GenerationUsage `json:"usage,omitempty"` -} - -// FinishReason is the reason why a model stopped generating tokens. -type FinishReason string - -const ( - FinishReasonStop FinishReason = "stop" - FinishReasonLength FinishReason = "length" - FinishReasonBlocked FinishReason = "blocked" - FinishReasonOther FinishReason = "other" - FinishReasonUnknown FinishReason = "unknown" -) - type dataPart struct { Data any `json:"data,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } -// A GenerateRequest is a request to generate completions from a model. -type GenerateRequest struct { - // Candidates indicates the number of responses the model should generate. - // Normally this would be set to 1. - Candidates int `json:"candidates,omitempty"` - Config any `json:"config,omitempty"` - Context []any `json:"context,omitempty"` - // Messages is a list of messages to pass to the model. The first n-1 Messages - // are treated as history. The last Message is the current request. - Messages []*Message `json:"messages,omitempty"` - // Output describes the desired response format. - Output *GenerateRequestOutput `json:"output,omitempty"` - // Tools lists the available tools that the model can ask the client to run. - Tools []*ToolDefinition `json:"tools,omitempty"` -} - -// GenerateRequestOutput describes the structure that the model's output -// should conform to. If Format is [OutputFormatJSON], then Schema -// can describe the desired form of the generated JSON. -type GenerateRequestOutput struct { +type ModelRequestOutput struct { Format OutputFormat `json:"format,omitempty"` Schema map[string]any `json:"schema,omitempty"` } -// OutputFormat is the format that the model's output should produce. type OutputFormat string const ( @@ -77,28 +34,6 @@ const ( OutputFormatMedia OutputFormat = "media" ) -// A GenerateResponse is a model's response to a [GenerateRequest]. -type GenerateResponse struct { - // Candidates are the requested responses from the model. The length of this - // slice will be equal to [GenerateRequest.Candidates]. - Candidates []*Candidate `json:"candidates,omitempty"` - Custom any `json:"custom,omitempty"` - // LatencyMs is the time the request took in milliseconds. - LatencyMs float64 `json:"latencyMs,omitempty"` - // Request is the [GenerateRequest] struct used to trigger this response. - Request *GenerateRequest `json:"request,omitempty"` - // Usage describes how many resources were used by this generation request. - Usage *GenerationUsage `json:"usage,omitempty"` -} - -// A GenerateResponseChunk is the portion of the [GenerateResponse] -// that is passed to a streaming callback. -type GenerateResponseChunk struct { - Content []*Part `json:"content,omitempty"` - Custom any `json:"custom,omitempty"` - Index int `json:"index,omitempty"` -} - // GenerationCommonConfig holds configuration for generation. type GenerationCommonConfig struct { MaxOutputTokens int `json:"maxOutputTokens,omitempty"` @@ -112,12 +47,16 @@ type GenerationCommonConfig struct { // GenerationUsage provides information about the generation process. type GenerationUsage struct { Custom map[string]float64 `json:"custom,omitempty"` + InputAudioFiles float64 `json:"inputAudioFiles,omitempty"` InputCharacters int `json:"inputCharacters,omitempty"` InputImages int `json:"inputImages,omitempty"` InputTokens int `json:"inputTokens,omitempty"` + InputVideos float64 `json:"inputVideos,omitempty"` + OutputAudioFiles float64 `json:"outputAudioFiles,omitempty"` OutputCharacters int `json:"outputCharacters,omitempty"` OutputImages int `json:"outputImages,omitempty"` OutputTokens int `json:"outputTokens,omitempty"` + OutputVideos float64 `json:"outputVideos,omitempty"` TotalTokens int `json:"totalTokens,omitempty"` } @@ -154,6 +93,49 @@ type ModelInfoSupports struct { Tools bool `json:"tools,omitempty"` } +// A ModelRequest is a request to generate completions from a model. +type ModelRequest struct { + Config any `json:"config,omitempty"` + Context []any `json:"context,omitempty"` + Messages []*Message `json:"messages,omitempty"` + // Output describes the desired response format. + Output *ModelRequestOutput `json:"output,omitempty"` + // Tools lists the available tools that the model can ask the client to run. + Tools []*ToolDefinition `json:"tools,omitempty"` +} + +// A ModelResponse is a model's response to a [ModelRequest]. +type ModelResponse struct { + Custom any `json:"custom,omitempty"` + FinishMessage string `json:"finishMessage,omitempty"` + FinishReason FinishReason `json:"finishReason,omitempty"` + // LatencyMs is the time the request took in milliseconds. + LatencyMs float64 `json:"latencyMs,omitempty"` + Message *Message `json:"message,omitempty"` + // Request is the [ModelRequest] struct used to trigger this response. + Request *ModelRequest `json:"request,omitempty"` + // Usage describes how many resources were used by this generation request. + Usage *GenerationUsage `json:"usage,omitempty"` +} + +// A ModelResponseChunk is the portion of the [ModelResponse] +// that is passed to a streaming callback. +type ModelResponseChunk struct { + Aggregated bool `json:"aggregated,omitempty"` + Content []*Part `json:"content,omitempty"` + Custom any `json:"custom,omitempty"` +} + +type FinishReason string + +const ( + FinishReasonStop FinishReason = "stop" + FinishReasonLength FinishReason = "length" + FinishReasonBlocked FinishReason = "blocked" + FinishReasonOther FinishReason = "other" + FinishReasonUnknown FinishReason = "unknown" +) + // Role indicates which entity is responsible for the content of a message. type Role string diff --git a/go/ai/generate.go b/go/ai/generate.go index f9aa6c4e1..54b14679b 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -34,15 +34,15 @@ type Model interface { // Name returns the registry name of the model. Name() string // Generate applies the [Model] to provided request, handling tool requests and handles streaming. - Generate(ctx context.Context, req *GenerateRequest, cb ModelStreamingCallback) (*GenerateResponse, error) + Generate(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) } -type modelActionDef core.Action[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk] +type modelActionDef core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk] -type modelAction = core.Action[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk] +type modelAction = core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk] // ModelStreamingCallback is the type for the streaming callback of a model. -type ModelStreamingCallback = func(context.Context, *GenerateResponseChunk) error +type ModelStreamingCallback = func(context.Context, *ModelResponseChunk) error // ModelCapabilities describes various capabilities of the model. type ModelCapabilities struct { @@ -60,7 +60,7 @@ type ModelMetadata struct { // DefineModel registers the given generate function as an action, and returns a // [Model] that runs it. -func DefineModel(provider, name string, metadata *ModelMetadata, generate func(context.Context, *GenerateRequest, ModelStreamingCallback) (*GenerateResponse, error)) Model { +func DefineModel(provider, name string, metadata *ModelMetadata, generate func(context.Context, *ModelRequest, ModelStreamingCallback) (*ModelResponse, error)) Model { metadataMap := map[string]any{} if metadata == nil { // Always make sure there's at least minimal metadata. @@ -86,13 +86,13 @@ func DefineModel(provider, name string, metadata *ModelMetadata, generate func(c // IsDefinedModel reports whether a model is defined. func IsDefinedModel(provider, name string) bool { - return core.LookupActionFor[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk](atype.Model, provider, name) != nil + return core.LookupActionFor[*ModelRequest, *ModelResponse, *ModelResponseChunk](atype.Model, provider, name) != nil } // LookupModel looks up a [Model] registered by [DefineModel]. // It returns nil if the model was not defined. func LookupModel(provider, name string) Model { - action := core.LookupActionFor[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk](atype.Model, provider, name) + action := core.LookupActionFor[*ModelRequest, *ModelResponse, *ModelResponseChunk](atype.Model, provider, name) if action == nil { return nil } @@ -101,7 +101,7 @@ func LookupModel(provider, name string) Model { // generateParams represents various params of the Generate call. type generateParams struct { - Request *GenerateRequest + Request *ModelRequest Stream ModelStreamingCallback History []*Message SystemPrompt *Message @@ -110,7 +110,7 @@ type generateParams struct { // GenerateOption configures params of the Generate call. type GenerateOption func(req *generateParams) error -// WithTextPrompt adds a simple text user prompt to GenerateRequest. +// WithTextPrompt adds a simple text user prompt to ModelRequest. func WithTextPrompt(prompt string) GenerateOption { return func(req *generateParams) error { req.Request.Messages = append(req.Request.Messages, NewUserTextMessage(prompt)) @@ -118,7 +118,7 @@ func WithTextPrompt(prompt string) GenerateOption { } } -// WithSystemPrompt adds a simple text system prompt as the first message in GenerateRequest. +// WithSystemPrompt adds a simple text system prompt as the first message in ModelRequest. // System prompt will always be put first in the list of messages. func WithSystemPrompt(prompt string) GenerateOption { return func(req *generateParams) error { @@ -130,7 +130,7 @@ func WithSystemPrompt(prompt string) GenerateOption { } } -// WithMessages adds provided messages to GenerateRequest. +// WithMessages adds provided messages to ModelRequest. func WithMessages(messages ...*Message) GenerateOption { return func(req *generateParams) error { req.Request.Messages = append(req.Request.Messages, messages...) @@ -138,7 +138,7 @@ func WithMessages(messages ...*Message) GenerateOption { } } -// WithHistory adds provided history messages to the begining of GenerateRequest.Messages. +// WithHistory adds provided history messages to the begining of ModelRequest.Messages. // History messages will always be put first in the list of messages, with the // exception of system prompt which will always be first. // [WithMessages] and [WithTextPrompt] will insert messages after system prompt and history. @@ -152,7 +152,7 @@ func WithHistory(history ...*Message) GenerateOption { } } -// WithConfig adds provided config to GenerateRequest. +// WithConfig adds provided config to ModelRequest. func WithConfig(config any) GenerateOption { return func(req *generateParams) error { if req.Request.Config != nil { @@ -163,15 +163,7 @@ func WithConfig(config any) GenerateOption { } } -// WithCandidates adds provided candidate count to GenerateRequest. -func WithCandidates(c int) GenerateOption { - return func(req *generateParams) error { - req.Request.Candidates = c - return nil - } -} - -// WithContext adds provided context to GenerateRequest. +// WithContext adds provided context to ModelRequest. func WithContext(c ...any) GenerateOption { return func(req *generateParams) error { req.Request.Context = append(req.Request.Context, c...) @@ -179,7 +171,7 @@ func WithContext(c ...any) GenerateOption { } } -// WithTools adds provided tools to GenerateRequest. +// WithTools adds provided tools to ModelRequest. func WithTools(tools ...Tool) GenerateOption { return func(req *generateParams) error { var toolDefs []*ToolDefinition @@ -191,14 +183,14 @@ func WithTools(tools ...Tool) GenerateOption { } } -// WithOutputSchema adds provided output schema to GenerateRequest. +// WithOutputSchema adds provided output schema to ModelRequest. func WithOutputSchema(schema any) GenerateOption { return func(req *generateParams) error { if req.Request.Output != nil && req.Request.Output.Schema != nil { return errors.New("cannot set Request.Output.Schema (WithOutputSchema) more than once") } if req.Request.Output == nil { - req.Request.Output = &GenerateRequestOutput{} + req.Request.Output = &ModelRequestOutput{} req.Request.Output.Format = OutputFormatJSON } req.Request.Output.Schema = base.SchemaAsMap(base.InferJSONSchemaNonReferencing(schema)) @@ -206,11 +198,11 @@ func WithOutputSchema(schema any) GenerateOption { } } -// WithOutputFormat adds provided output format to GenerateRequest. +// WithOutputFormat adds provided output format to ModelRequest. func WithOutputFormat(format OutputFormat) GenerateOption { return func(req *generateParams) error { if req.Request.Output == nil { - req.Request.Output = &GenerateRequestOutput{} + req.Request.Output = &ModelRequestOutput{} } req.Request.Output.Format = format return nil @@ -228,10 +220,10 @@ func WithStreaming(cb ModelStreamingCallback) GenerateOption { } } -// Generate run generate request for this model. Returns GenerateResponse struct. -func Generate(ctx context.Context, m Model, opts ...GenerateOption) (*GenerateResponse, error) { +// Generate run generate request for this model. Returns ModelResponse struct. +func Generate(ctx context.Context, m Model, opts ...GenerateOption) (*ModelResponse, error) { req := &generateParams{ - Request: &GenerateRequest{}, + Request: &ModelRequest{}, } for _, with := range opts { err := with(req) @@ -263,9 +255,9 @@ func GenerateText(ctx context.Context, m Model, opts ...GenerateOption) (string, return res.Text(), nil } -// Generate run generate request for this model. Returns GenerateResponse struct. +// Generate run generate request for this model. Returns ModelResponse struct. // TODO: Stream GenerateData with partial JSON -func GenerateData(ctx context.Context, m Model, value any, opts ...GenerateOption) (*GenerateResponse, error) { +func GenerateData(ctx context.Context, m Model, value any, opts ...GenerateOption) (*ModelResponse, error) { opts = append(opts, WithOutputSchema(value)) resp, err := Generate(ctx, m, opts...) if err != nil { @@ -279,7 +271,7 @@ func GenerateData(ctx context.Context, m Model, value any, opts ...GenerateOptio } // Generate applies the [Action] to provided request, handling tool requests and handles streaming. -func (m *modelActionDef) Generate(ctx context.Context, req *GenerateRequest, cb ModelStreamingCallback) (*GenerateResponse, error) { +func (m *modelActionDef) Generate(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) { if m == nil { return nil, errors.New("Generate called on a nil Model; check that all models are defined") } @@ -287,18 +279,18 @@ func (m *modelActionDef) Generate(ctx context.Context, req *GenerateRequest, cb return nil, err } - a := (*core.Action[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk])(m) + a := (*core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk])(m) for { resp, err := a.Run(ctx, req, cb) if err != nil { return nil, err } - candidates, err := validCandidates(ctx, resp) + msg, err := validResponse(ctx, resp) if err != nil { return nil, err } - resp.Candidates = candidates + resp.Message = msg newReq, err := handleToolRequest(ctx, req, resp) if err != nil { @@ -315,7 +307,7 @@ func (m *modelActionDef) Generate(ctx context.Context, req *GenerateRequest, cb func (i *modelActionDef) Name() string { return (*modelAction)(i).Name() } // conformOutput appends a message to the request indicating conformance to the expected schema. -func conformOutput(req *GenerateRequest) error { +func conformOutput(req *ModelRequest) error { if req.Output != nil && req.Output.Format == OutputFormatJSON && len(req.Messages) > 0 { jsonBytes, err := json.Marshal(req.Output.Schema) if err != nil { @@ -329,36 +321,30 @@ func conformOutput(req *GenerateRequest) error { return nil } -// validCandidates finds all candidates that match the expected schema. +// validResponse check the message matches the expected schema. // It will strip JSON markdown delimiters from the response. -func validCandidates(ctx context.Context, resp *GenerateResponse) ([]*Candidate, error) { - var candidates []*Candidate - for i, c := range resp.Candidates { - c, err := validCandidate(c, resp.Request.Output) - if err == nil { - candidates = append(candidates, c) - } else { - logger.FromContext(ctx).Debug("candidate did not match expected schema", "index", i, "error", err.Error()) - } - } - if len(candidates) == 0 { - return nil, errors.New("generation resulted in no candidates matching expected schema") +func validResponse(ctx context.Context, resp *ModelResponse) (*Message, error) { + msg, err := validMessage(resp.Message, resp.Request.Output) + if err == nil { + return msg, nil + } else { + logger.FromContext(ctx).Debug("message did not match expected schema", "error", err.Error()) + return nil, errors.New("generation did not result in a message matching expected schema") } - return candidates, nil } -// validCandidate will validate the candidate's response against the expected schema. -// It will return an error if it does not match, otherwise it will return a candidate with JSON content and type. -func validCandidate(c *Candidate, output *GenerateRequestOutput) (*Candidate, error) { +// validMessage will validate the message against the expected schema. +// It will return an error if it does not match, otherwise it will return a message with JSON content and type. +func validMessage(m *Message, output *ModelRequestOutput) (*Message, error) { if output != nil && output.Format == OutputFormatJSON { - if c.Message == nil { - return nil, errors.New("candidate has no message") + if m == nil { + return nil, errors.New("message is empty") } - if len(c.Message.Content) == 0 { - return nil, errors.New("candidate message has no content") + if len(m.Content) == 0 { + return nil, errors.New("message has no content") } - text := base.ExtractJSONFromMarkdown(c.Text()) + text := base.ExtractJSONFromMarkdown(m.Text()) var schemaBytes []byte schemaBytes, err := json.Marshal(output.Schema) if err != nil { @@ -368,19 +354,16 @@ func validCandidate(c *Candidate, output *GenerateRequestOutput) (*Candidate, er return nil, err } // TODO: Verify that it okay to replace all content with JSON. - c.Message.Content = []*Part{NewJSONPart(text)} + m.Content = []*Part{NewJSONPart(text)} } - return c, nil + return m, nil } // handleToolRequest checks if a tool was requested by a model. // If a tool was requested, this runs the tool and returns an -// updated GenerateRequest. If no tool was requested this returns nil. -func handleToolRequest(ctx context.Context, req *GenerateRequest, resp *GenerateResponse) (*GenerateRequest, error) { - if len(resp.Candidates) == 0 { - return nil, nil - } - msg := resp.Candidates[0].Message +// updated ModelRequest. If no tool was requested this returns nil. +func handleToolRequest(ctx context.Context, req *ModelRequest, resp *ModelResponse) (*ModelRequest, error) { + msg := resp.Message if msg == nil || len(msg.Content) == 0 { return nil, nil } @@ -411,7 +394,7 @@ func handleToolRequest(ctx context.Context, req *GenerateRequest, resp *Generate Role: RoleTool, } - // Copy the GenerateRequest rather than modifying it. + // Copy the ModelRequest rather than modifying it. rreq := *req rreq.Messages = append(slices.Clip(rreq.Messages), msg, toolResp) @@ -419,27 +402,27 @@ func handleToolRequest(ctx context.Context, req *GenerateRequest, resp *Generate } // Text returns the contents of the first candidate in a -// [GenerateResponse] as a string. It returns an empty string if there +// [ModelResponse] as a string. It returns an empty string if there // are no candidates or if the candidate has no message. -func (gr *GenerateResponse) Text() string { - if len(gr.Candidates) == 0 { +func (gr *ModelResponse) Text() string { + if gr.Message == nil { return "" } - return gr.Candidates[0].Text() + return gr.Message.Text() } // History returns messages from the request combined with the reponse message // to represent the conversation history. -func (gr *GenerateResponse) History() []*Message { - if len(gr.Candidates) == 0 { +func (gr *ModelResponse) History() []*Message { + if gr.Message == nil { return gr.Request.Messages } - return append(gr.Request.Messages, gr.Candidates[0].Message) + return append(gr.Request.Messages, gr.Message) } // UnmarshalOutput unmarshals structured JSON output into the provided // struct pointer. -func (gr *GenerateResponse) UnmarshalOutput(v any) error { +func (gr *ModelResponse) UnmarshalOutput(v any) error { j := base.ExtractJSONFromMarkdown(gr.Text()) if j == "" { return errors.New("unable to parse JSON from response text") @@ -448,10 +431,10 @@ func (gr *GenerateResponse) UnmarshalOutput(v any) error { return nil } -// Text returns the text content of the [GenerateResponseChunk] +// Text returns the text content of the [ModelResponseChunk] // as a string. It returns an error if there is no Content // in the response chunk. -func (c *GenerateResponseChunk) Text() string { +func (c *ModelResponseChunk) Text() string { if len(c.Content) == 0 { return "" } @@ -465,21 +448,20 @@ func (c *GenerateResponseChunk) Text() string { return sb.String() } -// Text returns the contents of a [Candidate] as a string. It -// returns an empty string if the candidate has no message. -func (c *Candidate) Text() string { - msg := c.Message - if msg == nil { +// Text returns the contents of a [Message] as a string. It +// returns an empty string if the message has no content. +func (m *Message) Text() string { + if m == nil { return "" } - if len(msg.Content) == 0 { + if len(m.Content) == 0 { return "" } - if len(msg.Content) == 1 { - return msg.Content[0].Text + if len(m.Content) == 1 { + return m.Content[0].Text } var sb strings.Builder - for _, p := range msg.Content { + for _, p := range m.Content { sb.WriteString(p.Text) } return sb.String() diff --git a/go/ai/generator_test.go b/go/ai/generator_test.go index ee1034000..fdd9a7d2f 100644 --- a/go/ai/generator_test.go +++ b/go/ai/generator_test.go @@ -30,10 +30,9 @@ type GameCharacter struct { Backstory string } -var echoModel = DefineModel("test", "echo", nil, func(ctx context.Context, gr *GenerateRequest, msc ModelStreamingCallback) (*GenerateResponse, error) { +var echoModel = DefineModel("test", "echo", nil, func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) { if msc != nil { - msc(ctx, &GenerateResponseChunk{ - Index: 0, + msc(ctx, &ModelResponseChunk{ Content: []*Part{NewTextPart("stream!")}, }) } @@ -43,13 +42,9 @@ var echoModel = DefineModel("test", "echo", nil, func(ctx context.Context, gr *G textResponse += m.Content[0].Text } } - return &GenerateResponse{ + return &ModelResponse{ Request: gr, - Candidates: []*Candidate{ - { - Message: NewUserTextMessage(textResponse), - }, - }, + Message: NewUserTextMessage(textResponse), }, nil }) @@ -63,27 +58,25 @@ var gablorkenTool = DefineTool("gablorken", "use when need to calculate a gablor }, ) -func TestValidCandidate(t *testing.T) { +func TestValidMessage(t *testing.T) { t.Parallel() - t.Run("Valid candidate with text format", func(t *testing.T) { - candidate := &Candidate{ - Message: &Message{ - Content: []*Part{ - NewTextPart("Hello, World!"), - }, + t.Run("Valid message with text format", func(t *testing.T) { + message := &Message{ + Content: []*Part{ + NewTextPart("Hello, World!"), }, } - outputSchema := &GenerateRequestOutput{ + outputSchema := &ModelRequestOutput{ Format: OutputFormatText, } - _, err := validCandidate(candidate, outputSchema) + _, err := validMessage(message, outputSchema) if err != nil { t.Fatal(err) } }) - t.Run("Valid candidate with JSON format and matching schema", func(t *testing.T) { + t.Run("Valid message with JSON format and matching schema", func(t *testing.T) { json := `{ "name": "John", "age": 30, @@ -93,14 +86,12 @@ func TestValidCandidate(t *testing.T) { "country": "USA" } }` - candidate := &Candidate{ - Message: &Message{ - Content: []*Part{ - NewTextPart(JSONMarkdown(json)), - }, + message := &Message{ + Content: []*Part{ + NewTextPart(JSONMarkdown(json)), }, } - outputSchema := &GenerateRequestOutput{ + outputSchema := &ModelRequestOutput{ Format: OutputFormatJSON, Schema: map[string]any{ "type": "object", @@ -121,25 +112,23 @@ func TestValidCandidate(t *testing.T) { }, }, } - candidate, err := validCandidate(candidate, outputSchema) + message, err := validMessage(message, outputSchema) if err != nil { t.Fatal(err) } - text := candidate.Text() + text := message.Text() if strings.TrimSpace(text) != strings.TrimSpace(json) { t.Fatalf("got %q, want %q", json, text) } }) - t.Run("Invalid candidate with JSON format and non-matching schema", func(t *testing.T) { - candidate := &Candidate{ - Message: &Message{ - Content: []*Part{ - NewTextPart(JSONMarkdown(`{"name": "John", "age": "30"}`)), - }, + t.Run("Invalid message with JSON format and non-matching schema", func(t *testing.T) { + message := &Message{ + Content: []*Part{ + NewTextPart(JSONMarkdown(`{"name": "John", "age": "30"}`)), }, } - outputSchema := &GenerateRequestOutput{ + outputSchema := &ModelRequestOutput{ Format: OutputFormatJSON, Schema: map[string]any{ "type": "object", @@ -149,54 +138,47 @@ func TestValidCandidate(t *testing.T) { }, }, } - _, err := validCandidate(candidate, outputSchema) + _, err := validMessage(message, outputSchema) errorContains(t, err, "data did not match expected schema") }) - t.Run("Candidate with invalid JSON", func(t *testing.T) { - candidate := &Candidate{ - Message: &Message{ - Content: []*Part{ - NewTextPart(JSONMarkdown(`{"name": "John", "age": 30`)), // Missing trailing }. - }, + t.Run("Message with invalid JSON", func(t *testing.T) { + message := &Message{ + Content: []*Part{ + NewTextPart(JSONMarkdown(`{"name": "John", "age": 30`)), // Missing trailing }. }, } - outputSchema := &GenerateRequestOutput{ + outputSchema := &ModelRequestOutput{ Format: OutputFormatJSON, } - _, err := validCandidate(candidate, outputSchema) + _, err := validMessage(message, outputSchema) errorContains(t, err, "data is not valid JSON") }) - t.Run("Candidate with no message", func(t *testing.T) { - candidate := &Candidate{} - outputSchema := &GenerateRequestOutput{ + t.Run("No message", func(t *testing.T) { + outputSchema := &ModelRequestOutput{ Format: OutputFormatJSON, } - _, err := validCandidate(candidate, outputSchema) - errorContains(t, err, "candidate has no message") + _, err := validMessage(nil, outputSchema) + errorContains(t, err, "message is empty") }) - t.Run("Candidate with message but no content", func(t *testing.T) { - candidate := &Candidate{ - Message: &Message{}, - } - outputSchema := &GenerateRequestOutput{ + t.Run("Empty message", func(t *testing.T) { + message := &Message{} + outputSchema := &ModelRequestOutput{ Format: OutputFormatJSON, } - _, err := validCandidate(candidate, outputSchema) - errorContains(t, err, "candidate message has no content") + _, err := validMessage(message, outputSchema) + errorContains(t, err, "message has no content") }) t.Run("Candidate contains unexpected field", func(t *testing.T) { - candidate := &Candidate{ - Message: &Message{ - Content: []*Part{ - NewTextPart(JSONMarkdown(`{"name": "John", "height": 190}`)), - }, + message := &Message{ + Content: []*Part{ + NewTextPart(JSONMarkdown(`{"name": "John", "height": 190}`)), }, } - outputSchema := &GenerateRequestOutput{ + outputSchema := &ModelRequestOutput{ Format: OutputFormatJSON, Schema: map[string]any{ "type": "object", @@ -207,25 +189,23 @@ func TestValidCandidate(t *testing.T) { "additionalProperties": false, }, } - _, err := validCandidate(candidate, outputSchema) + _, err := validMessage(message, outputSchema) errorContains(t, err, "data did not match expected schema") }) t.Run("Invalid expected schema", func(t *testing.T) { - candidate := &Candidate{ - Message: &Message{ - Content: []*Part{ - NewTextPart(JSONMarkdown(`{"name": "John", "age": 30}`)), - }, + message := &Message{ + Content: []*Part{ + NewTextPart(JSONMarkdown(`{"name": "John", "age": 30}`)), }, } - outputSchema := &GenerateRequestOutput{ + outputSchema := &ModelRequestOutput{ Format: OutputFormatJSON, Schema: map[string]any{ "type": "invalid", }, } - _, err := validCandidate(candidate, outputSchema) + _, err := validMessage(message, outputSchema) errorContains(t, err, "failed to validate data against expected schema") }) } @@ -235,7 +215,7 @@ func TestGenerate(t *testing.T) { charJSON := "{\"Name\": \"foo\", \"Backstory\": \"bar\"}" charJSONmd := "```json" + charJSON + "```" wantText := charJSON - wantRequest := &GenerateRequest{ + wantRequest := &ModelRequest{ Messages: []*Message{ // system prompt -- always first { @@ -274,10 +254,9 @@ func TestGenerate(t *testing.T) { }, }, }, - Config: GenerationCommonConfig{Temperature: 1}, - Candidates: 3, - Context: []any{[]any{string("Banana")}}, - Output: &GenerateRequestOutput{ + Config: GenerationCommonConfig{Temperature: 1}, + Context: []any{[]any{string("Banana")}}, + Output: &ModelRequestOutput{ Format: "json", Schema: map[string]any{ "$id": string("https://github.com/firebase/genkit/go/ai/game-character"), @@ -321,11 +300,10 @@ func TestGenerate(t *testing.T) { Temperature: 1, }), WithHistory(NewUserTextMessage("banana"), NewModelTextMessage("yes, banana")), - WithCandidates(3), WithContext([]any{"Banana"}), WithOutputSchema(&GameCharacter{}), WithTools(gablorkenTool), - WithStreaming(func(ctx context.Context, grc *GenerateResponseChunk) error { + WithStreaming(func(ctx context.Context, grc *ModelResponseChunk) error { streamText += grc.Text() return nil }), @@ -341,7 +319,7 @@ func TestGenerate(t *testing.T) { t.Errorf("Text() diff (+got -want):\n%s", diff) } if diff := cmp.Diff(res.Request, wantRequest, test_utils.IgnoreNoisyParts([]string{ - "{*ai.GenerateRequest}.Messages[4].Content[1].Text", + "{*ai.ModelRequest}.Messages[4].Content[1].Text", })); diff != "" { t.Errorf("Request diff (+got -want):\n%s", diff) } diff --git a/go/ai/prompt.go b/go/ai/prompt.go index a04dee1f2..a6e5e73ae 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -26,14 +26,14 @@ import ( // A Prompt is used to render a prompt template, // producing a [GenerateRequest] that may be passed to a [Model]. -type Prompt core.Action[any, *GenerateRequest, struct{}] +type Prompt core.Action[any, *ModelRequest, struct{}] // DefinePrompt takes a function that renders a prompt template // into a [GenerateRequest] that may be passed to a [Model]. // The prompt expects some input described by inputSchema. // DefinePrompt registers the function as an action, // and returns a [Prompt] that runs it. -func DefinePrompt(provider, name string, metadata map[string]any, inputSchema *jsonschema.Schema, render func(context.Context, any) (*GenerateRequest, error)) *Prompt { +func DefinePrompt(provider, name string, metadata map[string]any, inputSchema *jsonschema.Schema, render func(context.Context, any) (*ModelRequest, error)) *Prompt { mm := maps.Clone(metadata) if mm == nil { mm = make(map[string]any) @@ -50,13 +50,13 @@ func IsDefinedPrompt(provider, name string) bool { // LookupPrompt looks up a [Prompt] registered by [DefinePrompt]. // It returns nil if the prompt was not defined. func LookupPrompt(provider, name string) *Prompt { - return (*Prompt)(core.LookupActionFor[any, *GenerateRequest, struct{}](atype.Prompt, provider, name)) + return (*Prompt)(core.LookupActionFor[any, *ModelRequest, struct{}](atype.Prompt, provider, name)) } // Render renders the [Prompt] with some input data. -func (p *Prompt) Render(ctx context.Context, input any) (*GenerateRequest, error) { +func (p *Prompt) Render(ctx context.Context, input any) (*ModelRequest, error) { if p == nil { return nil, errors.New("Render called on a nil Prompt; check that all prompts are defined") } - return (*core.Action[any, *GenerateRequest, struct{}])(p).Run(ctx, input, nil) + return (*core.Action[any, *ModelRequest, struct{}])(p).Run(ctx, input, nil) } diff --git a/go/ai/request_helpers.go b/go/ai/request_helpers.go index 1f722e003..3632e5c6a 100644 --- a/go/ai/request_helpers.go +++ b/go/ai/request_helpers.go @@ -14,10 +14,10 @@ package ai -// NewGenerateRequest create a new GenerateRequest with provided config and +// NewModelRequest create a new ModelRequest with provided config and // messages. -func NewGenerateRequest(config any, messages ...*Message) *GenerateRequest { - return &GenerateRequest{ +func NewModelRequest(config any, messages ...*Message) *ModelRequest { + return &ModelRequest{ Config: config, Messages: messages, } diff --git a/go/core/schemas.config b/go/core/schemas.config index a65858193..ecf4b8784 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -125,13 +125,15 @@ the results of running a specific tool on the arguments passed to the client by the model in a [ToolRequest]. . -Candidate pkg ai -CandidateFinishReason pkg ai +Candidate omit +CandidateFinishReason omit DocumentData pkg ai -GenerateResponse pkg ai -GenerateResponseChunk pkg ai -GenerateRequest pkg ai +GenerateResponse omit +GenerateResponseChunk omit +GenerateResponseFinishReason name FinishReason +GenerateRequest omit GenerateRequestOutput pkg ai +GenerateRequestOutput name ModelRequestOutput GenerateRequestOutputFormat pkg ai GenerationUsage pkg ai GenerationUsage.inputCharacters type int @@ -175,63 +177,79 @@ RoleUser pkg ai RoleModel pkg ai RoleTool pkg ai +# ModelRequest +ModelRequest pkg ai +ModelRequest.config type any +ModelRequest.context type []any +ModelRequest.messages type []*Message +ModelRequest.output type *ModelRequestOutput +ModelRequest.tools type []*ToolDefinition + +# ModelResponse +ModelResponse pkg ai +ModelResponse.custom type any +ModelResponse.finishMessage type string +ModelResponseFinishReason pkg ai +ModelResponseFinishReason name FinishReason +ModelResponse.latencyMs type float64 +ModelResponse.message type *Message +ModelResponse.request type *ModelRequest +ModelResponse.usage type *GenerationUsage + +# ModelResponseChunk +ModelResponseChunk pkg ai +ModelResponseChunk.aggregated type bool +ModelResponseChunk.content type []*Part +ModelResponseChunk.custom type any -GenerateRequest doc -A GenerateRequest is a request to generate completions from a model. +GenerationCommonConfig doc +GenerationCommonConfig holds configuration for generation. . -GenerateRequest.candidates doc -Candidates indicates the number of responses the model should generate. -Normally this would be set to 1. + +Message doc +Message is the contents of a model response. . -GenerateRequest.output doc + +ToolDefinition doc +A ToolDefinition describes a tool. +. + +DataPart/properties/metadata name map[string]any + +ModelRequest doc +A ModelRequest is a request to generate completions from a model. +. +ModelRequest.output doc Output describes the desired response format. . -GenerateRequest.tools doc +ModelRequest.tools doc Tools lists the available tools that the model can ask the client to run. . -GenerateRequestOutput doc -GenerateRequestOutput describes the structure that the model's output +ModelRequestOutput doc +ModelRequestOutput describes the structure that the model's output should conform to. If Format is [OutputFormatJSON], then Schema can describe the desired form of the generated JSON. . -GenerateRequestOutputFormat doc +ModelRequestOutputFormat doc OutputFormat is the format that the model's output should produce. . -GenerateResponse doc -A GenerateResponse is a model's response to a [GenerateRequest]. -. -GenerateResponse.candidates doc -Candidates are the requested responses from the model. The length of this -slice will be equal to [GenerateRequest.Candidates]. +ModelResponse doc +A ModelResponse is a model's response to a [ModelRequest]. . -GenerateResponse.latencyMs doc +ModelResponse.latencyMs doc LatencyMs is the time the request took in milliseconds. . -GenerateResponse.request doc -Request is the [GenerateRequest] struct used to trigger this response. +ModelResponse.request doc +Request is the [ModelRequest] struct used to trigger this response. . -GenerateResponse.usage doc +ModelResponse.usage doc Usage describes how many resources were used by this generation request. . -GenerateResponseChunk doc -A GenerateResponseChunk is the portion of the [GenerateResponse] +ModelResponseChunk doc +A ModelResponseChunk is the portion of the [ModelResponse] that is passed to a streaming callback. -. - -GenerationCommonConfig doc -GenerationCommonConfig holds configuration for generation. -. - -Message doc -Message is the contents of a model response. -. - -ToolDefinition doc -A ToolDefinition describes a tool. -. - -DataPart/properties/metadata name map[string]any +. \ No newline at end of file diff --git a/go/internal/doc-snippets/modelplugin/modelplugin.go b/go/internal/doc-snippets/modelplugin/modelplugin.go index cad24a500..39a3b285c 100644 --- a/go/internal/doc-snippets/modelplugin/modelplugin.go +++ b/go/internal/doc-snippets/modelplugin/modelplugin.go @@ -26,9 +26,10 @@ const providerID = "mymodels" // [START cfg] type MyModelConfig struct { ai.GenerationCommonConfig - CustomOption int + CustomOption int AnotherCustomOption string } + // [END cfg] func Init() error { @@ -45,16 +46,16 @@ func Init() error { }, }, func(ctx context.Context, - genRequest *ai.GenerateRequest, + genRequest *ai.ModelRequest, _ ai.ModelStreamingCallback, - ) (*ai.GenerateResponse, error) { + ) (*ai.ModelResponse, error) { // Verify that the request includes a configuration that conforms to // your schema . if _, ok := genRequest.Config.(MyModelConfig); !ok { return nil, fmt.Errorf("request config must be type MyModelConfig") } - // Use your custom logic to convert Genkit's ai.GenerateRequest + // Use your custom logic to convert Genkit's ai.ModelRequest // into a form usable by the model's native API. apiRequest, err := apiRequestFromGenkitRequest(genRequest) if err != nil { @@ -69,7 +70,7 @@ func Init() error { } // Use your custom logic to convert the model's response to Genkin's - // ai.GenerateResponse. + // ai.ModelResponse. response, err := genResponseFromAPIResponse(apiResponse) if err != nil { return nil, err @@ -83,7 +84,7 @@ func Init() error { return nil } -func genResponseFromAPIResponse(apiResponse string) (*ai.GenerateResponse, error) { +func genResponseFromAPIResponse(apiResponse string) (*ai.ModelResponse, error) { panic("unimplemented") } @@ -91,6 +92,6 @@ func callModelAPI(apiRequest string) (string, error) { panic("unimplemented") } -func apiRequestFromGenkitRequest(genRequest *ai.GenerateRequest) (string, error) { +func apiRequestFromGenkitRequest(genRequest *ai.ModelRequest) (string, error) { panic("unimplemented") } diff --git a/go/internal/doc-snippets/models.go b/go/internal/doc-snippets/models.go index b16025f5a..e7edf5587 100644 --- a/go/internal/doc-snippets/models.go +++ b/go/internal/doc-snippets/models.go @@ -82,7 +82,7 @@ func streaming() error { ai.WithTextPrompt("Tell a long story about robots and ninjas."), // stream callback ai.WithStreaming( - func(ctx context.Context, grc *ai.GenerateResponseChunk) error { + func(ctx context.Context, grc *ai.ModelResponseChunk) error { fmt.Printf("Chunk: %s\n", grc.Text()) return nil })) @@ -147,7 +147,7 @@ func history() error { // [END hist1] _ = err // [START hist2] - history = append(history, response.Candidates[0].Message) + history = append(history, response.Message) // [END hist2] // [START hist3] diff --git a/go/internal/doc-snippets/prompts.go b/go/internal/doc-snippets/prompts.go index 492a4b0e8..ab066f1a1 100644 --- a/go/internal/doc-snippets/prompts.go +++ b/go/internal/doc-snippets/prompts.go @@ -64,7 +64,7 @@ func pr03() error { "helloPrompt", nil, // Additional model config jsonschema.Reflect(&HelloPromptInput{}), - func(ctx context.Context, input any) (*ai.GenerateRequest, error) { + func(ctx context.Context, input any) (*ai.ModelRequest, error) { params, ok := input.(HelloPromptInput) if !ok { return nil, errors.New("input doesn't satisfy schema") @@ -72,7 +72,7 @@ func pr03() error { prompt := fmt.Sprintf( "You are a helpful AI assistant named Walt. Say hello to %s.", params.UserName) - return &ai.GenerateRequest{Messages: []*ai.Message{ + return &ai.ModelRequest{Messages: []*ai.Message{ {Content: []*ai.Part{ai.NewTextPart(prompt)}}, }}, nil }, diff --git a/go/plugins/dotprompt/genkit.go b/go/plugins/dotprompt/genkit.go index f139b9df6..43aaf842b 100644 --- a/go/plugins/dotprompt/genkit.go +++ b/go/plugins/dotprompt/genkit.go @@ -98,10 +98,10 @@ fieldLoop: return m, nil } -// buildRequest prepares an [ai.GenerateRequest] based on the prompt, +// buildRequest prepares an [ai.ModelRequest] based on the prompt, // using the input variables and other information in the [ai.PromptRequest]. -func (p *Prompt) buildRequest(ctx context.Context, input any) (*ai.GenerateRequest, error) { - req := &ai.GenerateRequest{} +func (p *Prompt) buildRequest(ctx context.Context, input any) (*ai.ModelRequest, error) { + req := &ai.ModelRequest{} m, err := p.buildVariables(input) if err != nil { @@ -111,11 +111,6 @@ func (p *Prompt) buildRequest(ctx context.Context, input any) (*ai.GenerateReque return nil, err } - req.Candidates = p.Candidates - if req.Candidates == 0 { - req.Candidates = 1 - } - req.Config = p.GenerationConfig var outputSchema map[string]any @@ -130,7 +125,7 @@ func (p *Prompt) buildRequest(ctx context.Context, input any) (*ai.GenerateReque } } - req.Output = &ai.GenerateRequestOutput{ + req.Output = &ai.ModelRequestOutput{ Format: p.OutputFormat, Schema: outputSchema, } @@ -179,10 +174,10 @@ func (p *Prompt) Register() error { // the prompt. // // This implements the [ai.Prompt] interface. -func (p *Prompt) Generate(ctx context.Context, pr *PromptRequest, cb func(context.Context, *ai.GenerateResponseChunk) error) (*ai.GenerateResponse, error) { +func (p *Prompt) Generate(ctx context.Context, pr *PromptRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) { tracing.SetCustomMetadataAttr(ctx, "subtype", "prompt") - var genReq *ai.GenerateRequest + var genReq *ai.ModelRequest var err error if p.prompt != nil { genReq, err = p.prompt.Render(ctx, pr.Variables) @@ -194,9 +189,6 @@ func (p *Prompt) Generate(ctx context.Context, pr *PromptRequest, cb func(contex } // Let some fields in pr override those in the prompt config. - if pr.Candidates != 0 { - genReq.Candidates = pr.Candidates - } if pr.Config != nil { genReq.Config = pr.Config } diff --git a/go/plugins/dotprompt/genkit_test.go b/go/plugins/dotprompt/genkit_test.go index a724d28d8..438f35dee 100644 --- a/go/plugins/dotprompt/genkit_test.go +++ b/go/plugins/dotprompt/genkit_test.go @@ -22,18 +22,14 @@ import ( "github.com/firebase/genkit/go/ai" ) -func testGenerate(ctx context.Context, req *ai.GenerateRequest, cb func(context.Context, *ai.GenerateResponseChunk) error) (*ai.GenerateResponse, error) { +func testGenerate(ctx context.Context, req *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) { input := req.Messages[0].Content[0].Text output := fmt.Sprintf("AI reply to %q", input) - r := &ai.GenerateResponse{ - Candidates: []*ai.Candidate{ - { - Message: &ai.Message{ - Content: []*ai.Part{ - ai.NewTextPart(output), - }, - }, + r := &ai.ModelResponse{ + Message: &ai.Message{ + Content: []*ai.Part{ + ai.NewTextPart(output), }, }, Request: req, @@ -67,24 +63,17 @@ func TestExecute(t *testing.T) { }) } -func assertResponse(t *testing.T, resp *ai.GenerateResponse) { - if len(resp.Candidates) != 1 { - t.Errorf("got %d candidates, want 1", len(resp.Candidates)) - if len(resp.Candidates) < 1 { - t.FailNow() - } - } - msg := resp.Candidates[0].Message - if msg == nil { +func assertResponse(t *testing.T, resp *ai.ModelResponse) { + if resp.Message == nil { t.Fatal("response has candidate with no message") } - if len(msg.Content) != 1 { - t.Errorf("got %d message parts, want 1", len(msg.Content)) - if len(msg.Content) < 1 { + if len(resp.Message.Content) != 1 { + t.Errorf("got %d message parts, want 1", len(resp.Message.Content)) + if len(resp.Message.Content) < 1 { t.FailNow() } } - got := msg.Content[0].Text + got := resp.Message.Content[0].Text want := `AI reply to "TestExecute"` if got != want { t.Errorf("fake model replied with %q, want %q", got, want) diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 0a5bc6067..20561e94a 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -149,9 +149,9 @@ func defineModel(name string, caps ai.ModelCapabilities) ai.Model { } return ai.DefineModel(provider, name, meta, func( ctx context.Context, - input *ai.GenerateRequest, - cb func(context.Context, *ai.GenerateResponseChunk) error, - ) (*ai.GenerateResponse, error) { + input *ai.ModelRequest, + cb func(context.Context, *ai.ModelResponseChunk) error, + ) (*ai.ModelResponse, error) { return generate(ctx, state.gclient, name, input, cb) }) } @@ -229,9 +229,9 @@ func generate( ctx context.Context, client *genai.Client, model string, - input *ai.GenerateRequest, - cb func(context.Context, *ai.GenerateResponseChunk) error, -) (*ai.GenerateResponse, error) { + input *ai.ModelRequest, + cb func(context.Context, *ai.ModelResponseChunk) error, +) (*ai.ModelResponse, error) { gm, err := newModel(client, model, input) if err != nil { return nil, err @@ -272,7 +272,7 @@ func generate( // Streaming version. iter := cs.SendMessageStream(ctx, parts...) - var r *ai.GenerateResponse + var r *ai.ModelResponse for { chunk, err := iter.Next() if err == iterator.Done { @@ -285,9 +285,8 @@ func generate( // Send candidates to the callback. for _, c := range chunk.Candidates { tc := translateCandidate(c) - err := cb(ctx, &ai.GenerateResponseChunk{ + err := cb(ctx, &ai.ModelResponseChunk{ Content: tc.Message.Content, - Index: tc.Index, }) if err != nil { return nil, err @@ -297,15 +296,15 @@ func generate( if r == nil { // No candidates were returned. Probably rare, but it might avoid a NPE // to return an empty instead of nil result. - r = &ai.GenerateResponse{} + r = &ai.ModelResponse{} } r.Request = input return r, nil } -func newModel(client *genai.Client, model string, input *ai.GenerateRequest) (*genai.GenerativeModel, error) { +func newModel(client *genai.Client, model string, input *ai.ModelRequest) (*genai.GenerativeModel, error) { gm := client.GenerativeModel(model) - gm.SetCandidateCount(int32(input.Candidates)) + gm.SetCandidateCount(1) if c, ok := input.Config.(*ai.GenerationCommonConfig); ok && c != nil { if c.MaxOutputTokens != 0 { gm.SetMaxOutputTokens(int32(c.MaxOutputTokens)) @@ -341,7 +340,7 @@ func newModel(client *genai.Client, model string, input *ai.GenerateRequest) (*g } // startChat starts a chat session and configures it with the input messages. -func startChat(gm *genai.GenerativeModel, input *ai.GenerateRequest) (*genai.ChatSession, error) { +func startChat(gm *genai.GenerativeModel, input *ai.ModelRequest) (*genai.ChatSession, error) { cs := gm.StartChat() // All but the last message goes in the history field. @@ -470,26 +469,25 @@ func castToStringArray(i []any) []string { //copy:start vertexai.go translateCandidate -// translateCandidate translates from a genai.GenerateContentResponse to an ai.GenerateResponse. -func translateCandidate(cand *genai.Candidate) *ai.Candidate { - c := &ai.Candidate{} - c.Index = int(cand.Index) +// translateCandidate translates from a genai.GenerateContentResponse to an ai.ModelResponse. +func translateCandidate(cand *genai.Candidate) *ai.ModelResponse { + m := &ai.ModelResponse{} switch cand.FinishReason { case genai.FinishReasonStop: - c.FinishReason = ai.FinishReasonStop + m.FinishReason = ai.FinishReasonStop case genai.FinishReasonMaxTokens: - c.FinishReason = ai.FinishReasonLength + m.FinishReason = ai.FinishReasonLength case genai.FinishReasonSafety: - c.FinishReason = ai.FinishReasonBlocked + m.FinishReason = ai.FinishReasonBlocked case genai.FinishReasonRecitation: - c.FinishReason = ai.FinishReasonBlocked + m.FinishReason = ai.FinishReasonBlocked case genai.FinishReasonOther: - c.FinishReason = ai.FinishReasonOther + m.FinishReason = ai.FinishReasonOther default: // Unspecified - c.FinishReason = ai.FinishReasonUnknown + m.FinishReason = ai.FinishReasonUnknown } - m := &ai.Message{} - m.Role = ai.Role(cand.Content.Role) + msg := &ai.Message{} + msg.Role = ai.Role(cand.Content.Role) for _, part := range cand.Content.Parts { var p *ai.Part switch part := part.(type) { @@ -505,22 +503,20 @@ func translateCandidate(cand *genai.Candidate) *ai.Candidate { default: panic(fmt.Sprintf("unknown part %#v", part)) } - m.Content = append(m.Content, p) + msg.Content = append(msg.Content, p) } - c.Message = m - return c + m.Message = msg + return m } //copy:stop //copy:start vertexai.go translateResponse -// Translate from a genai.GenerateContentResponse to a ai.GenerateResponse. -func translateResponse(resp *genai.GenerateContentResponse) *ai.GenerateResponse { - r := &ai.GenerateResponse{} - for _, c := range resp.Candidates { - r.Candidates = append(r.Candidates, translateCandidate(c)) - } +// Translate from a genai.GenerateContentResponse to a ai.ModelResponse. +func translateResponse(resp *genai.GenerateContentResponse) *ai.ModelResponse { + r := translateCandidate(resp.Candidates[0]) + r.Usage = &ai.GenerationUsage{} if u := resp.UsageMetadata; u != nil { r.Usage.InputTokens = int(u.PromptTokenCount) diff --git a/go/plugins/googleai/googleai_test.go b/go/plugins/googleai/googleai_test.go index d67ac258f..01e6804ee 100644 --- a/go/plugins/googleai/googleai_test.go +++ b/go/plugins/googleai/googleai_test.go @@ -85,11 +85,11 @@ func TestLive(t *testing.T) { } }) t.Run("generate", func(t *testing.T) { - resp, err := ai.Generate(ctx, model, ai.WithCandidates(1), ai.WithTextPrompt("Which country was Napoleon the emperor of?")) + resp, err := ai.Generate(ctx, model, ai.WithTextPrompt("Which country was Napoleon the emperor of?")) if err != nil { t.Fatal(err) } - out := resp.Candidates[0].Message.Content[0].Text + out := resp.Message.Content[0].Text const want = "France" if out != want { t.Errorf("got %q, expecting %q", out, want) @@ -105,9 +105,8 @@ func TestLive(t *testing.T) { out := "" parts := 0 final, err := ai.Generate(ctx, model, - ai.WithCandidates(1), ai.WithTextPrompt("Write one paragraph about the North Pole."), - ai.WithStreaming(func(ctx context.Context, c *ai.GenerateResponseChunk) error { + ai.WithStreaming(func(ctx context.Context, c *ai.ModelResponseChunk) error { parts++ out += c.Content[0].Text return nil @@ -116,7 +115,7 @@ func TestLive(t *testing.T) { t.Fatal(err) } out2 := "" - for _, p := range final.Candidates[0].Message.Content { + for _, p := range final.Message.Content { out2 += p.Text } if out != out2 { @@ -136,7 +135,6 @@ func TestLive(t *testing.T) { }) t.Run("tool", func(t *testing.T) { resp, err := ai.Generate(ctx, model, - ai.WithCandidates(1), ai.WithTextPrompt("what is a gablorken of 2 over 3.5?"), ai.WithTools(gablorkenTool)) @@ -144,7 +142,7 @@ func TestLive(t *testing.T) { t.Fatal(err) } - out := resp.Candidates[0].Message.Content[0].Text + out := resp.Message.Content[0].Text const want = "12.25" if !strings.Contains(out, want) { t.Errorf("got %q, expecting it to contain %q", out, want) diff --git a/go/plugins/ollama/ollama.go b/go/plugins/ollama/ollama.go index aa9d1f0ca..bab805973 100644 --- a/go/plugins/ollama/ollama.go +++ b/go/plugins/ollama/ollama.go @@ -120,7 +120,7 @@ type ollamaChatRequest struct { Stream bool `json:"stream"` } -type ollamaGenerateRequest struct { +type ollamaModelRequest struct { System string `json:"system,omitempty"` Images []string `json:"images,omitempty"` Model string `json:"model"` @@ -138,7 +138,7 @@ type ollamaChatResponse struct { } `json:"message"` } -type ollamaGenerateResponse struct { +type ollamaModelResponse struct { Model string `json:"model"` CreatedAt string `json:"created_at"` Response string `json:"response"` @@ -168,7 +168,7 @@ func Init(ctx context.Context, cfg *Config) (err error) { } // Generate makes a request to the Ollama API and processes the response. -func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb func(context.Context, *ai.GenerateResponseChunk) error) (*ai.GenerateResponse, error) { +func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) { stream := cb != nil var payload any @@ -178,7 +178,7 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb if err != nil { return nil, fmt.Errorf("failed to grab image parts: %v", err) } - payload = ollamaGenerateRequest{ + payload = ollamaModelRequest{ Model: g.model.Name, Prompt: concatMessages(input, []ai.Role{ai.RoleUser, ai.RoleModel, ai.RoleTool}), System: concatMessages(input, []ai.Role{ai.RoleSystem}), @@ -232,11 +232,11 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("server returned non-200 status: %d, body: %s", resp.StatusCode, body) } - var response *ai.GenerateResponse + var response *ai.ModelResponse if isChatModel { response, err = translateChatResponse(body) } else { - response, err = translateGenerateResponse(body) + response, err = translateModelResponse(body) } response.Request = input if err != nil { @@ -244,11 +244,11 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb } return response, nil } else { - var chunks []*ai.GenerateResponseChunk + var chunks []*ai.ModelResponseChunk scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { line := scanner.Text() - var chunk *ai.GenerateResponseChunk + var chunk *ai.ModelResponseChunk if isChatModel { chunk, err = translateChatChunk(line) } else { @@ -264,20 +264,16 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb return nil, fmt.Errorf("reading response stream: %v", err) } // Create a final response with the merged chunks - finalResponse := &ai.GenerateResponse{ - Request: input, - Candidates: []*ai.Candidate{ - { - FinishReason: ai.FinishReason("stop"), - Message: &ai.Message{ - Role: ai.RoleModel, - }, - }, + finalResponse := &ai.ModelResponse{ + Request: input, + FinishReason: ai.FinishReason("stop"), + Message: &ai.Message{ + Role: ai.RoleModel, }, } // Add all the merged content to the final response's candidate for _, chunk := range chunks { - finalResponse.Candidates[0].Message.Content = append(finalResponse.Candidates[0].Message.Content, chunk.Content...) + finalResponse.Message.Content = append(finalResponse.Message.Content, chunk.Content...) } return finalResponse, nil // Return the final merged response @@ -308,72 +304,72 @@ func convertParts(role ai.Role, parts []*ai.Part) (*ollamaMessage, error) { } // translateChatResponse translates Ollama chat response into a genkit response. -func translateChatResponse(responseData []byte) (*ai.GenerateResponse, error) { +func translateChatResponse(responseData []byte) (*ai.ModelResponse, error) { var response ollamaChatResponse if err := json.Unmarshal(responseData, &response); err != nil { return nil, fmt.Errorf("failed to parse response JSON: %v", err) } - generateResponse := &ai.GenerateResponse{} - aiCandidate := &ai.Candidate{ + modelResponse := &ai.ModelResponse{ FinishReason: ai.FinishReason("stop"), Message: &ai.Message{ Role: ai.Role(response.Message.Role), }, } + aiPart := ai.NewTextPart(response.Message.Content) - aiCandidate.Message.Content = append(aiCandidate.Message.Content, aiPart) - generateResponse.Candidates = append(generateResponse.Candidates, aiCandidate) - return generateResponse, nil + modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart) + + return modelResponse, nil } // translateResponse translates Ollama generate response into a genkit response. -func translateGenerateResponse(responseData []byte) (*ai.GenerateResponse, error) { - var response ollamaGenerateResponse +func translateModelResponse(responseData []byte) (*ai.ModelResponse, error) { + var response ollamaModelResponse if err := json.Unmarshal(responseData, &response); err != nil { return nil, fmt.Errorf("failed to parse response JSON: %v", err) } - generateResponse := &ai.GenerateResponse{} - aiCandidate := &ai.Candidate{ + + modelResponse := &ai.ModelResponse{ FinishReason: ai.FinishReason("stop"), Message: &ai.Message{ Role: ai.RoleModel, }, } + aiPart := ai.NewTextPart(response.Response) - aiCandidate.Message.Content = append(aiCandidate.Message.Content, aiPart) - generateResponse.Candidates = append(generateResponse.Candidates, aiCandidate) - generateResponse.Usage = &ai.GenerationUsage{} // TODO: can we get any of this info? - return generateResponse, nil + modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart) + modelResponse.Usage = &ai.GenerationUsage{} // TODO: can we get any of this info? + return modelResponse, nil } -func translateChatChunk(input string) (*ai.GenerateResponseChunk, error) { +func translateChatChunk(input string) (*ai.ModelResponseChunk, error) { var response ollamaChatResponse if err := json.Unmarshal([]byte(input), &response); err != nil { return nil, fmt.Errorf("failed to parse response JSON: %v", err) } - chunk := &ai.GenerateResponseChunk{} + chunk := &ai.ModelResponseChunk{} aiPart := ai.NewTextPart(response.Message.Content) chunk.Content = append(chunk.Content, aiPart) return chunk, nil } -func translateGenerateChunk(input string) (*ai.GenerateResponseChunk, error) { - var response ollamaGenerateResponse +func translateGenerateChunk(input string) (*ai.ModelResponseChunk, error) { + var response ollamaModelResponse if err := json.Unmarshal([]byte(input), &response); err != nil { return nil, fmt.Errorf("failed to parse response JSON: %v", err) } - chunk := &ai.GenerateResponseChunk{} + chunk := &ai.ModelResponseChunk{} aiPart := ai.NewTextPart(response.Response) chunk.Content = append(chunk.Content, aiPart) return chunk, nil } // concatMessages translates a list of messages into a prompt-style format -func concatMessages(input *ai.GenerateRequest, roles []ai.Role) string { +func concatMessages(input *ai.ModelRequest, roles []ai.Role) string { roleSet := make(map[ai.Role]bool) for _, role := range roles { roleSet[role] = true // Create a set for faster lookup @@ -395,7 +391,7 @@ func concatMessages(input *ai.GenerateRequest, roles []ai.Role) string { } // concatImages grabs the images from genkit message parts -func concatImages(input *ai.GenerateRequest, roleFilter []ai.Role) ([]string, error) { +func concatImages(input *ai.ModelRequest, roleFilter []ai.Role) ([]string, error) { roleSet := make(map[ai.Role]bool) for _, role := range roleFilter { roleSet[role] = true diff --git a/go/plugins/ollama/ollama_live_test.go b/go/plugins/ollama/ollama_live_test.go index 0fbba2bc7..08c834fb6 100644 --- a/go/plugins/ollama/ollama_live_test.go +++ b/go/plugins/ollama/ollama_live_test.go @@ -58,7 +58,7 @@ func TestLive(t *testing.T) { // Generate a response from the model resp, err := m.Generate(ctx, - ai.NewGenerateRequest( + ai.NewModelRequest( &ai.GenerationCommonConfig{Temperature: 1}, ai.NewUserTextMessage("I'm hungry, what should I eat?")), nil) diff --git a/go/plugins/ollama/ollama_test.go b/go/plugins/ollama/ollama_test.go index 13c0718cd..a3cc591e6 100644 --- a/go/plugins/ollama/ollama_test.go +++ b/go/plugins/ollama/ollama_test.go @@ -68,7 +68,7 @@ func TestConcatMessages(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - input := &ai.GenerateRequest{Messages: tt.messages} + input := &ai.ModelRequest{Messages: tt.messages} got := concatMessages(input, tt.roles) if got != tt.want { t.Errorf("concatMessages() = %q, want %q", got, tt.want) @@ -81,13 +81,13 @@ func TestTranslateGenerateChunk(t *testing.T) { tests := []struct { name string input string - want *ai.GenerateResponseChunk + want *ai.ModelResponseChunk wantErr bool }{ { name: "Valid JSON response", input: `{"model": "my-model", "created_at": "2024-06-20T12:34:56Z", "response": "This is a test response."}`, - want: &ai.GenerateResponseChunk{ + want: &ai.ModelResponseChunk{ Content: []*ai.Part{ai.NewTextPart("This is a test response.")}, }, wantErr: false, diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index c13881509..b3215e03e 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -19,7 +19,6 @@ import ( "fmt" "os" "runtime" - "strconv" "strings" "sync" @@ -167,9 +166,9 @@ func defineModel(name string, caps ai.ModelCapabilities) ai.Model { } return ai.DefineModel(provider, name, meta, func( ctx context.Context, - input *ai.GenerateRequest, - cb func(context.Context, *ai.GenerateResponseChunk) error, - ) (*ai.GenerateResponse, error) { + input *ai.ModelRequest, + cb func(context.Context, *ai.ModelResponseChunk) error, + ) (*ai.ModelResponse, error) { return generate(ctx, state.gclient, name, input, cb) }) } @@ -236,9 +235,9 @@ func generate( ctx context.Context, client *genai.Client, model string, - input *ai.GenerateRequest, - cb func(context.Context, *ai.GenerateResponseChunk) error, -) (*ai.GenerateResponse, error) { + input *ai.ModelRequest, + cb func(context.Context, *ai.ModelResponseChunk) error, +) (*ai.ModelResponse, error) { gm, err := newModel(client, model, input) if err != nil { return nil, err @@ -279,7 +278,7 @@ func generate( // Streaming version. iter := cs.SendMessageStream(ctx, parts...) - var r *ai.GenerateResponse + var r *ai.ModelResponse for { chunk, err := iter.Next() if err == iterator.Done { @@ -292,9 +291,8 @@ func generate( // Send candidates to the callback. for _, c := range chunk.Candidates { tc := translateCandidate(c) - err := cb(ctx, &ai.GenerateResponseChunk{ + err := cb(ctx, &ai.ModelResponseChunk{ Content: tc.Message.Content, - Index: tc.Index, }) if err != nil { return nil, err @@ -304,15 +302,15 @@ func generate( if r == nil { // No candidates were returned. Probably rare, but it might avoid a NPE // to return an empty instead of nil result. - r = &ai.GenerateResponse{} + r = &ai.ModelResponse{} } r.Request = input return r, nil } -func newModel(client *genai.Client, model string, input *ai.GenerateRequest) (*genai.GenerativeModel, error) { +func newModel(client *genai.Client, model string, input *ai.ModelRequest) (*genai.GenerativeModel, error) { gm := client.GenerativeModel(model) - gm.SetCandidateCount(int32(input.Candidates)) + gm.SetCandidateCount(1) if c, ok := input.Config.(*ai.GenerationCommonConfig); ok && c != nil { if c.MaxOutputTokens != 0 { gm.SetMaxOutputTokens(int32(c.MaxOutputTokens)) @@ -348,7 +346,7 @@ func newModel(client *genai.Client, model string, input *ai.GenerateRequest) (*g } // startChat starts a chat session and configures it with the input messages. -func startChat(gm *genai.GenerativeModel, input *ai.GenerateRequest) (*genai.ChatSession, error) { +func startChat(gm *genai.GenerativeModel, input *ai.ModelRequest) (*genai.ChatSession, error) { cs := gm.StartChat() // All but the last message goes in the history field. @@ -381,14 +379,9 @@ func convertTools(inTools []*ai.ToolDefinition) ([]*genai.Tool, error) { if err != err { return nil, err } - outputSchema, err := convertSchema(t.OutputSchema, t.OutputSchema) - if err != err { - return nil, err - } fd := &genai.FunctionDeclaration{ Name: t.Name, Parameters: inputSchema, - Response: outputSchema, Description: t.Description, } outTools = append(outTools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{fd}}) @@ -435,50 +428,9 @@ func convertSchema(originalSchema map[string]any, genkitSchema map[string]any) ( if v, ok := genkitSchema["format"]; ok { schema.Format = v.(string) } - if v, ok := genkitSchema["pattern"]; ok { - schema.Pattern = v.(string) - } - if v, ok := genkitSchema["title"]; ok { - schema.Title = v.(string) - } - if v, ok := genkitSchema["minItems"]; ok { - schema.MinItems = v.(int64) - } - if v, ok := genkitSchema["maxItems"]; ok { - schema.MaxItems = v.(int64) - } - if v, ok := genkitSchema["minItems"]; ok { - schema.MinItems = v.(int64) - } - if v, ok := genkitSchema["maxProperties"]; ok { - schema.MaxProperties = v.(int64) - } - if v, ok := genkitSchema["minProperties"]; ok { - schema.MinProperties = v.(int64) - } - if v, ok := genkitSchema["maxLength"]; ok { - schema.MaxLength = v.(int64) - } - if v, ok := genkitSchema["minLength"]; ok { - schema.MinLength = v.(int64) - } if v, ok := genkitSchema["enum"]; ok { schema.Enum = castToStringArray(v.([]any)) } - if v, ok := genkitSchema["maximum"]; ok { - m, err := strconv.ParseFloat(v.(string), 64) - if err != nil { - return nil, err - } - schema.Maximum = m - } - if v, ok := genkitSchema["minimum"]; ok { - m, err := strconv.ParseFloat(v.(string), 64) - if err != nil { - return nil, err - } - schema.Minimum = m - } if v, ok := genkitSchema["items"]; ok { items, err := convertSchema(originalSchema, v.(map[string]any)) if err != nil { @@ -525,26 +477,25 @@ func castToStringArray(i []any) []string { //copy:sink translateCandidate from ../googleai/googleai.go // DO NOT MODIFY below vvvv -// translateCandidate translates from a genai.GenerateContentResponse to an ai.GenerateResponse. -func translateCandidate(cand *genai.Candidate) *ai.Candidate { - c := &ai.Candidate{} - c.Index = int(cand.Index) +// translateCandidate translates from a genai.GenerateContentResponse to an ai.ModelResponse. +func translateCandidate(cand *genai.Candidate) *ai.ModelResponse { + m := &ai.ModelResponse{} switch cand.FinishReason { case genai.FinishReasonStop: - c.FinishReason = ai.FinishReasonStop + m.FinishReason = ai.FinishReasonStop case genai.FinishReasonMaxTokens: - c.FinishReason = ai.FinishReasonLength + m.FinishReason = ai.FinishReasonLength case genai.FinishReasonSafety: - c.FinishReason = ai.FinishReasonBlocked + m.FinishReason = ai.FinishReasonBlocked case genai.FinishReasonRecitation: - c.FinishReason = ai.FinishReasonBlocked + m.FinishReason = ai.FinishReasonBlocked case genai.FinishReasonOther: - c.FinishReason = ai.FinishReasonOther + m.FinishReason = ai.FinishReasonOther default: // Unspecified - c.FinishReason = ai.FinishReasonUnknown + m.FinishReason = ai.FinishReasonUnknown } - m := &ai.Message{} - m.Role = ai.Role(cand.Content.Role) + msg := &ai.Message{} + msg.Role = ai.Role(cand.Content.Role) for _, part := range cand.Content.Parts { var p *ai.Part switch part := part.(type) { @@ -560,10 +511,10 @@ func translateCandidate(cand *genai.Candidate) *ai.Candidate { default: panic(fmt.Sprintf("unknown part %#v", part)) } - m.Content = append(m.Content, p) + msg.Content = append(msg.Content, p) } - c.Message = m - return c + m.Message = msg + return m } // DO NOT MODIFY above ^^^^ @@ -572,12 +523,10 @@ func translateCandidate(cand *genai.Candidate) *ai.Candidate { //copy:sink translateResponse from ../googleai/googleai.go // DO NOT MODIFY below vvvv -// Translate from a genai.GenerateContentResponse to a ai.GenerateResponse. -func translateResponse(resp *genai.GenerateContentResponse) *ai.GenerateResponse { - r := &ai.GenerateResponse{} - for _, c := range resp.Candidates { - r.Candidates = append(r.Candidates, translateCandidate(c)) - } +// Translate from a genai.GenerateContentResponse to a ai.ModelResponse. +func translateResponse(resp *genai.GenerateContentResponse) *ai.ModelResponse { + r := translateCandidate(resp.Candidates[0]) + r.Usage = &ai.GenerationUsage{} if u := resp.UsageMetadata; u != nil { r.Usage.InputTokens = int(u.PromptTokenCount) diff --git a/go/plugins/vertexai/vertexai_test.go b/go/plugins/vertexai/vertexai_test.go index c057bb271..1b2bb565f 100644 --- a/go/plugins/vertexai/vertexai_test.go +++ b/go/plugins/vertexai/vertexai_test.go @@ -53,11 +53,11 @@ func TestLive(t *testing.T) { }, ) t.Run("model", func(t *testing.T) { - resp, err := ai.Generate(ctx, model, ai.WithCandidates(1), ai.WithTextPrompt("Which country was Napoleon the emperor of?")) + resp, err := ai.Generate(ctx, model, ai.WithTextPrompt("Which country was Napoleon the emperor of?")) if err != nil { t.Fatal(err) } - out := resp.Candidates[0].Message.Content[0].Text + out := resp.Message.Content[0].Text if !strings.Contains(out, "France") { t.Errorf("got \"%s\", expecting it would contain \"France\"", out) } @@ -73,9 +73,8 @@ func TestLive(t *testing.T) { parts := 0 model := vertexai.Model(modelName) final, err := ai.Generate(ctx, model, - ai.WithCandidates(1), ai.WithTextPrompt("Write one paragraph about the Golden State Warriors."), - ai.WithStreaming(func(ctx context.Context, c *ai.GenerateResponseChunk) error { + ai.WithStreaming(func(ctx context.Context, c *ai.ModelResponseChunk) error { parts++ for _, p := range c.Content { out += p.Text @@ -86,7 +85,7 @@ func TestLive(t *testing.T) { t.Fatal(err) } out2 := "" - for _, p := range final.Candidates[0].Message.Content { + for _, p := range final.Message.Content { out2 += p.Text } if out != out2 { @@ -107,14 +106,13 @@ func TestLive(t *testing.T) { }) t.Run("tool", func(t *testing.T) { resp, err := ai.Generate(ctx, model, - ai.WithCandidates(1), ai.WithTextPrompt("what is a gablorken of 2 over 3.5?"), ai.WithTools(gablorkenTool)) if err != nil { t.Fatal(err) } - out := resp.Candidates[0].Message.Content[0].Text + out := resp.Message.Content[0].Text if !strings.Contains(out, "12.25") { t.Errorf("got %s, expecting it to contain \"12.25\"", out) } diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go index 7d03a9f7d..0b0fdc44b 100755 --- a/go/samples/coffee-shop/main.go +++ b/go/samples/coffee-shop/main.go @@ -116,9 +116,9 @@ func main() { } simpleGreetingFlow := genkit.DefineStreamingFlow("simpleGreeting", func(ctx context.Context, input *simpleGreetingInput, cb func(context.Context, string) error) (string, error) { - var callback func(context.Context, *ai.GenerateResponseChunk) error + var callback func(context.Context, *ai.ModelResponseChunk) error if cb != nil { - callback = func(ctx context.Context, c *ai.GenerateResponseChunk) error { + callback = func(ctx context.Context, c *ai.ModelResponseChunk) error { return cb(ctx, c.Text()) } } @@ -171,9 +171,9 @@ func main() { } genkit.DefineStreamingFlow("simpleStructuredGreeting", func(ctx context.Context, input *simpleGreetingInput, cb func(context.Context, string) error) (string, error) { - var callback func(context.Context, *ai.GenerateResponseChunk) error + var callback func(context.Context, *ai.ModelResponseChunk) error if cb != nil { - callback = func(ctx context.Context, c *ai.GenerateResponseChunk) error { + callback = func(ctx context.Context, c *ai.ModelResponseChunk) error { return cb(ctx, c.Text()) } } diff --git a/go/samples/menu/s03.go b/go/samples/menu/s03.go index 13f6a4bd6..4c277b6c7 100644 --- a/go/samples/menu/s03.go +++ b/go/samples/menu/s03.go @@ -117,7 +117,7 @@ func setup03(ctx context.Context, model ai.Model) error { return nil, err } - messages = append(messages, resp.Candidates[0].Message) + messages = append(messages, resp.Message) storedHistory.Store(input.SessionID, messages) out := &chatSessionOutput{ diff --git a/go/samples/menu/s04.go b/go/samples/menu/s04.go index 1387fd292..94cec8fc8 100644 --- a/go/samples/menu/s04.go +++ b/go/samples/menu/s04.go @@ -107,7 +107,7 @@ func setup04(ctx context.Context, indexer ai.Indexer, retriever ai.Retriever, mo } ret := &answerOutput{ - Answer: presp.Candidates[0].Message.Content[0].Text, + Answer: presp.Message.Content[0].Text, } return ret, nil }, diff --git a/go/samples/menu/s05.go b/go/samples/menu/s05.go index 2d046931d..5d8451a02 100644 --- a/go/samples/menu/s05.go +++ b/go/samples/menu/s05.go @@ -97,7 +97,7 @@ func setup05(ctx context.Context, gen, genVision ai.Model) error { return "", err } - ret := presp.Candidates[0].Message.Content[0].Text + ret := presp.Message.Content[0].Text return ret, nil }, ) @@ -115,7 +115,7 @@ func setup05(ctx context.Context, gen, genVision ai.Model) error { return nil, err } ret := &answerOutput{ - Answer: presp.Candidates[0].Message.Content[0].Text, + Answer: presp.Message.Content[0].Text, } return ret, nil }, diff --git a/go/tests/api_test.go b/go/tests/api_test.go index 6b8c51171..b5f24df7e 100644 --- a/go/tests/api_test.go +++ b/go/tests/api_test.go @@ -46,6 +46,7 @@ const hostPort = "http://localhost:3100" func TestReflectionAPI(t *testing.T) { filenames, err := filepath.Glob(filepath.FromSlash("../../tests/*.yaml")) + if err != nil { t.Fatal(err) } diff --git a/go/tests/test_app/main.go b/go/tests/test_app/main.go index 80bcc9abb..8c90a593c 100644 --- a/go/tests/test_app/main.go +++ b/go/tests/test_app/main.go @@ -40,19 +40,16 @@ func main() { } } -func echo(ctx context.Context, req *ai.GenerateRequest, cb func(context.Context, *ai.GenerateResponseChunk) error) (*ai.GenerateResponse, error) { +func echo(ctx context.Context, req *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) { jsonBytes, err := json.Marshal(req) if err != nil { return nil, err } - return &ai.GenerateResponse{ - Candidates: []*ai.Candidate{{ - Index: 0, - FinishReason: "stop", - Message: &ai.Message{ - Role: "model", - Content: []*ai.Part{ai.NewTextPart(string(jsonBytes))}, - }, - }}, + return &ai.ModelResponse{ + FinishReason: "stop", + Message: &ai.Message{ + Role: "model", + Content: []*ai.Part{ai.NewTextPart(string(jsonBytes))}, + }, }, nil } diff --git a/tests/reflection_api_tests.yaml b/tests/reflection_api_tests.yaml index 62ac9eba9..0f68a54b1 100644 --- a/tests/reflection_api_tests.yaml +++ b/tests/reflection_api_tests.yaml @@ -13,13 +13,11 @@ tests: messages: [{ role: user, content: [{ text: hello }] }] body: result: - candidates: - - index: 0 - finishReason: stop - message: - role: model - content: - - text: '{"messages":[{"content":[{"text":"hello"}],"role":"user"}]}' + finishReason: stop + message: + role: model + content: + - text: '{"messages":[{"content":[{"text":"hello"}],"role":"user"}]}' - path: /api/actions body: