Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ type ToolDefinition struct {
OutputSchema map[string]any `json:"outputSchema,omitempty"`
}

// A ToolRequest is a request from the model that the client should run
// a specific tool and pass a [ToolResponse] to the model on the next request it makes.
// A ToolRequest is a message from the model to the client that it should run a
// specific tool and pass a [ToolResponse] to the model on the next chat request it makes.
// Any ToolRequest will correspond to some [ToolDefinition] previously sent by the client.
type ToolRequest struct {
// Input is a JSON object describing the input values to the tool.
Expand All @@ -145,7 +145,7 @@ type ToolRequest struct {
Name string `json:"name,omitempty"`
}

// A ToolResponse is a response from the client to the model containing
// A ToolResponse is a message from the client to the model containing
// the results of running a specific tool on the arguments passed to the client
// by the model in a [ToolRequest].
type ToolResponse struct {
Expand Down
10 changes: 6 additions & 4 deletions go/genkit/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,21 @@ import (
"reflect"
"time"

"github.com/firebase/genkit/go/internal"
"github.com/invopop/jsonschema"
)

// Func is the type of function that Actions and Flows execute.
// It takes an input of type I and returns an output of type O, optionally
// streaming values of type S incrementally by invoking a callback.
// TODO(jba): use a generic type alias when they become available?
// If the StreamingCallback is non-nil and the function supports streaming, it should
// stream the results by invoking the callback periodically, ultimately returning
// with a final return value. Otherwise, it should ignore the StreamingCallback and
// just return a result.
type Func[I, O, S any] func(context.Context, I, StreamingCallback[S]) (O, error)

// TODO(jba): use a generic type alias for the above when they become available?

// StreamingCallback is the type of streaming callbacks, which is passed to action
// functions who should stream their responses.
type StreamingCallback[S any] func(context.Context, S) error
Expand All @@ -44,7 +46,7 @@ type StreamingCallback[S any] func(context.Context, S) error
// Such a function corresponds to a Flow[I, O, struct{}].
type NoStream = StreamingCallback[struct{}]

// An Action is a function with a name.
// An Action is a named, observable operation.
// It consists of a function that takes an input of type I and returns an output
// of type O, optionally streaming values of type S incrementally by invoking a callback.
// It optionally has other metadata, like a description
Expand Down Expand Up @@ -90,7 +92,7 @@ func (a *Action[I, O, S]) Name() string { return a.name }
// setTracingState sets the action's tracingState.
func (a *Action[I, O, S]) setTracingState(tstate *tracingState) { a.tstate = tstate }

// Run executes the Action's function in a new span.
// Run executes the Action's function in a new trace span.
func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb StreamingCallback[S]) (output O, err error) {
// TODO: validate input against JSONSchema for I.
// TODO: validate output against JSONSchema for O.
Expand All @@ -115,7 +117,7 @@ func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb StreamingCallback
latency := time.Since(start)
if err != nil {
writeActionFailure(ctx, a.name, latency, err)
return zero[O](), err
return internal.Zero[O](), err
}
writeActionSuccess(ctx, a.name, latency)
return out, nil
Expand Down
5 changes: 3 additions & 2 deletions go/genkit/conformance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"testing"
"time"

"github.com/firebase/genkit/go/internal"
"golang.org/x/exp/maps"
)

Expand Down Expand Up @@ -84,7 +85,7 @@ func TestFlowConformance(t *testing.T) {
for _, filename := range testFiles {
t.Run(strings.TrimSuffix(filepath.Base(filename), ".json"), func(t *testing.T) {
var test conformanceTest
if err := readJSONFile(filename, &test); err != nil {
if err := internal.ReadJSONFile(filename, &test); err != nil {
t.Fatal(err)
}
// Each test uses its own registry to avoid interference.
Expand All @@ -111,7 +112,7 @@ func TestFlowConformance(t *testing.T) {
}
ts := r.lookupTraceStore(EnvironmentDev)
var gotTrace any
if err := ts.loadAny(resp.Telemetry.TraceID, &gotTrace); err != nil {
if err := ts.LoadAny(resp.Telemetry.TraceID, &gotTrace); err != nil {
t.Fatal(err)
}
renameSpans(t, gotTrace)
Expand Down
11 changes: 6 additions & 5 deletions go/genkit/dev_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"syscall"
"time"

"github.com/firebase/genkit/go/gtrace"
"go.opentelemetry.io/otel/trace"
)

Expand Down Expand Up @@ -250,22 +251,22 @@ func (s *devServer) handleListTraces(w http.ResponseWriter, r *http.Request) err
}
}
ctoken := r.FormValue("continuationToken")
tds, ctoken, err := ts.List(r.Context(), &TraceQuery{Limit: limit, ContinuationToken: ctoken})
if errors.Is(err, errBadQuery) {
tds, ctoken, err := ts.List(r.Context(), &gtrace.Query{Limit: limit, ContinuationToken: ctoken})
if errors.Is(err, gtrace.ErrBadQuery) {
return &httpError{http.StatusBadRequest, err}
}
if err != nil {
return err
}
if tds == nil {
tds = []*TraceData{}
tds = []*gtrace.Data{}
}
return writeJSON(r.Context(), w, listTracesResult{tds, ctoken})
}

type listTracesResult struct {
Traces []*TraceData `json:"traces"`
ContinuationToken string `json:"continuationToken"`
Traces []*gtrace.Data `json:"traces"`
ContinuationToken string `json:"continuationToken"`
}

func (s *devServer) handleListFlowStates(w http.ResponseWriter, r *http.Request) error {
Expand Down
17 changes: 9 additions & 8 deletions go/genkit/dev_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"strings"
"testing"

"github.com/firebase/genkit/go/gtrace"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)
Expand Down Expand Up @@ -102,7 +103,7 @@ func TestDevServer(t *testing.T) {
if err != nil {
t.Fatal(err)
}
// We may have any result, including zero traces, so don't check anything else.
// We may have any result, including internal.Zero traces, so don't check anything else.
})
}

Expand All @@ -113,13 +114,13 @@ func checkActionTrace(t *testing.T, reg *registry, tid, name string) {
t.Fatal(err)
}
rootSpan := findRootSpan(t, td.Spans)
want := &SpanData{
want := &gtrace.SpanData{
TraceID: tid,
DisplayName: "dev-run-action-wrapper",
SpanKind: "INTERNAL",
SameProcessAsParentSpan: boolValue{Value: true},
Status: Status{Code: 0},
InstrumentationLibrary: InstrumentationLibrary{
SameProcessAsParentSpan: gtrace.BoolValue{Value: true},
Status: gtrace.Status{Code: 0},
InstrumentationLibrary: gtrace.InstrumentationLibrary{
Name: "genkit-tracer",
Version: "v1",
},
Expand All @@ -133,17 +134,17 @@ func checkActionTrace(t *testing.T, reg *registry, tid, name string) {
"genkit:state": "success",
},
}
diff := cmp.Diff(want, rootSpan, cmpopts.IgnoreFields(SpanData{}, "SpanID", "StartTime", "EndTime"))
diff := cmp.Diff(want, rootSpan, cmpopts.IgnoreFields(gtrace.SpanData{}, "SpanID", "StartTime", "EndTime"))
if diff != "" {
t.Errorf("mismatch (-want, +got):\n%s", diff)
}
}

// findRootSpan finds the root span in spans.
// It also verifies that it is unique.
func findRootSpan(t *testing.T, spans map[string]*SpanData) *SpanData {
func findRootSpan(t *testing.T, spans map[string]*gtrace.SpanData) *gtrace.SpanData {
t.Helper()
var root *SpanData
var root *gtrace.SpanData
for _, sd := range spans {
if sd.ParentSpanID == "" {
if root != nil {
Expand Down
6 changes: 4 additions & 2 deletions go/genkit/file_flow_state_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"context"
"os"
"path/filepath"

"github.com/firebase/genkit/go/internal"
)

// A FileFlowStateStore is a FlowStateStore that writes flowStates to files.
Expand All @@ -37,9 +39,9 @@ func NewFileFlowStateStore(dir string) (*FileFlowStateStore, error) {
func (s *FileFlowStateStore) Save(ctx context.Context, id string, fs flowStater) error {
fs.lock()
defer fs.unlock()
return writeJSONFile(filepath.Join(s.dir, clean(id)), fs)
return internal.WriteJSONFile(filepath.Join(s.dir, internal.Clean(id)), fs)
}

func (s *FileFlowStateStore) Load(ctx context.Context, id string, pfs any) error {
return readJSONFile(filepath.Join(s.dir, clean(id)), pfs)
return internal.ReadJSONFile(filepath.Join(s.dir, internal.Clean(id)), pfs)
}
65 changes: 34 additions & 31 deletions go/genkit/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,27 @@
// See the License for the specific language governing permissions and
// limitations under the License.

// This file contains code for flows.
package genkit

import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"sync"
"time"

"github.com/firebase/genkit/go/gtime"
"github.com/firebase/genkit/go/internal"
"github.com/google/uuid"
otrace "go.opentelemetry.io/otel/trace"
)

// TODO(jba): support auth
// TODO(jba): provide a way to start a Flow from user code.

// A Flow is a kind of Action that can be interrupted and resumed.
//
// A Flow[I, O, S] represents a function from I to O (the S parameter is described
// under "Streaming" below). But the function may run in pieces, with interruptions
Expand Down Expand Up @@ -62,25 +82,6 @@
//
// Streaming is only supported for the "start" flow instruction. Currently there is
// no way to schedule or resume a flow with streaming.
package genkit

import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"sync"
"time"

"github.com/google/uuid"
"go.opentelemetry.io/otel/trace"
)

// TODO(jba): support auth
// TODO(jba): provide a way to start a Flow from user code.

// A Flow is a kind of Action that can be interrupted and resumed.
type Flow[I, O, S any] struct {
name string // The last component of the flow's key in the registry.
fn Func[I, O, S] // The function to run.
Expand Down Expand Up @@ -110,6 +111,8 @@ func defineFlow[I, O, S any](r *registry, name string, fn Func[I, O, S]) *Flow[I
return f
}

// TODO(jba): use flowError?

// A flowInstruction is an instruction to follow with a flow.
// It is the input for the flow's action.
// Exactly one field will be non-nil.
Expand Down Expand Up @@ -162,8 +165,8 @@ type flowState[I, O any] struct {
FlowID string `json:"flowId,omitempty"`
FlowName string `json:"name,omitempty"`
// start time in milliseconds since the epoch
StartTime Milliseconds `json:"startTime,omitempty"`
Input I `json:"input,omitempty"`
StartTime gtime.Milliseconds `json:"startTime,omitempty"`
Input I `json:"input,omitempty"`

mu sync.Mutex
Cache map[string]json.RawMessage `json:"cache,omitempty"`
Expand All @@ -179,7 +182,7 @@ func newFlowState[I, O any](id, name string, input I) *flowState[I, O] {
FlowID: id,
FlowName: name,
Input: input,
StartTime: timeToMilliseconds(time.Now()),
StartTime: gtime.ToMilliseconds(time.Now()),
Cache: map[string]json.RawMessage{},
Operation: &Operation[O]{
FlowID: id,
Expand Down Expand Up @@ -283,7 +286,7 @@ func (f *Flow[I, O, S]) execute(ctx context.Context, state *flowState[I, O], dis
}()
ctx = flowContextKey.newContext(ctx, fctx)
exec := &FlowExecution{
StartTime: timeToMilliseconds(time.Now()),
StartTime: gtime.ToMilliseconds(time.Now()),
}
state.mu.Lock()
state.Executions = append(state.Executions, exec)
Expand All @@ -297,7 +300,7 @@ func (f *Flow[I, O, S]) execute(ctx context.Context, state *flowState[I, O], dis
spanMeta.SetAttr("flow:name", f.name)
spanMeta.SetAttr("flow:id", state.FlowID)
spanMeta.SetAttr("flow:dispatchType", dispatchType)
rootSpanContext := trace.SpanContextFromContext(ctx)
rootSpanContext := otrace.SpanContextFromContext(ctx)
traceID := rootSpanContext.TraceID().String()
exec.TraceIDs = append(exec.TraceIDs, traceID)
// TODO(jba): Save rootSpanContext in the state.
Expand Down Expand Up @@ -432,18 +435,18 @@ func Run[T any](ctx context.Context, name string, f func() (T, error)) (T, error
if ok {
var t T
if err := json.Unmarshal(j, &t); err != nil {
return zero[T](), err
return internal.Zero[T](), err
}
spanMeta.SetAttr("flow:state", "cached")
return t, nil
}
t, err := f()
if err != nil {
return zero[T](), err
return internal.Zero[T](), err
}
bytes, err := json.Marshal(t)
if err != nil {
return zero[T](), err
return internal.Zero[T](), err
}
fs.lock()
fs.cache()[uName] = json.RawMessage(bytes)
Expand All @@ -458,7 +461,7 @@ func Run[T any](ctx context.Context, name string, f func() (T, error)) (T, error
func RunFlow[I, O, S any](ctx context.Context, flow *Flow[I, O, S], input I) (O, error) {
state, err := flow.start(ctx, input, nil)
if err != nil {
return zero[O](), err
return internal.Zero[O](), err
}
return finishedOpResponse(state.Operation)
}
Expand Down Expand Up @@ -515,10 +518,10 @@ func StreamFlow[I, O, S any](ctx context.Context, flow *Flow[I, O, S], input I)

func finishedOpResponse[O any](op *Operation[O]) (O, error) {
if !op.Done {
return zero[O](), fmt.Errorf("flow %s did not finish execution", op.FlowID)
return internal.Zero[O](), fmt.Errorf("flow %s did not finish execution", op.FlowID)
}
if op.Result.Error != "" {
return zero[O](), fmt.Errorf("flow %s: %s", op.FlowID, op.Result.Error)
return internal.Zero[O](), fmt.Errorf("flow %s: %s", op.FlowID, op.Result.Error)
}
return op.Result.Response, nil
}
Loading