Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Go] reorganize and document #200

Merged
merged 5 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ type ToolDefinition struct {
OutputSchema map[string]any `json:"outputSchema,omitempty"`
}

// A ToolRequest is a request from the model that the client should run
// a specific tool and pass a [ToolResponse] to the model on the next request it makes.
// A ToolRequest is a message from the model to the client that it should run a
// specific tool and pass a [ToolResponse] to the model on the next chat request it makes.
// Any ToolRequest will correspond to some [ToolDefinition] previously sent by the client.
type ToolRequest struct {
// Input is a JSON object describing the input values to the tool.
Expand All @@ -145,7 +145,7 @@ type ToolRequest struct {
Name string `json:"name,omitempty"`
}

// A ToolResponse is a response from the client to the model containing
// A ToolResponse is a message from the client to the model containing
// the results of running a specific tool on the arguments passed to the client
// by the model in a [ToolRequest].
type ToolResponse struct {
Expand Down
34 changes: 34 additions & 0 deletions go/common/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package common provides common functionality for Go Genkit.
package common
jba marked this conversation as resolved.
Show resolved Hide resolved

import (
"time"
)

// Milliseconds represents a time as the number of milliseconds since the Unix epoch.
type Milliseconds float64

func TimeToMilliseconds(t time.Time) Milliseconds {
jba marked this conversation as resolved.
Show resolved Hide resolved
nsec := t.UnixNano()
return Milliseconds(float64(nsec) / 1e6)
}

func (m Milliseconds) Time() time.Time {
sec := int64(m / 1e3)
nsec := int64((float64(m) - float64(sec*1e3)) * 1e6)
return time.Unix(sec, nsec)
}
37 changes: 37 additions & 0 deletions go/common/common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package common

import (
"testing"
"time"
)

func TestMilliseconds(t *testing.T) {
for _, tm := range []time.Time{
time.Unix(0, 0),
time.Unix(1, 0),
time.Unix(100, 554),
time.Date(2024, time.March, 24, 1, 2, 3, 4, time.UTC),
} {
m := TimeToMilliseconds(tm)
got := m.Time()
// Compare to the nearest millisecond. Due to the floating-point operations in the above
// two functions, we can't be sure that the round trip is more accurate than that.
if !got.Round(time.Millisecond).Equal(tm.Round(time.Millisecond)) {
t.Errorf("got %v, want %v", got, tm)
}
}
}
10 changes: 6 additions & 4 deletions go/genkit/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,21 @@ import (
"reflect"
"time"

"github.com/firebase/genkit/go/internal"
"github.com/invopop/jsonschema"
)

// Func is the type of function that Actions and Flows execute.
// It takes an input of type I and returns an output of type O, optionally
// streaming values of type S incrementally by invoking a callback.
// TODO(jba): use a generic type alias when they become available?
// If the StreamingCallback is non-nil and the function supports streaming, it should
// stream the results by invoking the callback periodically, ultimately returning
// with a final return value. Otherwise, it should ignore the StreamingCallback and
// just return a result.
type Func[I, O, S any] func(context.Context, I, StreamingCallback[S]) (O, error)

// TODO(jba): use a generic type alias for the above when they become available?

// StreamingCallback is the type of streaming callbacks, which is passed to action
// functions who should stream their responses.
type StreamingCallback[S any] func(context.Context, S) error
Expand All @@ -44,7 +46,7 @@ type StreamingCallback[S any] func(context.Context, S) error
// Such a function corresponds to a Flow[I, O, struct{}].
type NoStream = StreamingCallback[struct{}]

// An Action is a function with a name.
// An Action is a named, observable operation.
// It consists of a function that takes an input of type I and returns an output
// of type O, optionally streaming values of type S incrementally by invoking a callback.
// It optionally has other metadata, like a description
Expand Down Expand Up @@ -90,7 +92,7 @@ func (a *Action[I, O, S]) Name() string { return a.name }
// setTracingState sets the action's tracingState.
func (a *Action[I, O, S]) setTracingState(tstate *tracingState) { a.tstate = tstate }

// Run executes the Action's function in a new span.
// Run executes the Action's function in a new trace span.
func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb StreamingCallback[S]) (output O, err error) {
// TODO: validate input against JSONSchema for I.
// TODO: validate output against JSONSchema for O.
Expand All @@ -115,7 +117,7 @@ func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb StreamingCallback
latency := time.Since(start)
if err != nil {
writeActionFailure(ctx, a.name, latency, err)
return zero[O](), err
return internal.Zero[O](), err
}
writeActionSuccess(ctx, a.name, latency)
return out, nil
Expand Down
5 changes: 3 additions & 2 deletions go/genkit/conformance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"testing"
"time"

"github.com/firebase/genkit/go/internal"
"golang.org/x/exp/maps"
)

Expand Down Expand Up @@ -84,7 +85,7 @@ func TestFlowConformance(t *testing.T) {
for _, filename := range testFiles {
t.Run(strings.TrimSuffix(filepath.Base(filename), ".json"), func(t *testing.T) {
var test conformanceTest
if err := readJSONFile(filename, &test); err != nil {
if err := internal.ReadJSONFile(filename, &test); err != nil {
t.Fatal(err)
}
// Each test uses its own registry to avoid interference.
Expand All @@ -111,7 +112,7 @@ func TestFlowConformance(t *testing.T) {
}
ts := r.lookupTraceStore(EnvironmentDev)
var gotTrace any
if err := ts.loadAny(resp.Telemetry.TraceID, &gotTrace); err != nil {
if err := ts.LoadAny(resp.Telemetry.TraceID, &gotTrace); err != nil {
t.Fatal(err)
}
renameSpans(t, gotTrace)
Expand Down
11 changes: 6 additions & 5 deletions go/genkit/dev_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"syscall"
"time"

gtrace "github.com/firebase/genkit/go/trace"
"go.opentelemetry.io/otel/trace"
)

Expand Down Expand Up @@ -250,22 +251,22 @@ func (s *devServer) handleListTraces(w http.ResponseWriter, r *http.Request) err
}
}
ctoken := r.FormValue("continuationToken")
tds, ctoken, err := ts.List(r.Context(), &TraceQuery{Limit: limit, ContinuationToken: ctoken})
if errors.Is(err, errBadQuery) {
tds, ctoken, err := ts.List(r.Context(), &gtrace.Query{Limit: limit, ContinuationToken: ctoken})
if errors.Is(err, gtrace.ErrBadQuery) {
return &httpError{http.StatusBadRequest, err}
}
if err != nil {
return err
}
if tds == nil {
tds = []*TraceData{}
tds = []*gtrace.Data{}
}
return writeJSON(r.Context(), w, listTracesResult{tds, ctoken})
}

type listTracesResult struct {
Traces []*TraceData `json:"traces"`
ContinuationToken string `json:"continuationToken"`
Traces []*gtrace.Data `json:"traces"`
ContinuationToken string `json:"continuationToken"`
}

func (s *devServer) handleListFlowStates(w http.ResponseWriter, r *http.Request) error {
Expand Down
17 changes: 9 additions & 8 deletions go/genkit/dev_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"strings"
"testing"

gtrace "github.com/firebase/genkit/go/trace"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The impulse to use an import alias for each import suggests that the package name should actually be gtrace.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)
Expand Down Expand Up @@ -102,7 +103,7 @@ func TestDevServer(t *testing.T) {
if err != nil {
t.Fatal(err)
}
// We may have any result, including zero traces, so don't check anything else.
// We may have any result, including internal.Zero traces, so don't check anything else.
})
}

Expand All @@ -113,13 +114,13 @@ func checkActionTrace(t *testing.T, reg *registry, tid, name string) {
t.Fatal(err)
}
rootSpan := findRootSpan(t, td.Spans)
want := &SpanData{
want := &gtrace.SpanData{
TraceID: tid,
DisplayName: "dev-run-action-wrapper",
SpanKind: "INTERNAL",
SameProcessAsParentSpan: boolValue{Value: true},
Status: Status{Code: 0},
InstrumentationLibrary: InstrumentationLibrary{
SameProcessAsParentSpan: gtrace.BoolValue{Value: true},
Status: gtrace.Status{Code: 0},
InstrumentationLibrary: gtrace.InstrumentationLibrary{
Name: "genkit-tracer",
Version: "v1",
},
Expand All @@ -133,17 +134,17 @@ func checkActionTrace(t *testing.T, reg *registry, tid, name string) {
"genkit:state": "success",
},
}
diff := cmp.Diff(want, rootSpan, cmpopts.IgnoreFields(SpanData{}, "SpanID", "StartTime", "EndTime"))
diff := cmp.Diff(want, rootSpan, cmpopts.IgnoreFields(gtrace.SpanData{}, "SpanID", "StartTime", "EndTime"))
if diff != "" {
t.Errorf("mismatch (-want, +got):\n%s", diff)
}
}

// findRootSpan finds the root span in spans.
// It also verifies that it is unique.
func findRootSpan(t *testing.T, spans map[string]*SpanData) *SpanData {
func findRootSpan(t *testing.T, spans map[string]*gtrace.SpanData) *gtrace.SpanData {
t.Helper()
var root *SpanData
var root *gtrace.SpanData
for _, sd := range spans {
if sd.ParentSpanID == "" {
if root != nil {
Expand Down
6 changes: 4 additions & 2 deletions go/genkit/file_flow_state_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"context"
"os"
"path/filepath"

"github.com/firebase/genkit/go/internal"
)

// A FileFlowStateStore is a FlowStateStore that writes flowStates to files.
Expand All @@ -37,9 +39,9 @@ func NewFileFlowStateStore(dir string) (*FileFlowStateStore, error) {
func (s *FileFlowStateStore) Save(ctx context.Context, id string, fs flowStater) error {
fs.lock()
defer fs.unlock()
return writeJSONFile(filepath.Join(s.dir, clean(id)), fs)
return internal.WriteJSONFile(filepath.Join(s.dir, internal.Clean(id)), fs)
}

func (s *FileFlowStateStore) Load(ctx context.Context, id string, pfs any) error {
return readJSONFile(filepath.Join(s.dir, clean(id)), pfs)
return internal.ReadJSONFile(filepath.Join(s.dir, internal.Clean(id)), pfs)
}
Loading
Loading