diff --git a/go/ai/embedder.go b/go/ai/embedder.go index 3d07b1a91f..0aa5a0cddf 100644 --- a/go/ai/embedder.go +++ b/go/ai/embedder.go @@ -17,7 +17,7 @@ package ai import ( "context" - "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/core" ) // Embedder is the interface used to convert a document to a @@ -35,6 +35,5 @@ type EmbedRequest struct { // RegisterEmbedder registers the actions for a specific embedder. func RegisterEmbedder(name string, embedder Embedder) { - genkit.RegisterAction(genkit.ActionTypeEmbedder, name, - genkit.NewAction(name, genkit.ActionTypeEmbedder, nil, embedder.Embed)) + core.RegisterAction(name, core.NewAction(name, core.ActionTypeEmbedder, nil, embedder.Embed)) } diff --git a/go/ai/generator.go b/go/ai/generator.go index 4295cb8aa1..cb6cacf0d4 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -21,7 +21,7 @@ import ( "slices" "strings" - "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/core" ) // Generator is the interface used to query an AI model. @@ -31,7 +31,7 @@ type Generator interface { // populating the result's Candidates field. // - If the streaming callback returns a non-nil error, generation will stop // and Generate immediately returns that error (and a nil response). - Generate(context.Context, *GenerateRequest, genkit.StreamingCallback[*Candidate]) (*GenerateResponse, error) + Generate(context.Context, *GenerateRequest, func(context.Context, *Candidate) error) (*GenerateResponse, error) } // GeneratorCapabilities describes various capabilities of the generator. @@ -63,16 +63,16 @@ func RegisterGenerator(provider, name string, metadata *GeneratorMetadata, gener } metadataMap["supports"] = supports } - genkit.RegisterAction(genkit.ActionTypeModel, provider, - genkit.NewStreamingAction(name, genkit.ActionTypeModel, map[string]any{ + core.RegisterAction(provider, + core.NewStreamingAction(name, core.ActionTypeModel, map[string]any{ "model": metadataMap, }, generator.Generate)) } // Generate applies a [Generator] to some input, handling tool requests. -func Generate(ctx context.Context, generator Generator, input *GenerateRequest, cb genkit.StreamingCallback[*Candidate]) (*GenerateResponse, error) { +func Generate(ctx context.Context, g Generator, input *GenerateRequest, cb func(context.Context, *Candidate) error) (*GenerateResponse, error) { for { - resp, err := generator.Generate(ctx, input, cb) + resp, err := g.Generate(ctx, input, cb) if err != nil { return nil, err } @@ -89,14 +89,14 @@ func Generate(ctx context.Context, generator Generator, input *GenerateRequest, } } -// generatorActionType is the instantiated genkit.Action type registered +// generatorActionType is the instantiated core.Action type registered // by RegisterGenerator. -type generatorActionType = genkit.Action[*GenerateRequest, *GenerateResponse, *Candidate] +type generatorActionType = core.Action[*GenerateRequest, *GenerateResponse, *Candidate] // LookupGeneratorAction looks up an action registered by [RegisterGenerator] // and returns a generator that invokes the action. func LookupGeneratorAction(provider, name string) (Generator, error) { - action := genkit.LookupAction(genkit.ActionTypeModel, provider, name) + action := core.LookupAction(core.ActionTypeModel, provider, name) if action == nil { return nil, fmt.Errorf("LookupGeneratorAction: no generator action named %q/%q", provider, name) } @@ -113,9 +113,9 @@ type generatorAction struct { } // Generate implements Generator. This is like the [Generate] function, -// but invokes the [genkit.Action] rather than invoking the Generator +// but invokes the [core.Action] rather than invoking the Generator // directly. -func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, cb genkit.StreamingCallback[*Candidate]) (*GenerateResponse, error) { +func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, cb func(context.Context, *Candidate) error) (*GenerateResponse, error) { for { resp, err := ga.action.Run(ctx, input, cb) if err != nil { diff --git a/go/ai/retriever.go b/go/ai/retriever.go index d4f857b72e..44a081d0c9 100644 --- a/go/ai/retriever.go +++ b/go/ai/retriever.go @@ -17,7 +17,7 @@ package ai import ( "context" - "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/core" ) // Retriever supports adding documents to a database, and @@ -51,11 +51,11 @@ type RetrieverResponse struct { // RegisterRetriever registers the actions for a specific retriever. func RegisterRetriever(name string, retriever Retriever) { - genkit.RegisterAction(genkit.ActionTypeRetriever, name, - genkit.NewAction(name, genkit.ActionTypeRetriever, nil, retriever.Retrieve)) + core.RegisterAction(name, + core.NewAction(name, core.ActionTypeRetriever, nil, retriever.Retrieve)) - genkit.RegisterAction(genkit.ActionTypeIndexer, name, - genkit.NewAction(name, genkit.ActionTypeIndexer, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) { + core.RegisterAction(name, + core.NewAction(name, core.ActionTypeIndexer, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) { err := retriever.Index(ctx, req) return struct{}{}, err })) diff --git a/go/ai/tools.go b/go/ai/tools.go index d584e0d59b..7e89a89fee 100644 --- a/go/ai/tools.go +++ b/go/ai/tools.go @@ -19,7 +19,7 @@ import ( "fmt" "maps" - "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/core" ) // A Tool is an implementation of a single tool. @@ -42,18 +42,17 @@ func RegisterTool(name string, definition *ToolDefinition, metadata map[string]a metadata["type"] = "tool" // TODO: There is no provider for a tool. - genkit.RegisterAction(genkit.ActionTypeTool, "tool", - genkit.NewAction(definition.Name, genkit.ActionTypeTool, metadata, fn)) + core.RegisterAction("tool", core.NewAction(definition.Name, core.ActionTypeTool, metadata, fn)) } -// toolActionType is the instantiated genkit.Action type registered +// toolActionType is the instantiated core.Action type registered // by RegisterTool. -type toolActionType = genkit.Action[map[string]any, map[string]any, struct{}] +type toolActionType = core.Action[map[string]any, map[string]any, struct{}] // RunTool looks up a tool registered by [RegisterTool], // runs it with the given input, and returns the result. func RunTool(ctx context.Context, name string, input map[string]any) (map[string]any, error) { - action := genkit.LookupAction(genkit.ActionTypeTool, "tool", name) + action := core.LookupAction(core.ActionTypeTool, "tool", name) if action == nil { return nil, fmt.Errorf("no tool named %q", name) } diff --git a/go/genkit/action.go b/go/core/action.go similarity index 82% rename from go/genkit/action.go rename to go/core/action.go index d74c735d4b..945c03eadf 100644 --- a/go/genkit/action.go +++ b/go/core/action.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import ( "context" @@ -34,18 +34,16 @@ import ( // stream the results by invoking the callback periodically, ultimately returning // with a final return value. Otherwise, it should ignore the StreamingCallback and // just return a result. -type Func[I, O, S any] func(context.Context, I, StreamingCallback[S]) (O, error) +type Func[I, O, S any] func(context.Context, I, func(context.Context, S) error) (O, error) // TODO(jba): use a generic type alias for the above when they become available? -// StreamingCallback is the type of streaming callbacks, which is passed to action -// functions who should stream their responses. -type StreamingCallback[S any] func(context.Context, S) error - // NoStream indicates that the action or flow does not support streaming. // A Func[I, O, NoStream] will ignore its streaming callback. // Such a function corresponds to a Flow[I, O, struct{}]. -type NoStream = StreamingCallback[struct{}] +type NoStream = func(context.Context, struct{}) error + +type streamingCallback[S any] func(context.Context, S) error // An Action is a named, observable operation. // It consists of a function that takes an input of type I and returns an output @@ -56,6 +54,7 @@ type NoStream = StreamingCallback[struct{}] // Each time an Action is run, it results in a new trace span. type Action[I, O, S any] struct { name string + atype ActionType fn Func[I, O, S] tstate *tracing.State inputSchema *jsonschema.Schema @@ -68,20 +67,21 @@ type Action[I, O, S any] struct { // See js/common/src/types.ts // NewAction creates a new Action with the given name and non-streaming function. -func NewAction[I, O any](name string, actionType ActionType, metadata map[string]any, fn func(context.Context, I) (O, error)) *Action[I, O, struct{}] { - return NewStreamingAction(name, actionType, metadata, func(ctx context.Context, in I, cb NoStream) (O, error) { +func NewAction[I, O any](name string, atype ActionType, metadata map[string]any, fn func(context.Context, I) (O, error)) *Action[I, O, struct{}] { + return NewStreamingAction(name, atype, metadata, func(ctx context.Context, in I, cb NoStream) (O, error) { return fn(ctx, in) }) } // NewStreamingAction creates a new Action with the given name and streaming function. -func NewStreamingAction[I, O, S any](name string, actionType ActionType, metadata map[string]any, fn Func[I, O, S]) *Action[I, O, S] { +func NewStreamingAction[I, O, S any](name string, atype ActionType, metadata map[string]any, fn Func[I, O, S]) *Action[I, O, S] { var i I var o O return &Action[I, O, S]{ - name: name, - fn: func(ctx context.Context, input I, sc StreamingCallback[S]) (O, error) { - tracing.SetCustomMetadataAttr(ctx, "subtype", string(actionType)) + name: name, + atype: atype, + fn: func(ctx context.Context, input I, sc func(context.Context, S) error) (O, error) { + tracing.SetCustomMetadataAttr(ctx, "subtype", string(atype)) return fn(ctx, input, sc) }, inputSchema: inferJSONSchema(i), @@ -93,11 +93,13 @@ func NewStreamingAction[I, O, S any](name string, actionType ActionType, metadat // Name returns the Action's name. func (a *Action[I, O, S]) Name() string { return a.name } +func (a *Action[I, O, S]) actionType() ActionType { return a.atype } + // setTracingState sets the action's tracing.State. func (a *Action[I, O, S]) setTracingState(tstate *tracing.State) { a.tstate = tstate } // Run executes the Action's function in a new trace span. -func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb StreamingCallback[S]) (output O, err error) { +func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb func(context.Context, S) error) (output O, err error) { // TODO: validate input against JSONSchema for I. // TODO: validate output against JSONSchema for O. internal.Logger(ctx).Debug("Action.Run", @@ -128,12 +130,12 @@ func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb StreamingCallback }) } -func (a *Action[I, O, S]) runJSON(ctx context.Context, input json.RawMessage, cb StreamingCallback[json.RawMessage]) (json.RawMessage, error) { +func (a *Action[I, O, S]) runJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) { var in I if err := json.Unmarshal(input, &in); err != nil { return nil, err } - var callback StreamingCallback[S] + var callback func(context.Context, S) error if cb != nil { callback = func(ctx context.Context, s S) error { bytes, err := json.Marshal(s) @@ -157,10 +159,11 @@ func (a *Action[I, O, S]) runJSON(ctx context.Context, input json.RawMessage, cb // action is the type that all Action[I, O, S] have in common. type action interface { Name() string + actionType() ActionType // runJSON uses encoding/json to unmarshal the input, // calls Action.Run, then returns the marshaled result. - runJSON(ctx context.Context, input json.RawMessage, cb StreamingCallback[json.RawMessage]) (json.RawMessage, error) + runJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) // desc returns a description of the action. // It should set all fields of actionDesc except Key, which diff --git a/go/genkit/action_test.go b/go/core/action_test.go similarity index 96% rename from go/genkit/action_test.go rename to go/core/action_test.go index ae28a43229..5cc1ee9d04 100644 --- a/go/genkit/action_test.go +++ b/go/core/action_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import ( "bytes" @@ -55,7 +55,7 @@ func TestNewAction(t *testing.T) { } // count streams the numbers from 0 to n-1, then returns n. -func count(ctx context.Context, n int, cb StreamingCallback[int]) (int, error) { +func count(ctx context.Context, n int, cb func(context.Context, int) error) (int, error) { if cb != nil { for i := 0; i < n; i++ { if err := cb(ctx, i); err != nil { diff --git a/go/genkit/conformance_test.go b/go/core/conformance_test.go similarity index 98% rename from go/genkit/conformance_test.go rename to go/core/conformance_test.go index c7e12ab21c..08a9a527b0 100644 --- a/go/genkit/conformance_test.go +++ b/go/core/conformance_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import ( "cmp" @@ -82,6 +82,9 @@ func TestFlowConformance(t *testing.T) { if err != nil { t.Fatal(err) } + if len(testFiles) == 0 { + t.Fatal("did not find any test files") + } for _, filename := range testFiles { t.Run(strings.TrimSuffix(filepath.Base(filename), ".json"), func(t *testing.T) { var test conformanceTest diff --git a/go/core/core.go b/go/core/core.go new file mode 100644 index 0000000000..41f086b320 --- /dev/null +++ b/go/core/core.go @@ -0,0 +1,18 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package core implements Genkit actions, flows and other essential machinery. +// This package is primarily intended for genkit internals and for plugins. +// Applications using genkit should use the genkit package. +package core diff --git a/go/genkit/file_flow_state_store.go b/go/core/file_flow_state_store.go similarity index 99% rename from go/genkit/file_flow_state_store.go rename to go/core/file_flow_state_store.go index 4424342127..a4b3a7b171 100644 --- a/go/genkit/file_flow_state_store.go +++ b/go/core/file_flow_state_store.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import ( "context" diff --git a/go/genkit/flow.go b/go/core/flow.go similarity index 89% rename from go/genkit/flow.go rename to go/core/flow.go index 5d5d573754..4338727552 100644 --- a/go/genkit/flow.go +++ b/go/core/flow.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import ( "context" @@ -110,7 +110,7 @@ func defineFlow[I, O, S any](r *registry, name string, fn Func[I, O, S]) *Flow[I // TODO(jba): set stateStore? } a := f.action() - r.registerAction(ActionTypeFlow, name, a) + r.registerAction(name, a) // TODO(jba): this is a roundabout way to transmit the tracing state. Is there a cleaner way? f.tstate = a.tstate r.registerFlow(f) @@ -250,15 +250,16 @@ func (f *Flow[I, O, S]) action() *Action[*flowInstruction[I], *flowState[I, O], "inputSchema": inferJSONSchema(i), "outputSchema": inferJSONSchema(o), } - return NewStreamingAction(f.name, ActionTypeFlow, metadata, func(ctx context.Context, inst *flowInstruction[I], cb StreamingCallback[S]) (*flowState[I, O], error) { + cback := func(ctx context.Context, inst *flowInstruction[I], cb func(context.Context, S) error) (*flowState[I, O], error) { tracing.SpanMetaKey.FromContext(ctx).SetAttr("flow:wrapperAction", "true") - return f.runInstruction(ctx, inst, cb) - }) + return f.runInstruction(ctx, inst, streamingCallback[S](cb)) + } + return NewStreamingAction(f.name, ActionTypeFlow, metadata, cback) } // runInstruction performs one of several actions on a flow, as determined by msg. // (Called runEnvelope in the js.) -func (f *Flow[I, O, S]) runInstruction(ctx context.Context, inst *flowInstruction[I], cb StreamingCallback[S]) (*flowState[I, O], error) { +func (f *Flow[I, O, S]) runInstruction(ctx context.Context, inst *flowInstruction[I], cb streamingCallback[S]) (*flowState[I, O], error) { switch { case inst.Start != nil: // TODO(jba): pass msg.Start.Labels. @@ -284,18 +285,18 @@ type flow interface { // runJSON uses encoding/json to unmarshal the input, // calls Flow.start, then returns the marshaled result. - runJSON(ctx context.Context, input json.RawMessage, cb StreamingCallback[json.RawMessage]) (json.RawMessage, error) + runJSON(ctx context.Context, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) } func (f *Flow[I, O, S]) Name() string { return f.name } -func (f *Flow[I, O, S]) runJSON(ctx context.Context, input json.RawMessage, cb StreamingCallback[json.RawMessage]) (json.RawMessage, error) { +func (f *Flow[I, O, S]) runJSON(ctx context.Context, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) { var in I if err := json.Unmarshal(input, &in); err != nil { return nil, &httpError{http.StatusBadRequest, err} } // If there is a callback, wrap it to turn an S into a json.RawMessage. - var callback StreamingCallback[S] + var callback streamingCallback[S] if cb != nil { callback = func(ctx context.Context, s S) error { bytes, err := json.Marshal(s) @@ -323,7 +324,7 @@ func (f *Flow[I, O, S]) runJSON(ctx context.Context, input json.RawMessage, cb S } // start starts executing the flow with the given input. -func (f *Flow[I, O, S]) start(ctx context.Context, input I, cb StreamingCallback[S]) (_ *flowState[I, O], err error) { +func (f *Flow[I, O, S]) start(ctx context.Context, input I, cb streamingCallback[S]) (_ *flowState[I, O], err error) { flowID, err := generateFlowID() if err != nil { return nil, err @@ -340,7 +341,7 @@ func (f *Flow[I, O, S]) start(ctx context.Context, input I, cb StreamingCallback // // This function corresponds to Flow.executeSteps in the js, but does more: // it creates the flowContext and saves the state. -func (f *Flow[I, O, S]) execute(ctx context.Context, state *flowState[I, O], dispatchType string, cb StreamingCallback[S]) { +func (f *Flow[I, O, S]) execute(ctx context.Context, state *flowState[I, O], dispatchType string, cb streamingCallback[S]) { fctx := newFlowContext(state, f.stateStore, f.tstate) defer func() { if err := fctx.finish(ctx); err != nil { @@ -531,54 +532,18 @@ func RunFlow[I, O, S any](ctx context.Context, flow *Flow[I, O, S], input I) (O, return finishedOpResponse(state.Operation) } -// StreamFlowValue is either a streamed value or a final output of a flow. -type StreamFlowValue[O, S any] struct { - Done bool - Output O // valid if Done is true - Stream S // valid if Done is false -} +// InternalStreamFlow is for use by genkit.StreamFlow exclusively. +// It is not subject to any backwards compatibility guarantees. +func InternalStreamFlow[I, O, S any](ctx context.Context, flow *Flow[I, O, S], input I, callback func(context.Context, S) error) (O, error) { -var errStop = errors.New("stop") - -// StreamFlow runs flow on input and delivers both the streamed values and the final output. -// It returns a function whose argument function (the "yield function") will be repeatedly -// called with the results. -// -// If the yield function is passed a non-nil error, the flow has failed with that -// error; the yield function will not be called again. An error is also passed if -// the flow fails to complete (that is, it has an interrupt). -// -// If the yield function's [StreamFlowValue] argument has Done == true, the value's -// Output field contains the final output; the yield function will not be called -// again. -// -// Otherwise the Stream field of the passed [StreamFlowValue] holds a streamed result. -func StreamFlow[I, O, S any](ctx context.Context, flow *Flow[I, O, S], input I) func(func(*StreamFlowValue[O, S], error) bool) { - return func(yield func(*StreamFlowValue[O, S], error) bool) { - cb := func(ctx context.Context, s S) error { - if ctx.Err() != nil { - return ctx.Err() - } - if !yield(&StreamFlowValue[O, S]{Stream: s}, nil) { - return errStop - } - return nil - } - var output O - state, err := flow.start(ctx, input, cb) - if err == nil { - if ctx.Err() != nil { - err = ctx.Err() - } else { - output, err = finishedOpResponse(state.Operation) - } - } - if err != nil { - yield(nil, err) - } else { - yield(&StreamFlowValue[O, S]{Done: true, Output: output}, nil) - } + state, err := flow.start(ctx, input, callback) + if err != nil { + return internal.Zero[O](), err } + if ctx.Err() != nil { + return internal.Zero[O](), ctx.Err() + } + return finishedOpResponse(state.Operation) } func finishedOpResponse[O any](op *operation[O]) (O, error) { diff --git a/go/genkit/flow_state_store.go b/go/core/flow_state_store.go similarity index 98% rename from go/genkit/flow_state_store.go rename to go/core/flow_state_store.go index dd803f2b9c..497a8d2ffd 100644 --- a/go/genkit/flow_state_store.go +++ b/go/core/flow_state_store.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import "context" diff --git a/go/genkit/flow_test.go b/go/core/flow_test.go similarity index 88% rename from go/genkit/flow_test.go rename to go/core/flow_test.go index b8d59d2d22..b2761949ce 100644 --- a/go/genkit/flow_test.go +++ b/go/core/flow_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import ( "context" @@ -103,32 +103,6 @@ func TestRunFlow(t *testing.T) { } } -func TestStreamFlow(t *testing.T) { - reg, err := newRegistry() - if err != nil { - t.Fatal(err) - } - f := defineFlow(reg, "count", count) - iter := StreamFlow(context.Background(), f, 2) - want := 0 - iter(func(val *StreamFlowValue[int, int], err error) bool { - if err != nil { - t.Fatal(err) - } - var got int - if val.Done { - got = val.Output - } else { - got = val.Stream - } - if got != want { - t.Errorf("got %d, want %d", got, want) - } - want++ - return true - }) -} - func TestFlowState(t *testing.T) { // A flowState is an action output, so it must support JSON marshaling. // Verify that a fully populated flowState can round-trip via JSON. diff --git a/go/core/gen.go b/go/core/gen.go new file mode 100644 index 0000000000..0ef7fd96db --- /dev/null +++ b/go/core/gen.go @@ -0,0 +1,32 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file was generated by jsonschemagen. DO NOT EDIT. + +package core + +import "github.com/firebase/genkit/go/gtime" + +type flowError struct { + Error string `json:"error,omitempty"` + Stacktrace string `json:"stacktrace,omitempty"` +} + +type flowExecution struct { + // end time in milliseconds since the epoch + EndTime gtime.Milliseconds `json:"endTime,omitempty"` + // start time in milliseconds since the epoch + StartTime gtime.Milliseconds `json:"startTime,omitempty"` + TraceIDs []string `json:"traceIds,omitempty"` +} diff --git a/go/genkit/metrics.go b/go/core/metrics.go similarity index 99% rename from go/genkit/metrics.go rename to go/core/metrics.go index c29ec9b979..13ca3b2376 100644 --- a/go/genkit/metrics.go +++ b/go/core/metrics.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import ( "context" diff --git a/go/genkit/registry.go b/go/core/registry.go similarity index 95% rename from go/genkit/registry.go rename to go/core/registry.go index 93bc45c320..3634c6b994 100644 --- a/go/genkit/registry.go +++ b/go/core/registry.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import ( "context" @@ -96,16 +96,16 @@ const ( // RegisterAction records the action in the global registry. // It panics if an action with the same type, provider and name is already // registered. -func RegisterAction(typ ActionType, provider string, a action) { - globalRegistry.registerAction(typ, provider, a) +func RegisterAction(provider string, a action) { + globalRegistry.registerAction(provider, a) slog.Info("RegisterAction", - "type", typ, + "type", a.actionType(), "provider", provider, "name", a.Name()) } -func (r *registry) registerAction(typ ActionType, provider string, a action) { - key := fmt.Sprintf("/%s/%s/%s", typ, provider, a.Name()) +func (r *registry) registerAction(provider string, a action) { + key := fmt.Sprintf("/%s/%s/%s", a.actionType(), provider, a.Name()) r.mu.Lock() defer r.mu.Unlock() if _, ok := r.actions[key]; ok { diff --git a/go/genkit/servers.go b/go/core/servers.go similarity index 98% rename from go/genkit/servers.go rename to go/core/servers.go index 7d8107a84b..3c66fa4fac 100644 --- a/go/genkit/servers.go +++ b/go/core/servers.go @@ -20,7 +20,7 @@ // The production server has a route for each flow. It // is intended for production deployments. -package genkit +package core import ( "context" @@ -117,7 +117,7 @@ func (s *devServer) handleRunAction(w http.ResponseWriter, r *http.Request) erro internal.Logger(ctx).Debug("running action", "key", body.Key, "stream", stream) - var callback StreamingCallback[json.RawMessage] + var callback streamingCallback[json.RawMessage] if stream { // Stream results are newline-separated JSON. callback = func(ctx context.Context, msg json.RawMessage) error { @@ -141,7 +141,7 @@ type telemetry struct { TraceID string `json:"traceId"` } -func runAction(ctx context.Context, reg *registry, key string, input json.RawMessage, cb StreamingCallback[json.RawMessage]) (*runActionResponse, error) { +func runAction(ctx context.Context, reg *registry, key string, input json.RawMessage, cb streamingCallback[json.RawMessage]) (*runActionResponse, error) { action := reg.lookupAction(key) if action == nil { return nil, &httpError{http.StatusNotFound, fmt.Errorf("no action with key %q", key)} diff --git a/go/genkit/servers_test.go b/go/core/servers_test.go similarity index 93% rename from go/genkit/servers_test.go rename to go/core/servers_test.go index 2ded092058..cbe75cfe66 100644 --- a/go/genkit/servers_test.go +++ b/go/core/servers_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit +package core import ( "context" @@ -38,17 +38,17 @@ func TestDevServer(t *testing.T) { if err != nil { t.Fatal(err) } - r.registerAction("test", "devServer", NewAction("inc", ActionTypeCustom, map[string]any{ + r.registerAction("devServer", NewAction("inc", ActionTypeCustom, map[string]any{ "foo": "bar", }, inc)) - r.registerAction("test", "devServer", NewAction("dec", ActionTypeCustom, map[string]any{ + r.registerAction("devServer", NewAction("dec", ActionTypeCustom, map[string]any{ "bar": "baz", }, dec)) srv := httptest.NewServer(newDevServeMux(r)) defer srv.Close() t.Run("runAction", func(t *testing.T) { - body := `{"key": "/test/devServer/inc", "input": 3}` + body := `{"key": "/custom/devServer/inc", "input": 3}` res, err := http.Post(srv.URL+"/api/runAction", "application/json", strings.NewReader(body)) if err != nil { t.Fatal(err) @@ -84,15 +84,15 @@ func TestDevServer(t *testing.T) { t.Fatal(err) } want := map[string]actionDesc{ - "/test/devServer/inc": { - Key: "/test/devServer/inc", + "/custom/devServer/inc": { + Key: "/custom/devServer/inc", Name: "inc", InputSchema: &jsonschema.Schema{Type: "integer"}, OutputSchema: &jsonschema.Schema{Type: "integer"}, Metadata: map[string]any{"foo": "bar"}, }, - "/test/devServer/dec": { - Key: "/test/devServer/dec", + "/custom/devServer/dec": { + Key: "/custom/devServer/dec", InputSchema: &jsonschema.Schema{Type: "integer"}, OutputSchema: &jsonschema.Schema{Type: "integer"}, Name: "dec", diff --git a/go/genkit/testdata/conformance/basic.json b/go/core/testdata/conformance/basic.json similarity index 100% rename from go/genkit/testdata/conformance/basic.json rename to go/core/testdata/conformance/basic.json diff --git a/go/genkit/testdata/conformance/run-1.json b/go/core/testdata/conformance/run-1.json similarity index 100% rename from go/genkit/testdata/conformance/run-1.json rename to go/core/testdata/conformance/run-1.json diff --git a/go/genkit/gen.go b/go/genkit/gen.go index 2a64ee8c10..8686b47fa6 100644 --- a/go/genkit/gen.go +++ b/go/genkit/gen.go @@ -15,18 +15,3 @@ // This file was generated by jsonschemagen. DO NOT EDIT. package genkit - -import "github.com/firebase/genkit/go/gtime" - -type flowError struct { - Error string `json:"error,omitempty"` - Stacktrace string `json:"stacktrace,omitempty"` -} - -type flowExecution struct { - // end time in milliseconds since the epoch - EndTime gtime.Milliseconds `json:"endTime,omitempty"` - // start time in milliseconds since the epoch - StartTime gtime.Milliseconds `json:"startTime,omitempty"` - TraceIDs []string `json:"traceIds,omitempty"` -} diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 1b7475ee5b..dafa919583 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -19,5 +19,114 @@ // Run the Go code generator on the file just created. //go:generate go run ../internal/cmd/jsonschemagen -outdir .. -config schemas.config ../../genkit-tools/genkit-schema.json genkit -// Package genkit is the genkit API for Go. +// Package genkit provides Genkit functionality for application developers. package genkit + +import ( + "context" + "errors" + "net/http" + + "github.com/firebase/genkit/go/core" +) + +// DefineFlow creates a Flow that runs fn, and registers it as an action. +// +// fn takes an input of type I and returns an output of type O, optionally +// streaming values of type S incrementally by invoking a callback. +// If the callback is non-nil and the function supports streaming, it should +// stream the results by invoking the callback periodically, ultimately returning +// with a final return value. Otherwise, it should ignore the StreamingCallback and +// just return a result. +func DefineFlow[I, O, S any]( + name string, + fn func(ctx context.Context, input I, callback func(context.Context, S) error) (O, error), +) *core.Flow[I, O, S] { + return core.DefineFlow(name, core.Func[I, O, S](fn)) +} + +// Run runs the function f in the context of the current flow. +// It returns an error if no flow is active. +// +// Each call to Run results in a new step in the flow. +// A step has its own span in the trace, and its result is cached so that if the flow +// is restarted, f will not be called a second time. +func Run[T any](ctx context.Context, name string, f func() (T, error)) (T, error) { + return core.Run(ctx, name, f) +} + +// RunFlow runs flow in the context of another flow. The flow must run to completion when started +// (that is, it must not have interrupts). +func RunFlow[I, O, S any](ctx context.Context, flow *core.Flow[I, O, S], input I) (O, error) { + return core.RunFlow(ctx, flow, input) +} + +type NoStream = core.NoStream + +// StreamFlowValue is either a streamed value or a final output of a flow. +type StreamFlowValue[O, S any] struct { + Done bool + Output O // valid if Done is true + Stream S // valid if Done is false +} + +// StreamFlow runs flow on input and delivers both the streamed values and the final output. +// It returns a function whose argument function (the "yield function") will be repeatedly +// called with the results. +// +// If the yield function is passed a non-nil error, the flow has failed with that +// error; the yield function will not be called again. An error is also passed if +// the flow fails to complete (that is, it has an interrupt). +// Genkit Go does not yet support interrupts. +// +// If the yield function's [StreamFlowValue] argument has Done == true, the value's +// Output field contains the final output; the yield function will not be called +// again. +// +// Otherwise the Stream field of the passed [StreamFlowValue] holds a streamed result. +func StreamFlow[I, O, S any](ctx context.Context, flow *core.Flow[I, O, S], input I) func(func(*StreamFlowValue[O, S], error) bool) { + return func(yield func(*StreamFlowValue[O, S], error) bool) { + cb := func(ctx context.Context, s S) error { + if ctx.Err() != nil { + return ctx.Err() + } + if !yield(&StreamFlowValue[O, S]{Stream: s}, nil) { + return errStop + } + return nil + } + output, err := core.InternalStreamFlow(ctx, flow, input, cb) + if err != nil { + yield(nil, err) + } else { + yield(&StreamFlowValue[O, S]{Done: true, Output: output}, nil) + } + } +} + +var errStop = errors.New("stop") + +// StartFlowServer starts a server serving the routes described in [NewFlowServeMux]. +// It listens on addr, or if empty, the value of the PORT environment variable, +// or if that is empty, ":3400". +// +// In development mode (if the environment variable GENKIT_ENV=dev), it also starts +// a dev server. +// +// StartFlowServer always returns a non-nil error, the one returned by http.ListenAndServe. +func StartFlowServer(addr string) error { + return core.StartFlowServer(addr) +} + +// NewFlowServeMux constructs a [net/http.ServeMux] where each defined flow is a route. +// All routes take a single query parameter, "stream", which if true will stream the +// flow's results back to the client. (Not all flows support streaming, however.) +// +// To use the returned ServeMux as part of a server with other routes, either add routes +// to it, or install it as part of another ServeMux, like so: +// +// mainMux := http.NewServeMux() +// mainMux.Handle("POST /flow/", http.StripPrefix("/flow/", NewFlowServeMux())) +func NewFlowServeMux() *http.ServeMux { + return core.NewFlowServeMux() +} diff --git a/go/genkit/genkit_test.go b/go/genkit/genkit_test.go new file mode 100644 index 0000000000..d9519a6fe7 --- /dev/null +++ b/go/genkit/genkit_test.go @@ -0,0 +1,54 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package genkit + +import ( + "context" + "testing" +) + +func TestStreamFlow(t *testing.T) { + f := DefineFlow("count", count) + iter := StreamFlow(context.Background(), f, 2) + want := 0 + iter(func(val *StreamFlowValue[int, int], err error) bool { + if err != nil { + t.Fatal(err) + } + var got int + if val.Done { + got = val.Output + } else { + got = val.Stream + } + if got != want { + t.Errorf("got %d, want %d", got, want) + } + want++ + return true + }) +} + +// count streams the numbers from 0 to n-1, then returns n. +func count(ctx context.Context, n int, cb func(context.Context, int) error) (int, error) { + if cb != nil { + for i := 0; i < n; i++ { + if err := cb(ctx, i); err != nil { + return 0, err + } + } + } + return n, nil +} diff --git a/go/genkit/schemas.config b/go/genkit/schemas.config index 3acb49ace2..079233d719 100644 --- a/go/genkit/schemas.config +++ b/go/genkit/schemas.config @@ -1,7 +1,7 @@ # This file holds configuration for the genkit-schema.json file # generated by the npm export:schemas script. -genkit import github.com/firebase/genkit/go/gtime +core import github.com/firebase/genkit/go/gtime # DocumentData type was hand-written. DocumentData omit @@ -67,10 +67,12 @@ FlowInvokeEnvelopeMessageRunScheduled omit Operation omit FlowStateExecution name flowExecution +FlowStateExecution pkg core FlowStateExecution.startTime type gtime.Milliseconds FlowStateExecution.endTime type gtime.Milliseconds FlowError name flowError +FlowError pkg core GenerateRequest.messages doc Messages is a list of messages to pass to the model. The first n-1 Messages diff --git a/go/plugins/dotprompt/genkit.go b/go/plugins/dotprompt/genkit.go index 5fc3c860fd..5af27483ca 100644 --- a/go/plugins/dotprompt/genkit.go +++ b/go/plugins/dotprompt/genkit.go @@ -21,7 +21,7 @@ import ( "strings" "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/internal/tracing" ) @@ -116,12 +116,12 @@ func (p *Prompt) buildRequest(input *ActionInput) (*ai.GenerateRequest, error) { return req, nil } -// Action returns a [genkit.Action] that executes the prompt. +// Action returns a [core.Action] that executes the prompt. // The returned Action will take an [ActionInput] that provides // variables to substitute into the template text. // It will then pass the rendered text to an AI generator, // and return whatever the generator computes. -func (p *Prompt) Action() (*genkit.Action[*ActionInput, *ai.GenerateResponse, struct{}], error) { +func (p *Prompt) Action() (*core.Action[*ActionInput, *ai.GenerateResponse, struct{}], error) { if p.Name == "" { return nil, errors.New("dotprompt: missing name") } @@ -130,7 +130,7 @@ func (p *Prompt) Action() (*genkit.Action[*ActionInput, *ai.GenerateResponse, st name += "." + p.Variant } - a := genkit.NewAction(name, genkit.ActionTypePrompt, nil, p.Execute) + a := core.NewAction(name, core.ActionTypePrompt, nil, p.Execute) a.Metadata = map[string]any{ "type": "prompt", "prompt": p, @@ -153,7 +153,7 @@ func (p *Prompt) Register() error { return err } - genkit.RegisterAction(genkit.ActionTypePrompt, name, action) + core.RegisterAction(name, action) return nil } diff --git a/go/plugins/dotprompt/genkit_test.go b/go/plugins/dotprompt/genkit_test.go index 61d5148c17..8d03f5b24f 100644 --- a/go/plugins/dotprompt/genkit_test.go +++ b/go/plugins/dotprompt/genkit_test.go @@ -20,12 +20,11 @@ import ( "testing" "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/genkit" ) type testGenerator struct{} -func (testGenerator) Generate(ctx context.Context, req *ai.GenerateRequest, cb genkit.StreamingCallback[*ai.Candidate]) (*ai.GenerateResponse, error) { +func (testGenerator) Generate(ctx context.Context, req *ai.GenerateRequest, cb func(context.Context, *ai.Candidate) error) (*ai.GenerateResponse, error) { input := req.Messages[0].Content[0].Text() output := fmt.Sprintf("AI reply to %q", input) diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 34f3dbed28..f2dedc32fe 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -19,7 +19,6 @@ import ( "fmt" "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/genkit" "github.com/google/generative-ai-go/genai" "google.golang.org/api/iterator" "google.golang.org/api/option" @@ -35,7 +34,7 @@ type generator struct { //session *genai.ChatSession // non-nil if we're in the middle of a chat } -func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb genkit.StreamingCallback[*ai.Candidate]) (*ai.GenerateResponse, error) { +func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb func(context.Context, *ai.Candidate) error) (*ai.GenerateResponse, error) { gm := g.client.GenerativeModel(g.model) // Translate from a ai.GenerateRequest to a genai request. diff --git a/go/plugins/googlecloud/googlecloud.go b/go/plugins/googlecloud/googlecloud.go index 3c5b38821d..6344ee388c 100644 --- a/go/plugins/googlecloud/googlecloud.go +++ b/go/plugins/googlecloud/googlecloud.go @@ -27,7 +27,7 @@ import ( "cloud.google.com/go/logging" mexporter "github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric" texporter "github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace" - "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/core" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" @@ -64,7 +64,7 @@ func Init(ctx context.Context, projectID string, opts *Options) error { return err } aexp := &adjustingTraceExporter{texp} - genkit.RegisterSpanProcessor(sdktrace.NewBatchSpanProcessor(aexp)) + core.RegisterSpanProcessor(sdktrace.NewBatchSpanProcessor(aexp)) if err := setMeterProvider(projectID, opts.MetricInterval); err != nil { return err } diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index 554a1fce9a..ff00c05558 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -20,7 +20,6 @@ import ( "cloud.google.com/go/vertexai/genai" "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/genkit" ) func newClient(ctx context.Context, projectID, location string) (*genai.Client, error) { @@ -32,7 +31,7 @@ type generator struct { client *genai.Client } -func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb genkit.StreamingCallback[*ai.Candidate]) (*ai.GenerateResponse, error) { +func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb func(context.Context, *ai.Candidate) error) (*ai.GenerateResponse, error) { if cb != nil { panic("streaming not supported yet") // TODO: streaming } diff --git a/go/samples/flow-sample1/main.go b/go/samples/flow-sample1/main.go index a42141f189..a0107f9eb9 100644 --- a/go/samples/flow-sample1/main.go +++ b/go/samples/flow-sample1/main.go @@ -61,7 +61,7 @@ func main() { Count int `json:"count"` } - genkit.DefineFlow("streamy", func(ctx context.Context, count int, cb genkit.StreamingCallback[chunk]) (string, error) { + genkit.DefineFlow("streamy", func(ctx context.Context, count int, cb func(context.Context, chunk) error) (string, error) { i := 0 if cb != nil { for ; i < count; i++ {