From 43fe9bd071291d41d22fad010cc0447212869f70 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 24 May 2024 15:45:28 -0400 Subject: [PATCH] [Go] move most of genkit package into core (#244) This PR greatly simplifies the genkit package, limiting it to the symbols that Genkit app developers, as exemplified by the programs in the "samples" directory, would need. To accomplish this, it moves most of the code to a new package, named core. The core package, rather than the genkit package, is now imported by the ai package and plugins. End-user applications should not normally require it. This overlapped with #229, so unfortunately those (minor) changes are incorporated here as well, in a slightly different form. --- go/ai/embedder.go | 5 +- go/ai/generator.go | 22 ++-- go/ai/retriever.go | 10 +- go/ai/tools.go | 11 +- go/{genkit => core}/action.go | 37 +++--- go/{genkit => core}/action_test.go | 4 +- go/{genkit => core}/conformance_test.go | 5 +- go/core/core.go | 18 +++ go/{genkit => core}/file_flow_state_store.go | 2 +- go/{genkit => core}/flow.go | 79 ++++--------- go/{genkit => core}/flow_state_store.go | 2 +- go/{genkit => core}/flow_test.go | 28 +---- go/core/gen.go | 32 +++++ go/{genkit => core}/metrics.go | 2 +- go/{genkit => core}/registry.go | 12 +- go/{genkit => core}/servers.go | 6 +- go/{genkit => core}/servers_test.go | 16 +-- .../testdata/conformance/basic.json | 0 .../testdata/conformance/run-1.json | 0 go/genkit/gen.go | 15 --- go/genkit/genkit.go | 111 +++++++++++++++++- go/genkit/genkit_test.go | 54 +++++++++ go/genkit/schemas.config | 4 +- go/plugins/dotprompt/genkit.go | 10 +- go/plugins/dotprompt/genkit_test.go | 3 +- go/plugins/googleai/googleai.go | 3 +- go/plugins/googlecloud/googlecloud.go | 4 +- go/plugins/vertexai/vertexai.go | 3 +- go/samples/flow-sample1/main.go | 2 +- 29 files changed, 320 insertions(+), 180 deletions(-) rename go/{genkit => core}/action.go (82%) rename go/{genkit => core}/action_test.go (96%) rename go/{genkit => core}/conformance_test.go (98%) create mode 100644 go/core/core.go rename go/{genkit => core}/file_flow_state_store.go (99%) rename go/{genkit => core}/flow.go (89%) rename go/{genkit => core}/flow_state_store.go (98%) rename go/{genkit => core}/flow_test.go (88%) create mode 100644 go/core/gen.go rename go/{genkit => core}/metrics.go (99%) rename go/{genkit => core}/registry.go (95%) rename go/{genkit => core}/servers.go (98%) rename go/{genkit => core}/servers_test.go (93%) rename go/{genkit => core}/testdata/conformance/basic.json (100%) rename go/{genkit => core}/testdata/conformance/run-1.json (100%) create mode 100644 go/genkit/genkit_test.go diff --git a/go/ai/embedder.go b/go/ai/embedder.go index 3d07b1a91..0aa5a0cdd 100644 --- a/go/ai/embedder.go +++ b/go/ai/embedder.go @@ -17,7 +17,7 @@ package ai import ( "context" - "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/core" ) // Embedder is the interface used to convert a document to a @@ -35,6 +35,5 @@ type EmbedRequest struct { // RegisterEmbedder registers the actions for a specific embedder. func RegisterEmbedder(name string, embedder Embedder) { - genkit.RegisterAction(genkit.ActionTypeEmbedder, name, - genkit.NewAction(name, genkit.ActionTypeEmbedder, nil, embedder.Embed)) + core.RegisterAction(name, core.NewAction(name, core.ActionTypeEmbedder, nil, embedder.Embed)) } diff --git a/go/ai/generator.go b/go/ai/generator.go index 4295cb8aa..cb6cacf0d 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -21,7 +21,7 @@ import ( "slices" "strings" - "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/core" ) // Generator is the interface used to query an AI model. @@ -31,7 +31,7 @@ type Generator interface { // populating the result's Candidates field. // - If the streaming callback returns a non-nil error, generation will stop // and Generate immediately returns that error (and a nil response). - Generate(context.Context, *GenerateRequest, genkit.StreamingCallback[*Candidate]) (*GenerateResponse, error) + Generate(context.Context, *GenerateRequest, func(context.Context, *Candidate) error) (*GenerateResponse, error) } // GeneratorCapabilities describes various capabilities of the generator. @@ -63,16 +63,16 @@ func RegisterGenerator(provider, name string, metadata *GeneratorMetadata, gener } metadataMap["supports"] = supports } - genkit.RegisterAction(genkit.ActionTypeModel, provider, - genkit.NewStreamingAction(name, genkit.ActionTypeModel, map[string]any{ + core.RegisterAction(provider, + core.NewStreamingAction(name, core.ActionTypeModel, map[string]any{ "model": metadataMap, }, generator.Generate)) } // Generate applies a [Generator] to some input, handling tool requests. -func Generate(ctx context.Context, generator Generator, input *GenerateRequest, cb genkit.StreamingCallback[*Candidate]) (*GenerateResponse, error) { +func Generate(ctx context.Context, g Generator, input *GenerateRequest, cb func(context.Context, *Candidate) error) (*GenerateResponse, error) { for { - resp, err := generator.Generate(ctx, input, cb) + resp, err := g.Generate(ctx, input, cb) if err != nil { return nil, err } @@ -89,14 +89,14 @@ func Generate(ctx context.Context, generator Generator, input *GenerateRequest, } } -// generatorActionType is the instantiated genkit.Action type registered +// generatorActionType is the instantiated core.Action type registered // by RegisterGenerator. -type generatorActionType = genkit.Action[*GenerateRequest, *GenerateResponse, *Candidate] +type generatorActionType = core.Action[*GenerateRequest, *GenerateResponse, *Candidate] // LookupGeneratorAction looks up an action registered by [RegisterGenerator] // and returns a generator that invokes the action. func LookupGeneratorAction(provider, name string) (Generator, error) { - action := genkit.LookupAction(genkit.ActionTypeModel, provider, name) + action := core.LookupAction(core.ActionTypeModel, provider, name) if action == nil { return nil, fmt.Errorf("LookupGeneratorAction: no generator action named %q/%q", provider, name) } @@ -113,9 +113,9 @@ type generatorAction struct { } // Generate implements Generator. This is like the [Generate] function, -// but invokes the [genkit.Action] rather than invoking the Generator +// but invokes the [core.Action] rather than invoking the Generator // directly. -func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, cb genkit.StreamingCallback[*Candidate]) (*GenerateResponse, error) { +func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, cb func(context.Context, *Candidate) error) (*GenerateResponse, error) { for { resp, err := ga.action.Run(ctx, input, cb) if err != nil { diff --git a/go/ai/retriever.go b/go/ai/retriever.go index d4f857b72..44a081d0c 100644 --- a/go/ai/retriever.go +++ b/go/ai/retriever.go @@ -17,7 +17,7 @@ package ai import ( "context" - "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/core" ) // Retriever supports adding documents to a database, and @@ -51,11 +51,11 @@ type RetrieverResponse struct { // RegisterRetriever registers the actions for a specific retriever. func RegisterRetriever(name string, retriever Retriever) { - genkit.RegisterAction(genkit.ActionTypeRetriever, name, - genkit.NewAction(name, genkit.ActionTypeRetriever, nil, retriever.Retrieve)) + core.RegisterAction(name, + core.NewAction(name, core.ActionTypeRetriever, nil, retriever.Retrieve)) - genkit.RegisterAction(genkit.ActionTypeIndexer, name, - genkit.NewAction(name, genkit.ActionTypeIndexer, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) { + core.RegisterAction(name, + core.NewAction(name, core.ActionTypeIndexer, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) { err := retriever.Index(ctx, req) return struct{}{}, err })) diff --git a/go/ai/tools.go b/go/ai/tools.go index d584e0d59..7e89a89fe 100644 --- a/go/ai/tools.go +++ b/go/ai/tools.go @@ -19,7 +19,7 @@ import ( "fmt" "maps" - "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/core" ) // A Tool is an implementation of a single tool. @@ -42,18 +42,17 @@ func RegisterTool(name string, definition *ToolDefinition, metadata map[string]a metadata["type"] = "tool" // TODO: There is no provider for a tool. - genkit.RegisterAction(genkit.ActionTypeTool, "tool", - genkit.NewAction(definition.Name, genkit.ActionTypeTool, metadata, fn)) + core.RegisterAction("tool", core.NewAction(definition.Name, core.ActionTypeTool, metadata, fn)) } -// toolActionType is the instantiated genkit.Action type registered +// toolActionType is the instantiated core.Action type registered // by RegisterTool. -type toolActionType = genkit.Action[map[string]any, map[string]any, struct{}] +type toolActionType = core.Action[map[string]any, map[string]any, struct{}] // RunTool looks up a tool registered by [RegisterTool], // runs it with the given input, and returns the result. func RunTool(ctx context.Context, name string, input map[string]any) (map[string]any, error) { - action := genkit.LookupAction(genkit.ActionTypeTool, "tool", name) + action := core.LookupAction(core.ActionTypeTool, "tool", name) if action == nil { return nil, fmt.Errorf("no tool named %q", name) } diff --git a/go/genkit/action.go b/go/core/action.go similarity index 82% rename from go/genkit/action.go rename to go/core/action.go index d74c735d4..945c03ead 100644 --- a/go/genkit/action.go +++ b/go/core/action.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import ( "context" @@ -34,18 +34,16 @@ import ( // stream the results by invoking the callback periodically, ultimately returning // with a final return value. Otherwise, it should ignore the StreamingCallback and // just return a result. -type Func[I, O, S any] func(context.Context, I, StreamingCallback[S]) (O, error) +type Func[I, O, S any] func(context.Context, I, func(context.Context, S) error) (O, error) // TODO(jba): use a generic type alias for the above when they become available? -// StreamingCallback is the type of streaming callbacks, which is passed to action -// functions who should stream their responses. -type StreamingCallback[S any] func(context.Context, S) error - // NoStream indicates that the action or flow does not support streaming. // A Func[I, O, NoStream] will ignore its streaming callback. // Such a function corresponds to a Flow[I, O, struct{}]. -type NoStream = StreamingCallback[struct{}] +type NoStream = func(context.Context, struct{}) error + +type streamingCallback[S any] func(context.Context, S) error // An Action is a named, observable operation. // It consists of a function that takes an input of type I and returns an output @@ -56,6 +54,7 @@ type NoStream = StreamingCallback[struct{}] // Each time an Action is run, it results in a new trace span. type Action[I, O, S any] struct { name string + atype ActionType fn Func[I, O, S] tstate *tracing.State inputSchema *jsonschema.Schema @@ -68,20 +67,21 @@ type Action[I, O, S any] struct { // See js/common/src/types.ts // NewAction creates a new Action with the given name and non-streaming function. -func NewAction[I, O any](name string, actionType ActionType, metadata map[string]any, fn func(context.Context, I) (O, error)) *Action[I, O, struct{}] { - return NewStreamingAction(name, actionType, metadata, func(ctx context.Context, in I, cb NoStream) (O, error) { +func NewAction[I, O any](name string, atype ActionType, metadata map[string]any, fn func(context.Context, I) (O, error)) *Action[I, O, struct{}] { + return NewStreamingAction(name, atype, metadata, func(ctx context.Context, in I, cb NoStream) (O, error) { return fn(ctx, in) }) } // NewStreamingAction creates a new Action with the given name and streaming function. -func NewStreamingAction[I, O, S any](name string, actionType ActionType, metadata map[string]any, fn Func[I, O, S]) *Action[I, O, S] { +func NewStreamingAction[I, O, S any](name string, atype ActionType, metadata map[string]any, fn Func[I, O, S]) *Action[I, O, S] { var i I var o O return &Action[I, O, S]{ - name: name, - fn: func(ctx context.Context, input I, sc StreamingCallback[S]) (O, error) { - tracing.SetCustomMetadataAttr(ctx, "subtype", string(actionType)) + name: name, + atype: atype, + fn: func(ctx context.Context, input I, sc func(context.Context, S) error) (O, error) { + tracing.SetCustomMetadataAttr(ctx, "subtype", string(atype)) return fn(ctx, input, sc) }, inputSchema: inferJSONSchema(i), @@ -93,11 +93,13 @@ func NewStreamingAction[I, O, S any](name string, actionType ActionType, metadat // Name returns the Action's name. func (a *Action[I, O, S]) Name() string { return a.name } +func (a *Action[I, O, S]) actionType() ActionType { return a.atype } + // setTracingState sets the action's tracing.State. func (a *Action[I, O, S]) setTracingState(tstate *tracing.State) { a.tstate = tstate } // Run executes the Action's function in a new trace span. -func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb StreamingCallback[S]) (output O, err error) { +func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb func(context.Context, S) error) (output O, err error) { // TODO: validate input against JSONSchema for I. // TODO: validate output against JSONSchema for O. internal.Logger(ctx).Debug("Action.Run", @@ -128,12 +130,12 @@ func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb StreamingCallback }) } -func (a *Action[I, O, S]) runJSON(ctx context.Context, input json.RawMessage, cb StreamingCallback[json.RawMessage]) (json.RawMessage, error) { +func (a *Action[I, O, S]) runJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) { var in I if err := json.Unmarshal(input, &in); err != nil { return nil, err } - var callback StreamingCallback[S] + var callback func(context.Context, S) error if cb != nil { callback = func(ctx context.Context, s S) error { bytes, err := json.Marshal(s) @@ -157,10 +159,11 @@ func (a *Action[I, O, S]) runJSON(ctx context.Context, input json.RawMessage, cb // action is the type that all Action[I, O, S] have in common. type action interface { Name() string + actionType() ActionType // runJSON uses encoding/json to unmarshal the input, // calls Action.Run, then returns the marshaled result. - runJSON(ctx context.Context, input json.RawMessage, cb StreamingCallback[json.RawMessage]) (json.RawMessage, error) + runJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) // desc returns a description of the action. // It should set all fields of actionDesc except Key, which diff --git a/go/genkit/action_test.go b/go/core/action_test.go similarity index 96% rename from go/genkit/action_test.go rename to go/core/action_test.go index ae28a4322..5cc1ee9d0 100644 --- a/go/genkit/action_test.go +++ b/go/core/action_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import ( "bytes" @@ -55,7 +55,7 @@ func TestNewAction(t *testing.T) { } // count streams the numbers from 0 to n-1, then returns n. -func count(ctx context.Context, n int, cb StreamingCallback[int]) (int, error) { +func count(ctx context.Context, n int, cb func(context.Context, int) error) (int, error) { if cb != nil { for i := 0; i < n; i++ { if err := cb(ctx, i); err != nil { diff --git a/go/genkit/conformance_test.go b/go/core/conformance_test.go similarity index 98% rename from go/genkit/conformance_test.go rename to go/core/conformance_test.go index c7e12ab21..08a9a527b 100644 --- a/go/genkit/conformance_test.go +++ b/go/core/conformance_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import ( "cmp" @@ -82,6 +82,9 @@ func TestFlowConformance(t *testing.T) { if err != nil { t.Fatal(err) } + if len(testFiles) == 0 { + t.Fatal("did not find any test files") + } for _, filename := range testFiles { t.Run(strings.TrimSuffix(filepath.Base(filename), ".json"), func(t *testing.T) { var test conformanceTest diff --git a/go/core/core.go b/go/core/core.go new file mode 100644 index 000000000..41f086b32 --- /dev/null +++ b/go/core/core.go @@ -0,0 +1,18 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package core implements Genkit actions, flows and other essential machinery. +// This package is primarily intended for genkit internals and for plugins. +// Applications using genkit should use the genkit package. +package core diff --git a/go/genkit/file_flow_state_store.go b/go/core/file_flow_state_store.go similarity index 99% rename from go/genkit/file_flow_state_store.go rename to go/core/file_flow_state_store.go index 442434212..a4b3a7b17 100644 --- a/go/genkit/file_flow_state_store.go +++ b/go/core/file_flow_state_store.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import ( "context" diff --git a/go/genkit/flow.go b/go/core/flow.go similarity index 89% rename from go/genkit/flow.go rename to go/core/flow.go index 5d5d57375..433872755 100644 --- a/go/genkit/flow.go +++ b/go/core/flow.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import ( "context" @@ -110,7 +110,7 @@ func defineFlow[I, O, S any](r *registry, name string, fn Func[I, O, S]) *Flow[I // TODO(jba): set stateStore? } a := f.action() - r.registerAction(ActionTypeFlow, name, a) + r.registerAction(name, a) // TODO(jba): this is a roundabout way to transmit the tracing state. Is there a cleaner way? f.tstate = a.tstate r.registerFlow(f) @@ -250,15 +250,16 @@ func (f *Flow[I, O, S]) action() *Action[*flowInstruction[I], *flowState[I, O], "inputSchema": inferJSONSchema(i), "outputSchema": inferJSONSchema(o), } - return NewStreamingAction(f.name, ActionTypeFlow, metadata, func(ctx context.Context, inst *flowInstruction[I], cb StreamingCallback[S]) (*flowState[I, O], error) { + cback := func(ctx context.Context, inst *flowInstruction[I], cb func(context.Context, S) error) (*flowState[I, O], error) { tracing.SpanMetaKey.FromContext(ctx).SetAttr("flow:wrapperAction", "true") - return f.runInstruction(ctx, inst, cb) - }) + return f.runInstruction(ctx, inst, streamingCallback[S](cb)) + } + return NewStreamingAction(f.name, ActionTypeFlow, metadata, cback) } // runInstruction performs one of several actions on a flow, as determined by msg. // (Called runEnvelope in the js.) -func (f *Flow[I, O, S]) runInstruction(ctx context.Context, inst *flowInstruction[I], cb StreamingCallback[S]) (*flowState[I, O], error) { +func (f *Flow[I, O, S]) runInstruction(ctx context.Context, inst *flowInstruction[I], cb streamingCallback[S]) (*flowState[I, O], error) { switch { case inst.Start != nil: // TODO(jba): pass msg.Start.Labels. @@ -284,18 +285,18 @@ type flow interface { // runJSON uses encoding/json to unmarshal the input, // calls Flow.start, then returns the marshaled result. - runJSON(ctx context.Context, input json.RawMessage, cb StreamingCallback[json.RawMessage]) (json.RawMessage, error) + runJSON(ctx context.Context, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) } func (f *Flow[I, O, S]) Name() string { return f.name } -func (f *Flow[I, O, S]) runJSON(ctx context.Context, input json.RawMessage, cb StreamingCallback[json.RawMessage]) (json.RawMessage, error) { +func (f *Flow[I, O, S]) runJSON(ctx context.Context, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) { var in I if err := json.Unmarshal(input, &in); err != nil { return nil, &httpError{http.StatusBadRequest, err} } // If there is a callback, wrap it to turn an S into a json.RawMessage. - var callback StreamingCallback[S] + var callback streamingCallback[S] if cb != nil { callback = func(ctx context.Context, s S) error { bytes, err := json.Marshal(s) @@ -323,7 +324,7 @@ func (f *Flow[I, O, S]) runJSON(ctx context.Context, input json.RawMessage, cb S } // start starts executing the flow with the given input. -func (f *Flow[I, O, S]) start(ctx context.Context, input I, cb StreamingCallback[S]) (_ *flowState[I, O], err error) { +func (f *Flow[I, O, S]) start(ctx context.Context, input I, cb streamingCallback[S]) (_ *flowState[I, O], err error) { flowID, err := generateFlowID() if err != nil { return nil, err @@ -340,7 +341,7 @@ func (f *Flow[I, O, S]) start(ctx context.Context, input I, cb StreamingCallback // // This function corresponds to Flow.executeSteps in the js, but does more: // it creates the flowContext and saves the state. -func (f *Flow[I, O, S]) execute(ctx context.Context, state *flowState[I, O], dispatchType string, cb StreamingCallback[S]) { +func (f *Flow[I, O, S]) execute(ctx context.Context, state *flowState[I, O], dispatchType string, cb streamingCallback[S]) { fctx := newFlowContext(state, f.stateStore, f.tstate) defer func() { if err := fctx.finish(ctx); err != nil { @@ -531,54 +532,18 @@ func RunFlow[I, O, S any](ctx context.Context, flow *Flow[I, O, S], input I) (O, return finishedOpResponse(state.Operation) } -// StreamFlowValue is either a streamed value or a final output of a flow. -type StreamFlowValue[O, S any] struct { - Done bool - Output O // valid if Done is true - Stream S // valid if Done is false -} +// InternalStreamFlow is for use by genkit.StreamFlow exclusively. +// It is not subject to any backwards compatibility guarantees. +func InternalStreamFlow[I, O, S any](ctx context.Context, flow *Flow[I, O, S], input I, callback func(context.Context, S) error) (O, error) { -var errStop = errors.New("stop") - -// StreamFlow runs flow on input and delivers both the streamed values and the final output. -// It returns a function whose argument function (the "yield function") will be repeatedly -// called with the results. -// -// If the yield function is passed a non-nil error, the flow has failed with that -// error; the yield function will not be called again. An error is also passed if -// the flow fails to complete (that is, it has an interrupt). -// -// If the yield function's [StreamFlowValue] argument has Done == true, the value's -// Output field contains the final output; the yield function will not be called -// again. -// -// Otherwise the Stream field of the passed [StreamFlowValue] holds a streamed result. -func StreamFlow[I, O, S any](ctx context.Context, flow *Flow[I, O, S], input I) func(func(*StreamFlowValue[O, S], error) bool) { - return func(yield func(*StreamFlowValue[O, S], error) bool) { - cb := func(ctx context.Context, s S) error { - if ctx.Err() != nil { - return ctx.Err() - } - if !yield(&StreamFlowValue[O, S]{Stream: s}, nil) { - return errStop - } - return nil - } - var output O - state, err := flow.start(ctx, input, cb) - if err == nil { - if ctx.Err() != nil { - err = ctx.Err() - } else { - output, err = finishedOpResponse(state.Operation) - } - } - if err != nil { - yield(nil, err) - } else { - yield(&StreamFlowValue[O, S]{Done: true, Output: output}, nil) - } + state, err := flow.start(ctx, input, callback) + if err != nil { + return internal.Zero[O](), err } + if ctx.Err() != nil { + return internal.Zero[O](), ctx.Err() + } + return finishedOpResponse(state.Operation) } func finishedOpResponse[O any](op *operation[O]) (O, error) { diff --git a/go/genkit/flow_state_store.go b/go/core/flow_state_store.go similarity index 98% rename from go/genkit/flow_state_store.go rename to go/core/flow_state_store.go index dd803f2b9..497a8d2ff 100644 --- a/go/genkit/flow_state_store.go +++ b/go/core/flow_state_store.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import "context" diff --git a/go/genkit/flow_test.go b/go/core/flow_test.go similarity index 88% rename from go/genkit/flow_test.go rename to go/core/flow_test.go index b8d59d2d2..b2761949c 100644 --- a/go/genkit/flow_test.go +++ b/go/core/flow_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import ( "context" @@ -103,32 +103,6 @@ func TestRunFlow(t *testing.T) { } } -func TestStreamFlow(t *testing.T) { - reg, err := newRegistry() - if err != nil { - t.Fatal(err) - } - f := defineFlow(reg, "count", count) - iter := StreamFlow(context.Background(), f, 2) - want := 0 - iter(func(val *StreamFlowValue[int, int], err error) bool { - if err != nil { - t.Fatal(err) - } - var got int - if val.Done { - got = val.Output - } else { - got = val.Stream - } - if got != want { - t.Errorf("got %d, want %d", got, want) - } - want++ - return true - }) -} - func TestFlowState(t *testing.T) { // A flowState is an action output, so it must support JSON marshaling. // Verify that a fully populated flowState can round-trip via JSON. diff --git a/go/core/gen.go b/go/core/gen.go new file mode 100644 index 000000000..0ef7fd96d --- /dev/null +++ b/go/core/gen.go @@ -0,0 +1,32 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file was generated by jsonschemagen. DO NOT EDIT. + +package core + +import "github.com/firebase/genkit/go/gtime" + +type flowError struct { + Error string `json:"error,omitempty"` + Stacktrace string `json:"stacktrace,omitempty"` +} + +type flowExecution struct { + // end time in milliseconds since the epoch + EndTime gtime.Milliseconds `json:"endTime,omitempty"` + // start time in milliseconds since the epoch + StartTime gtime.Milliseconds `json:"startTime,omitempty"` + TraceIDs []string `json:"traceIds,omitempty"` +} diff --git a/go/genkit/metrics.go b/go/core/metrics.go similarity index 99% rename from go/genkit/metrics.go rename to go/core/metrics.go index c29ec9b97..13ca3b237 100644 --- a/go/genkit/metrics.go +++ b/go/core/metrics.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import ( "context" diff --git a/go/genkit/registry.go b/go/core/registry.go similarity index 95% rename from go/genkit/registry.go rename to go/core/registry.go index 93bc45c32..3634c6b99 100644 --- a/go/genkit/registry.go +++ b/go/core/registry.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import ( "context" @@ -96,16 +96,16 @@ const ( // RegisterAction records the action in the global registry. // It panics if an action with the same type, provider and name is already // registered. -func RegisterAction(typ ActionType, provider string, a action) { - globalRegistry.registerAction(typ, provider, a) +func RegisterAction(provider string, a action) { + globalRegistry.registerAction(provider, a) slog.Info("RegisterAction", - "type", typ, + "type", a.actionType(), "provider", provider, "name", a.Name()) } -func (r *registry) registerAction(typ ActionType, provider string, a action) { - key := fmt.Sprintf("/%s/%s/%s", typ, provider, a.Name()) +func (r *registry) registerAction(provider string, a action) { + key := fmt.Sprintf("/%s/%s/%s", a.actionType(), provider, a.Name()) r.mu.Lock() defer r.mu.Unlock() if _, ok := r.actions[key]; ok { diff --git a/go/genkit/servers.go b/go/core/servers.go similarity index 98% rename from go/genkit/servers.go rename to go/core/servers.go index 7d8107a84..3c66fa4fa 100644 --- a/go/genkit/servers.go +++ b/go/core/servers.go @@ -20,7 +20,7 @@ // The production server has a route for each flow. It // is intended for production deployments. -package genkit +package core import ( "context" @@ -117,7 +117,7 @@ func (s *devServer) handleRunAction(w http.ResponseWriter, r *http.Request) erro internal.Logger(ctx).Debug("running action", "key", body.Key, "stream", stream) - var callback StreamingCallback[json.RawMessage] + var callback streamingCallback[json.RawMessage] if stream { // Stream results are newline-separated JSON. callback = func(ctx context.Context, msg json.RawMessage) error { @@ -141,7 +141,7 @@ type telemetry struct { TraceID string `json:"traceId"` } -func runAction(ctx context.Context, reg *registry, key string, input json.RawMessage, cb StreamingCallback[json.RawMessage]) (*runActionResponse, error) { +func runAction(ctx context.Context, reg *registry, key string, input json.RawMessage, cb streamingCallback[json.RawMessage]) (*runActionResponse, error) { action := reg.lookupAction(key) if action == nil { return nil, &httpError{http.StatusNotFound, fmt.Errorf("no action with key %q", key)} diff --git a/go/genkit/servers_test.go b/go/core/servers_test.go similarity index 93% rename from go/genkit/servers_test.go rename to go/core/servers_test.go index 2ded09205..cbe75cfe6 100644 --- a/go/genkit/servers_test.go +++ b/go/core/servers_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import ( "context" @@ -38,17 +38,17 @@ func TestDevServer(t *testing.T) { if err != nil { t.Fatal(err) } - r.registerAction("test", "devServer", NewAction("inc", ActionTypeCustom, map[string]any{ + r.registerAction("devServer", NewAction("inc", ActionTypeCustom, map[string]any{ "foo": "bar", }, inc)) - r.registerAction("test", "devServer", NewAction("dec", ActionTypeCustom, map[string]any{ + r.registerAction("devServer", NewAction("dec", ActionTypeCustom, map[string]any{ "bar": "baz", }, dec)) srv := httptest.NewServer(newDevServeMux(r)) defer srv.Close() t.Run("runAction", func(t *testing.T) { - body := `{"key": "/test/devServer/inc", "input": 3}` + body := `{"key": "/custom/devServer/inc", "input": 3}` res, err := http.Post(srv.URL+"/api/runAction", "application/json", strings.NewReader(body)) if err != nil { t.Fatal(err) @@ -84,15 +84,15 @@ func TestDevServer(t *testing.T) { t.Fatal(err) } want := map[string]actionDesc{ - "/test/devServer/inc": { - Key: "/test/devServer/inc", + "/custom/devServer/inc": { + Key: "/custom/devServer/inc", Name: "inc", InputSchema: &jsonschema.Schema{Type: "integer"}, OutputSchema: &jsonschema.Schema{Type: "integer"}, Metadata: map[string]any{"foo": "bar"}, }, - "/test/devServer/dec": { - Key: "/test/devServer/dec", + "/custom/devServer/dec": { + Key: "/custom/devServer/dec", InputSchema: &jsonschema.Schema{Type: "integer"}, OutputSchema: &jsonschema.Schema{Type: "integer"}, Name: "dec", diff --git a/go/genkit/testdata/conformance/basic.json b/go/core/testdata/conformance/basic.json similarity index 100% rename from go/genkit/testdata/conformance/basic.json rename to go/core/testdata/conformance/basic.json diff --git a/go/genkit/testdata/conformance/run-1.json b/go/core/testdata/conformance/run-1.json similarity index 100% rename from go/genkit/testdata/conformance/run-1.json rename to go/core/testdata/conformance/run-1.json diff --git a/go/genkit/gen.go b/go/genkit/gen.go index 2a64ee8c1..8686b47fa 100644 --- a/go/genkit/gen.go +++ b/go/genkit/gen.go @@ -15,18 +15,3 @@ // This file was generated by jsonschemagen. DO NOT EDIT. package genkit - -import "github.com/firebase/genkit/go/gtime" - -type flowError struct { - Error string `json:"error,omitempty"` - Stacktrace string `json:"stacktrace,omitempty"` -} - -type flowExecution struct { - // end time in milliseconds since the epoch - EndTime gtime.Milliseconds `json:"endTime,omitempty"` - // start time in milliseconds since the epoch - StartTime gtime.Milliseconds `json:"startTime,omitempty"` - TraceIDs []string `json:"traceIds,omitempty"` -} diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 1b7475ee5..dafa91958 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -19,5 +19,114 @@ // Run the Go code generator on the file just created. //go:generate go run ../internal/cmd/jsonschemagen -outdir .. -config schemas.config ../../genkit-tools/genkit-schema.json genkit -// Package genkit is the genkit API for Go. +// Package genkit provides Genkit functionality for application developers. package genkit + +import ( + "context" + "errors" + "net/http" + + "github.com/firebase/genkit/go/core" +) + +// DefineFlow creates a Flow that runs fn, and registers it as an action. +// +// fn takes an input of type I and returns an output of type O, optionally +// streaming values of type S incrementally by invoking a callback. +// If the callback is non-nil and the function supports streaming, it should +// stream the results by invoking the callback periodically, ultimately returning +// with a final return value. Otherwise, it should ignore the StreamingCallback and +// just return a result. +func DefineFlow[I, O, S any]( + name string, + fn func(ctx context.Context, input I, callback func(context.Context, S) error) (O, error), +) *core.Flow[I, O, S] { + return core.DefineFlow(name, core.Func[I, O, S](fn)) +} + +// Run runs the function f in the context of the current flow. +// It returns an error if no flow is active. +// +// Each call to Run results in a new step in the flow. +// A step has its own span in the trace, and its result is cached so that if the flow +// is restarted, f will not be called a second time. +func Run[T any](ctx context.Context, name string, f func() (T, error)) (T, error) { + return core.Run(ctx, name, f) +} + +// RunFlow runs flow in the context of another flow. The flow must run to completion when started +// (that is, it must not have interrupts). +func RunFlow[I, O, S any](ctx context.Context, flow *core.Flow[I, O, S], input I) (O, error) { + return core.RunFlow(ctx, flow, input) +} + +type NoStream = core.NoStream + +// StreamFlowValue is either a streamed value or a final output of a flow. +type StreamFlowValue[O, S any] struct { + Done bool + Output O // valid if Done is true + Stream S // valid if Done is false +} + +// StreamFlow runs flow on input and delivers both the streamed values and the final output. +// It returns a function whose argument function (the "yield function") will be repeatedly +// called with the results. +// +// If the yield function is passed a non-nil error, the flow has failed with that +// error; the yield function will not be called again. An error is also passed if +// the flow fails to complete (that is, it has an interrupt). +// Genkit Go does not yet support interrupts. +// +// If the yield function's [StreamFlowValue] argument has Done == true, the value's +// Output field contains the final output; the yield function will not be called +// again. +// +// Otherwise the Stream field of the passed [StreamFlowValue] holds a streamed result. +func StreamFlow[I, O, S any](ctx context.Context, flow *core.Flow[I, O, S], input I) func(func(*StreamFlowValue[O, S], error) bool) { + return func(yield func(*StreamFlowValue[O, S], error) bool) { + cb := func(ctx context.Context, s S) error { + if ctx.Err() != nil { + return ctx.Err() + } + if !yield(&StreamFlowValue[O, S]{Stream: s}, nil) { + return errStop + } + return nil + } + output, err := core.InternalStreamFlow(ctx, flow, input, cb) + if err != nil { + yield(nil, err) + } else { + yield(&StreamFlowValue[O, S]{Done: true, Output: output}, nil) + } + } +} + +var errStop = errors.New("stop") + +// StartFlowServer starts a server serving the routes described in [NewFlowServeMux]. +// It listens on addr, or if empty, the value of the PORT environment variable, +// or if that is empty, ":3400". +// +// In development mode (if the environment variable GENKIT_ENV=dev), it also starts +// a dev server. +// +// StartFlowServer always returns a non-nil error, the one returned by http.ListenAndServe. +func StartFlowServer(addr string) error { + return core.StartFlowServer(addr) +} + +// NewFlowServeMux constructs a [net/http.ServeMux] where each defined flow is a route. +// All routes take a single query parameter, "stream", which if true will stream the +// flow's results back to the client. (Not all flows support streaming, however.) +// +// To use the returned ServeMux as part of a server with other routes, either add routes +// to it, or install it as part of another ServeMux, like so: +// +// mainMux := http.NewServeMux() +// mainMux.Handle("POST /flow/", http.StripPrefix("/flow/", NewFlowServeMux())) +func NewFlowServeMux() *http.ServeMux { + return core.NewFlowServeMux() +} diff --git a/go/genkit/genkit_test.go b/go/genkit/genkit_test.go new file mode 100644 index 000000000..d9519a6fe --- /dev/null +++ b/go/genkit/genkit_test.go @@ -0,0 +1,54 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package genkit + +import ( + "context" + "testing" +) + +func TestStreamFlow(t *testing.T) { + f := DefineFlow("count", count) + iter := StreamFlow(context.Background(), f, 2) + want := 0 + iter(func(val *StreamFlowValue[int, int], err error) bool { + if err != nil { + t.Fatal(err) + } + var got int + if val.Done { + got = val.Output + } else { + got = val.Stream + } + if got != want { + t.Errorf("got %d, want %d", got, want) + } + want++ + return true + }) +} + +// count streams the numbers from 0 to n-1, then returns n. +func count(ctx context.Context, n int, cb func(context.Context, int) error) (int, error) { + if cb != nil { + for i := 0; i < n; i++ { + if err := cb(ctx, i); err != nil { + return 0, err + } + } + } + return n, nil +} diff --git a/go/genkit/schemas.config b/go/genkit/schemas.config index 3acb49ace..079233d71 100644 --- a/go/genkit/schemas.config +++ b/go/genkit/schemas.config @@ -1,7 +1,7 @@ # This file holds configuration for the genkit-schema.json file # generated by the npm export:schemas script. -genkit import github.com/firebase/genkit/go/gtime +core import github.com/firebase/genkit/go/gtime # DocumentData type was hand-written. DocumentData omit @@ -67,10 +67,12 @@ FlowInvokeEnvelopeMessageRunScheduled omit Operation omit FlowStateExecution name flowExecution +FlowStateExecution pkg core FlowStateExecution.startTime type gtime.Milliseconds FlowStateExecution.endTime type gtime.Milliseconds FlowError name flowError +FlowError pkg core GenerateRequest.messages doc Messages is a list of messages to pass to the model. The first n-1 Messages diff --git a/go/plugins/dotprompt/genkit.go b/go/plugins/dotprompt/genkit.go index 5fc3c860f..5af27483c 100644 --- a/go/plugins/dotprompt/genkit.go +++ b/go/plugins/dotprompt/genkit.go @@ -21,7 +21,7 @@ import ( "strings" "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/internal/tracing" ) @@ -116,12 +116,12 @@ func (p *Prompt) buildRequest(input *ActionInput) (*ai.GenerateRequest, error) { return req, nil } -// Action returns a [genkit.Action] that executes the prompt. +// Action returns a [core.Action] that executes the prompt. // The returned Action will take an [ActionInput] that provides // variables to substitute into the template text. // It will then pass the rendered text to an AI generator, // and return whatever the generator computes. -func (p *Prompt) Action() (*genkit.Action[*ActionInput, *ai.GenerateResponse, struct{}], error) { +func (p *Prompt) Action() (*core.Action[*ActionInput, *ai.GenerateResponse, struct{}], error) { if p.Name == "" { return nil, errors.New("dotprompt: missing name") } @@ -130,7 +130,7 @@ func (p *Prompt) Action() (*genkit.Action[*ActionInput, *ai.GenerateResponse, st name += "." + p.Variant } - a := genkit.NewAction(name, genkit.ActionTypePrompt, nil, p.Execute) + a := core.NewAction(name, core.ActionTypePrompt, nil, p.Execute) a.Metadata = map[string]any{ "type": "prompt", "prompt": p, @@ -153,7 +153,7 @@ func (p *Prompt) Register() error { return err } - genkit.RegisterAction(genkit.ActionTypePrompt, name, action) + core.RegisterAction(name, action) return nil } diff --git a/go/plugins/dotprompt/genkit_test.go b/go/plugins/dotprompt/genkit_test.go index 61d5148c1..8d03f5b24 100644 --- a/go/plugins/dotprompt/genkit_test.go +++ b/go/plugins/dotprompt/genkit_test.go @@ -20,12 +20,11 @@ import ( "testing" "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/genkit" ) type testGenerator struct{} -func (testGenerator) Generate(ctx context.Context, req *ai.GenerateRequest, cb genkit.StreamingCallback[*ai.Candidate]) (*ai.GenerateResponse, error) { +func (testGenerator) Generate(ctx context.Context, req *ai.GenerateRequest, cb func(context.Context, *ai.Candidate) error) (*ai.GenerateResponse, error) { input := req.Messages[0].Content[0].Text() output := fmt.Sprintf("AI reply to %q", input) diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 34f3dbed2..f2dedc32f 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -19,7 +19,6 @@ import ( "fmt" "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/genkit" "github.com/google/generative-ai-go/genai" "google.golang.org/api/iterator" "google.golang.org/api/option" @@ -35,7 +34,7 @@ type generator struct { //session *genai.ChatSession // non-nil if we're in the middle of a chat } -func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb genkit.StreamingCallback[*ai.Candidate]) (*ai.GenerateResponse, error) { +func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb func(context.Context, *ai.Candidate) error) (*ai.GenerateResponse, error) { gm := g.client.GenerativeModel(g.model) // Translate from a ai.GenerateRequest to a genai request. diff --git a/go/plugins/googlecloud/googlecloud.go b/go/plugins/googlecloud/googlecloud.go index 3c5b38821..6344ee388 100644 --- a/go/plugins/googlecloud/googlecloud.go +++ b/go/plugins/googlecloud/googlecloud.go @@ -27,7 +27,7 @@ import ( "cloud.google.com/go/logging" mexporter "github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric" texporter "github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace" - "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/core" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" @@ -64,7 +64,7 @@ func Init(ctx context.Context, projectID string, opts *Options) error { return err } aexp := &adjustingTraceExporter{texp} - genkit.RegisterSpanProcessor(sdktrace.NewBatchSpanProcessor(aexp)) + core.RegisterSpanProcessor(sdktrace.NewBatchSpanProcessor(aexp)) if err := setMeterProvider(projectID, opts.MetricInterval); err != nil { return err } diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index 554a1fce9..ff00c0555 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -20,7 +20,6 @@ import ( "cloud.google.com/go/vertexai/genai" "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/genkit" ) func newClient(ctx context.Context, projectID, location string) (*genai.Client, error) { @@ -32,7 +31,7 @@ type generator struct { client *genai.Client } -func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb genkit.StreamingCallback[*ai.Candidate]) (*ai.GenerateResponse, error) { +func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb func(context.Context, *ai.Candidate) error) (*ai.GenerateResponse, error) { if cb != nil { panic("streaming not supported yet") // TODO: streaming } diff --git a/go/samples/flow-sample1/main.go b/go/samples/flow-sample1/main.go index a42141f18..a0107f9eb 100644 --- a/go/samples/flow-sample1/main.go +++ b/go/samples/flow-sample1/main.go @@ -61,7 +61,7 @@ func main() { Count int `json:"count"` } - genkit.DefineFlow("streamy", func(ctx context.Context, count int, cb genkit.StreamingCallback[chunk]) (string, error) { + genkit.DefineFlow("streamy", func(ctx context.Context, count int, cb func(context.Context, chunk) error) (string, error) { i := 0 if cb != nil { for ; i < count; i++ {