Skip to content

Commit

Permalink
[Go] reorganize and document (#200)
Browse files Browse the repository at this point in the history
* [Go] reorganize and document

Make the API nicer.

- Unexport symbols that don't need to be visible.

- Write documentation some undocumented symbols.

- move some tracing symbols to a separate package
  • Loading branch information
jba authored May 21, 2024
1 parent b068890 commit f390858
Show file tree
Hide file tree
Showing 25 changed files with 359 additions and 282 deletions.
6 changes: 3 additions & 3 deletions go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ type ToolDefinition struct {
OutputSchema map[string]any `json:"outputSchema,omitempty"`
}

// A ToolRequest is a request from the model that the client should run
// a specific tool and pass a [ToolResponse] to the model on the next request it makes.
// A ToolRequest is a message from the model to the client that it should run a
// specific tool and pass a [ToolResponse] to the model on the next chat request it makes.
// Any ToolRequest will correspond to some [ToolDefinition] previously sent by the client.
type ToolRequest struct {
// Input is a JSON object describing the input values to the tool.
Expand All @@ -145,7 +145,7 @@ type ToolRequest struct {
Name string `json:"name,omitempty"`
}

// A ToolResponse is a response from the client to the model containing
// A ToolResponse is a message from the client to the model containing
// the results of running a specific tool on the arguments passed to the client
// by the model in a [ToolRequest].
type ToolResponse struct {
Expand Down
10 changes: 6 additions & 4 deletions go/genkit/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,21 @@ import (
"reflect"
"time"

"github.com/firebase/genkit/go/internal"
"github.com/invopop/jsonschema"
)

// Func is the type of function that Actions and Flows execute.
// It takes an input of type I and returns an output of type O, optionally
// streaming values of type S incrementally by invoking a callback.
// TODO(jba): use a generic type alias when they become available?
// If the StreamingCallback is non-nil and the function supports streaming, it should
// stream the results by invoking the callback periodically, ultimately returning
// with a final return value. Otherwise, it should ignore the StreamingCallback and
// just return a result.
type Func[I, O, S any] func(context.Context, I, StreamingCallback[S]) (O, error)

// TODO(jba): use a generic type alias for the above when they become available?

// StreamingCallback is the type of streaming callbacks, which is passed to action
// functions who should stream their responses.
type StreamingCallback[S any] func(context.Context, S) error
Expand All @@ -44,7 +46,7 @@ type StreamingCallback[S any] func(context.Context, S) error
// Such a function corresponds to a Flow[I, O, struct{}].
type NoStream = StreamingCallback[struct{}]

// An Action is a function with a name.
// An Action is a named, observable operation.
// It consists of a function that takes an input of type I and returns an output
// of type O, optionally streaming values of type S incrementally by invoking a callback.
// It optionally has other metadata, like a description
Expand Down Expand Up @@ -90,7 +92,7 @@ func (a *Action[I, O, S]) Name() string { return a.name }
// setTracingState sets the action's tracingState.
func (a *Action[I, O, S]) setTracingState(tstate *tracingState) { a.tstate = tstate }

// Run executes the Action's function in a new span.
// Run executes the Action's function in a new trace span.
func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb StreamingCallback[S]) (output O, err error) {
// TODO: validate input against JSONSchema for I.
// TODO: validate output against JSONSchema for O.
Expand All @@ -115,7 +117,7 @@ func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb StreamingCallback
latency := time.Since(start)
if err != nil {
writeActionFailure(ctx, a.name, latency, err)
return zero[O](), err
return internal.Zero[O](), err
}
writeActionSuccess(ctx, a.name, latency)
return out, nil
Expand Down
5 changes: 3 additions & 2 deletions go/genkit/conformance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"testing"
"time"

"github.com/firebase/genkit/go/internal"
"golang.org/x/exp/maps"
)

Expand Down Expand Up @@ -84,7 +85,7 @@ func TestFlowConformance(t *testing.T) {
for _, filename := range testFiles {
t.Run(strings.TrimSuffix(filepath.Base(filename), ".json"), func(t *testing.T) {
var test conformanceTest
if err := readJSONFile(filename, &test); err != nil {
if err := internal.ReadJSONFile(filename, &test); err != nil {
t.Fatal(err)
}
// Each test uses its own registry to avoid interference.
Expand All @@ -111,7 +112,7 @@ func TestFlowConformance(t *testing.T) {
}
ts := r.lookupTraceStore(EnvironmentDev)
var gotTrace any
if err := ts.loadAny(resp.Telemetry.TraceID, &gotTrace); err != nil {
if err := ts.LoadAny(resp.Telemetry.TraceID, &gotTrace); err != nil {
t.Fatal(err)
}
renameSpans(t, gotTrace)
Expand Down
11 changes: 6 additions & 5 deletions go/genkit/dev_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"syscall"
"time"

"github.com/firebase/genkit/go/gtrace"
"go.opentelemetry.io/otel/trace"
)

Expand Down Expand Up @@ -250,22 +251,22 @@ func (s *devServer) handleListTraces(w http.ResponseWriter, r *http.Request) err
}
}
ctoken := r.FormValue("continuationToken")
tds, ctoken, err := ts.List(r.Context(), &TraceQuery{Limit: limit, ContinuationToken: ctoken})
if errors.Is(err, errBadQuery) {
tds, ctoken, err := ts.List(r.Context(), &gtrace.Query{Limit: limit, ContinuationToken: ctoken})
if errors.Is(err, gtrace.ErrBadQuery) {
return &httpError{http.StatusBadRequest, err}
}
if err != nil {
return err
}
if tds == nil {
tds = []*TraceData{}
tds = []*gtrace.Data{}
}
return writeJSON(r.Context(), w, listTracesResult{tds, ctoken})
}

type listTracesResult struct {
Traces []*TraceData `json:"traces"`
ContinuationToken string `json:"continuationToken"`
Traces []*gtrace.Data `json:"traces"`
ContinuationToken string `json:"continuationToken"`
}

func (s *devServer) handleListFlowStates(w http.ResponseWriter, r *http.Request) error {
Expand Down
17 changes: 9 additions & 8 deletions go/genkit/dev_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"strings"
"testing"

"github.com/firebase/genkit/go/gtrace"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)
Expand Down Expand Up @@ -102,7 +103,7 @@ func TestDevServer(t *testing.T) {
if err != nil {
t.Fatal(err)
}
// We may have any result, including zero traces, so don't check anything else.
// We may have any result, including internal.Zero traces, so don't check anything else.
})
}

Expand All @@ -113,13 +114,13 @@ func checkActionTrace(t *testing.T, reg *registry, tid, name string) {
t.Fatal(err)
}
rootSpan := findRootSpan(t, td.Spans)
want := &SpanData{
want := &gtrace.SpanData{
TraceID: tid,
DisplayName: "dev-run-action-wrapper",
SpanKind: "INTERNAL",
SameProcessAsParentSpan: boolValue{Value: true},
Status: Status{Code: 0},
InstrumentationLibrary: InstrumentationLibrary{
SameProcessAsParentSpan: gtrace.BoolValue{Value: true},
Status: gtrace.Status{Code: 0},
InstrumentationLibrary: gtrace.InstrumentationLibrary{
Name: "genkit-tracer",
Version: "v1",
},
Expand All @@ -133,17 +134,17 @@ func checkActionTrace(t *testing.T, reg *registry, tid, name string) {
"genkit:state": "success",
},
}
diff := cmp.Diff(want, rootSpan, cmpopts.IgnoreFields(SpanData{}, "SpanID", "StartTime", "EndTime"))
diff := cmp.Diff(want, rootSpan, cmpopts.IgnoreFields(gtrace.SpanData{}, "SpanID", "StartTime", "EndTime"))
if diff != "" {
t.Errorf("mismatch (-want, +got):\n%s", diff)
}
}

// findRootSpan finds the root span in spans.
// It also verifies that it is unique.
func findRootSpan(t *testing.T, spans map[string]*SpanData) *SpanData {
func findRootSpan(t *testing.T, spans map[string]*gtrace.SpanData) *gtrace.SpanData {
t.Helper()
var root *SpanData
var root *gtrace.SpanData
for _, sd := range spans {
if sd.ParentSpanID == "" {
if root != nil {
Expand Down
6 changes: 4 additions & 2 deletions go/genkit/file_flow_state_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"context"
"os"
"path/filepath"

"github.com/firebase/genkit/go/internal"
)

// A FileFlowStateStore is a FlowStateStore that writes flowStates to files.
Expand All @@ -37,9 +39,9 @@ func NewFileFlowStateStore(dir string) (*FileFlowStateStore, error) {
func (s *FileFlowStateStore) Save(ctx context.Context, id string, fs flowStater) error {
fs.lock()
defer fs.unlock()
return writeJSONFile(filepath.Join(s.dir, clean(id)), fs)
return internal.WriteJSONFile(filepath.Join(s.dir, internal.Clean(id)), fs)
}

func (s *FileFlowStateStore) Load(ctx context.Context, id string, pfs any) error {
return readJSONFile(filepath.Join(s.dir, clean(id)), pfs)
return internal.ReadJSONFile(filepath.Join(s.dir, internal.Clean(id)), pfs)
}
65 changes: 34 additions & 31 deletions go/genkit/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,27 @@
// See the License for the specific language governing permissions and
// limitations under the License.

// This file contains code for flows.
package genkit

import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"sync"
"time"

"github.com/firebase/genkit/go/gtime"
"github.com/firebase/genkit/go/internal"
"github.com/google/uuid"
otrace "go.opentelemetry.io/otel/trace"
)

// TODO(jba): support auth
// TODO(jba): provide a way to start a Flow from user code.

// A Flow is a kind of Action that can be interrupted and resumed.
//
// A Flow[I, O, S] represents a function from I to O (the S parameter is described
// under "Streaming" below). But the function may run in pieces, with interruptions
Expand Down Expand Up @@ -62,25 +82,6 @@
//
// Streaming is only supported for the "start" flow instruction. Currently there is
// no way to schedule or resume a flow with streaming.
package genkit

import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"sync"
"time"

"github.com/google/uuid"
"go.opentelemetry.io/otel/trace"
)

// TODO(jba): support auth
// TODO(jba): provide a way to start a Flow from user code.

// A Flow is a kind of Action that can be interrupted and resumed.
type Flow[I, O, S any] struct {
name string // The last component of the flow's key in the registry.
fn Func[I, O, S] // The function to run.
Expand Down Expand Up @@ -110,6 +111,8 @@ func defineFlow[I, O, S any](r *registry, name string, fn Func[I, O, S]) *Flow[I
return f
}

// TODO(jba): use flowError?

// A flowInstruction is an instruction to follow with a flow.
// It is the input for the flow's action.
// Exactly one field will be non-nil.
Expand Down Expand Up @@ -162,8 +165,8 @@ type flowState[I, O any] struct {
FlowID string `json:"flowId,omitempty"`
FlowName string `json:"name,omitempty"`
// start time in milliseconds since the epoch
StartTime Milliseconds `json:"startTime,omitempty"`
Input I `json:"input,omitempty"`
StartTime gtime.Milliseconds `json:"startTime,omitempty"`
Input I `json:"input,omitempty"`

mu sync.Mutex
Cache map[string]json.RawMessage `json:"cache,omitempty"`
Expand All @@ -179,7 +182,7 @@ func newFlowState[I, O any](id, name string, input I) *flowState[I, O] {
FlowID: id,
FlowName: name,
Input: input,
StartTime: timeToMilliseconds(time.Now()),
StartTime: gtime.ToMilliseconds(time.Now()),
Cache: map[string]json.RawMessage{},
Operation: &Operation[O]{
FlowID: id,
Expand Down Expand Up @@ -283,7 +286,7 @@ func (f *Flow[I, O, S]) execute(ctx context.Context, state *flowState[I, O], dis
}()
ctx = flowContextKey.newContext(ctx, fctx)
exec := &FlowExecution{
StartTime: timeToMilliseconds(time.Now()),
StartTime: gtime.ToMilliseconds(time.Now()),
}
state.mu.Lock()
state.Executions = append(state.Executions, exec)
Expand All @@ -297,7 +300,7 @@ func (f *Flow[I, O, S]) execute(ctx context.Context, state *flowState[I, O], dis
spanMeta.SetAttr("flow:name", f.name)
spanMeta.SetAttr("flow:id", state.FlowID)
spanMeta.SetAttr("flow:dispatchType", dispatchType)
rootSpanContext := trace.SpanContextFromContext(ctx)
rootSpanContext := otrace.SpanContextFromContext(ctx)
traceID := rootSpanContext.TraceID().String()
exec.TraceIDs = append(exec.TraceIDs, traceID)
// TODO(jba): Save rootSpanContext in the state.
Expand Down Expand Up @@ -432,18 +435,18 @@ func Run[T any](ctx context.Context, name string, f func() (T, error)) (T, error
if ok {
var t T
if err := json.Unmarshal(j, &t); err != nil {
return zero[T](), err
return internal.Zero[T](), err
}
spanMeta.SetAttr("flow:state", "cached")
return t, nil
}
t, err := f()
if err != nil {
return zero[T](), err
return internal.Zero[T](), err
}
bytes, err := json.Marshal(t)
if err != nil {
return zero[T](), err
return internal.Zero[T](), err
}
fs.lock()
fs.cache()[uName] = json.RawMessage(bytes)
Expand All @@ -458,7 +461,7 @@ func Run[T any](ctx context.Context, name string, f func() (T, error)) (T, error
func RunFlow[I, O, S any](ctx context.Context, flow *Flow[I, O, S], input I) (O, error) {
state, err := flow.start(ctx, input, nil)
if err != nil {
return zero[O](), err
return internal.Zero[O](), err
}
return finishedOpResponse(state.Operation)
}
Expand Down Expand Up @@ -515,10 +518,10 @@ func StreamFlow[I, O, S any](ctx context.Context, flow *Flow[I, O, S], input I)

func finishedOpResponse[O any](op *Operation[O]) (O, error) {
if !op.Done {
return zero[O](), fmt.Errorf("flow %s did not finish execution", op.FlowID)
return internal.Zero[O](), fmt.Errorf("flow %s did not finish execution", op.FlowID)
}
if op.Result.Error != "" {
return zero[O](), fmt.Errorf("flow %s: %s", op.FlowID, op.Result.Error)
return internal.Zero[O](), fmt.Errorf("flow %s: %s", op.FlowID, op.Result.Error)
}
return op.Result.Response, nil
}
Loading

0 comments on commit f390858

Please sign in to comment.