Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Go] production server #233

Merged
merged 4 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 53 additions & 3 deletions go/genkit/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"sync"
"time"
Expand Down Expand Up @@ -112,6 +113,7 @@ func defineFlow[I, O, S any](r *registry, name string, fn Func[I, O, S]) *Flow[I
r.registerAction(ActionTypeFlow, name, a)
// TODO(jba): this is a roundabout way to transmit the tracing state. Is there a cleaner way?
f.tstate = a.tstate
r.registerFlow(f)
return f
}

Expand Down Expand Up @@ -231,7 +233,10 @@ type operation[O any] struct {
type FlowResult[O any] struct {
Response O `json:"response,omitempty"`
Error string `json:"error,omitempty"`
// TODO(jba): keep the actual error around so that RunFlow can use it.
// The Error field above is not used in the code, but it gets marshaled
// into JSON.
// TODO(jba): replace with a type that implements error and json.Marshaler.
err error
StackTrace string `json:"stacktrace,omitempty"`
}

Expand Down Expand Up @@ -273,6 +278,50 @@ func (f *Flow[I, O, S]) runInstruction(ctx context.Context, inst *flowInstructio
}
}

// flow is the type that all Flow[I, O, S] have in common.
type flow interface {
Name() string

// runJSON uses encoding/json to unmarshal the input,
// calls Flow.start, then returns the marshaled result.
runJSON(ctx context.Context, input json.RawMessage, cb StreamingCallback[json.RawMessage]) (json.RawMessage, error)
}

func (f *Flow[I, O, S]) Name() string { return f.name }

func (f *Flow[I, O, S]) runJSON(ctx context.Context, input json.RawMessage, cb StreamingCallback[json.RawMessage]) (json.RawMessage, error) {
var in I
if err := json.Unmarshal(input, &in); err != nil {
return nil, &httpError{http.StatusBadRequest, err}
}
// If there is a callback, wrap it to turn an S into a json.RawMessage.
var callback StreamingCallback[S]
if cb != nil {
callback = func(ctx context.Context, s S) error {
bytes, err := json.Marshal(s)
if err != nil {
return err
}
return cb(ctx, json.RawMessage(bytes))
}
}
fstate, err := f.start(ctx, in, callback)
if err != nil {
return nil, err
}
if fstate.Operation == nil {
return nil, errors.New("nil operation")
}
res := fstate.Operation.Result
if res == nil {
return nil, errors.New("nil result")
}
if res.err != nil {
return nil, res.err
}
return json.Marshal(res.Response)
}

// start starts executing the flow with the given input.
func (f *Flow[I, O, S]) start(ctx context.Context, input I, cb StreamingCallback[S]) (_ *flowState[I, O], err error) {
flowID, err := generateFlowID()
Expand Down Expand Up @@ -346,6 +395,7 @@ func (f *Flow[I, O, S]) execute(ctx context.Context, state *flowState[I, O], dis
state.Operation.Done = true
if err != nil {
state.Operation.Result = &FlowResult[O]{
err: err,
Error: err.Error(),
// TODO(jba): stack trace?
}
Expand Down Expand Up @@ -535,8 +585,8 @@ func finishedOpResponse[O any](op *operation[O]) (O, error) {
if !op.Done {
return internal.Zero[O](), fmt.Errorf("flow %s did not finish execution", op.FlowID)
}
if op.Result.Error != "" {
return internal.Zero[O](), fmt.Errorf("flow %s: %s", op.FlowID, op.Result.Error)
if op.Result.err != nil {
return internal.Zero[O](), fmt.Errorf("flow %s: %w", op.FlowID, op.Result.err)
}
return op.Result.Response, nil
}
9 changes: 7 additions & 2 deletions go/genkit/flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package genkit
import (
"context"
"encoding/json"
"errors"
"slices"
"testing"

Expand Down Expand Up @@ -46,7 +47,10 @@ func TestFlowStart(t *testing.T) {
Response: 2,
},
}
if diff := cmp.Diff(want, got, cmpopts.IgnoreFields(operation[int]{}, "FlowID")); diff != "" {
diff := cmp.Diff(want, got,
cmpopts.IgnoreFields(operation[int]{}, "FlowID"),
cmpopts.IgnoreUnexported(FlowResult[int]{}, flowState[int, int]{}))
if diff != "" {
t.Errorf("mismatch (-want, +got):\n%s", diff)
}
}
Expand Down Expand Up @@ -147,6 +151,7 @@ func TestFlowState(t *testing.T) {
Metadata: "meta",
Result: &FlowResult[int]{
Response: 6,
err: errors.New("err"),
Error: "err",
StackTrace: "st",
},
Expand All @@ -161,7 +166,7 @@ func TestFlowState(t *testing.T) {
if err := json.Unmarshal(data, &got); err != nil {
t.Fatal(err)
}
diff := cmp.Diff(fs, got, cmpopts.IgnoreUnexported(flowState[int, int]{}))
diff := cmp.Diff(fs, got, cmpopts.IgnoreUnexported(flowState[int, int]{}, FlowResult[int]{}))
if diff != "" {
t.Errorf("mismatch (-want, +got):\n%s", diff)
}
Expand Down
15 changes: 15 additions & 0 deletions go/genkit/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type registry struct {
tstate *tracing.State
mu sync.Mutex
actions map[string]action
flows []flow
// TraceStores, at most one for each [Environment].
// Only the prod trace store is actually registered; the dev one is
// always created automatically. But it's simpler if we keep them together here.
Expand Down Expand Up @@ -144,6 +145,20 @@ func (r *registry) listActions() []actionDesc {
return ads
}

// registerFlow stores the flow for use by the production server (see [NewFlowServeMux]).
// It doesn't check for duplicates because registerAction will do that.
func (r *registry) registerFlow(f flow) {
r.mu.Lock()
defer r.mu.Unlock()
r.flows = append(r.flows, f)
}

func (r *registry) listFlows() []flow {
r.mu.Lock()
defer r.mu.Unlock()
return r.flows
}

// RegisterTraceStore uses the given trace.Store to record traces in the prod environment.
// (A trace.Store that writes to the local filesystem is always installed in the dev environment.)
// The returned function should be called before the program ends to ensure that
Expand Down
Loading
Loading