Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Go] added action type and subtype span metadata to actions #229

Merged
merged 2 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ type EmbedRequest struct {
// RegisterEmbedder registers the actions for a specific embedder.
func RegisterEmbedder(name string, embedder Embedder) {
genkit.RegisterAction(genkit.ActionTypeEmbedder, name,
genkit.NewAction(name, nil, embedder.Embed))
genkit.NewAction(name, genkit.ActionTypeEmbedder, nil, embedder.Embed))
}
2 changes: 1 addition & 1 deletion go/ai/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func RegisterGenerator(provider, name string, metadata *GeneratorMetadata, gener
metadataMap["supports"] = supports
}
genkit.RegisterAction(genkit.ActionTypeModel, provider,
genkit.NewStreamingAction(name, map[string]any{
genkit.NewStreamingAction(name, genkit.ActionTypeModel, map[string]any{
"model": metadataMap,
}, generator.Generate))
}
Expand Down
4 changes: 2 additions & 2 deletions go/ai/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ type RetrieverResponse struct {
// RegisterRetriever registers the actions for a specific retriever.
func RegisterRetriever(name string, retriever Retriever) {
genkit.RegisterAction(genkit.ActionTypeRetriever, name,
genkit.NewAction(name, nil, retriever.Retrieve))
genkit.NewAction(name, genkit.ActionTypeRetriever, nil, retriever.Retrieve))

genkit.RegisterAction(genkit.ActionTypeIndexer, name,
genkit.NewAction(name, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) {
genkit.NewAction(name, genkit.ActionTypeIndexer, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) {
err := retriever.Index(ctx, req)
return struct{}{}, err
}))
Expand Down
2 changes: 1 addition & 1 deletion go/ai/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func RegisterTool(name string, definition *ToolDefinition, metadata map[string]a

// TODO: There is no provider for a tool.
genkit.RegisterAction(genkit.ActionTypeTool, "tool",
genkit.NewAction(definition.Name, metadata, fn))
genkit.NewAction(definition.Name, genkit.ActionTypeTool, metadata, fn))
}

// toolActionType is the instantiated genkit.Action type registered
Expand Down
13 changes: 8 additions & 5 deletions go/genkit/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,22 @@ type Action[I, O, S any] struct {
// See js/common/src/types.ts

// NewAction creates a new Action with the given name and non-streaming function.
func NewAction[I, O any](name string, metadata map[string]any, fn func(context.Context, I) (O, error)) *Action[I, O, struct{}] {
return NewStreamingAction(name, metadata, func(ctx context.Context, in I, cb NoStream) (O, error) {
func NewAction[I, O any](name string, actionType ActionType, metadata map[string]any, fn func(context.Context, I) (O, error)) *Action[I, O, struct{}] {
return NewStreamingAction(name, actionType, metadata, func(ctx context.Context, in I, cb NoStream) (O, error) {
return fn(ctx, in)
})
}

// NewStreamingAction creates a new Action with the given name and streaming function.
func NewStreamingAction[I, O, S any](name string, metadata map[string]any, fn Func[I, O, S]) *Action[I, O, S] {
func NewStreamingAction[I, O, S any](name string, actionType ActionType, metadata map[string]any, fn Func[I, O, S]) *Action[I, O, S] {
var i I
var o O
return &Action[I, O, S]{
name: name,
fn: fn,
name: name,
fn: func(ctx context.Context, input I, sc StreamingCallback[S]) (O, error) {
tracing.SetCustomMetadataAttr(ctx, "subtype", string(actionType))
return fn(ctx, input, sc)
},
inputSchema: inferJSONSchema(i),
outputSchema: inferJSONSchema(o),
Metadata: metadata,
Expand Down
10 changes: 5 additions & 5 deletions go/genkit/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func inc(_ context.Context, x int) (int, error) {
}

func TestActionRun(t *testing.T) {
a := NewAction("inc", nil, inc)
a := NewAction("inc", ActionTypeCustom, nil, inc)
got, err := a.Run(context.Background(), 3, nil)
if err != nil {
t.Fatal(err)
Expand All @@ -37,7 +37,7 @@ func TestActionRun(t *testing.T) {
}

func TestActionRunJSON(t *testing.T) {
a := NewAction("inc", nil, inc)
a := NewAction("inc", ActionTypeCustom, nil, inc)
input := []byte("3")
want := []byte("4")
got, err := a.runJSON(context.Background(), input, nil)
Expand All @@ -51,7 +51,7 @@ func TestActionRunJSON(t *testing.T) {

func TestNewAction(t *testing.T) {
// Verify that struct{} can occur in the function signature.
_ = NewAction("f", nil, func(context.Context, int) (struct{}, error) { return struct{}{}, nil })
_ = NewAction("f", ActionTypeCustom, nil, func(context.Context, int) (struct{}, error) { return struct{}{}, nil })
}

// count streams the numbers from 0 to n-1, then returns n.
Expand All @@ -68,7 +68,7 @@ func count(ctx context.Context, n int, cb StreamingCallback[int]) (int, error) {

func TestActionStreaming(t *testing.T) {
ctx := context.Background()
a := NewStreamingAction("count", nil, count)
a := NewStreamingAction("count", ActionTypeCustom, nil, count)
const n = 3

// Non-streaming.
Expand Down Expand Up @@ -101,7 +101,7 @@ func TestActionStreaming(t *testing.T) {
func TestActionTracing(t *testing.T) {
ctx := context.Background()
const actionName = "TestTracing-inc"
a := NewAction(actionName, nil, inc)
a := NewAction(actionName, ActionTypeCustom, nil, inc)
if _, err := a.Run(context.Background(), 3, nil); err != nil {
t.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion go/genkit/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ func (f *Flow[I, O, S]) action() *Action[*flowInstruction[I], *flowState[I, O],
"inputSchema": inferJSONSchema(i),
"outputSchema": inferJSONSchema(o),
}
return NewStreamingAction(f.name, metadata, func(ctx context.Context, inst *flowInstruction[I], cb StreamingCallback[S]) (*flowState[I, O], error) {
return NewStreamingAction(f.name, ActionTypeFlow, metadata, func(ctx context.Context, inst *flowInstruction[I], cb StreamingCallback[S]) (*flowState[I, O], error) {
tracing.SpanMetaKey.FromContext(ctx).SetAttr("flow:wrapperAction", "true")
return f.runInstruction(ctx, inst, cb)
})
Expand Down
1 change: 1 addition & 0 deletions go/genkit/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ const (
ActionTypeModel ActionType = "model"
ActionTypePrompt ActionType = "prompt"
ActionTypeTool ActionType = "tool"
ActionTypeCustom ActionType = "custom"
)

// RegisterAction records the action in the global registry.
Expand Down
4 changes: 2 additions & 2 deletions go/genkit/servers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ func TestDevServer(t *testing.T) {
if err != nil {
t.Fatal(err)
}
r.registerAction("test", "devServer", NewAction("inc", map[string]any{
r.registerAction("test", "devServer", NewAction("inc", ActionTypeCustom, map[string]any{
"foo": "bar",
}, inc))
r.registerAction("test", "devServer", NewAction("dec", map[string]any{
r.registerAction("test", "devServer", NewAction("dec", ActionTypeCustom, map[string]any{
"bar": "baz",
}, dec))
srv := httptest.NewServer(newDevServeMux(r))
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/dotprompt/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func (p *Prompt) Action() (*genkit.Action[*ActionInput, *ai.GenerateResponse, st
name += "." + p.Variant
}

a := genkit.NewAction(name, nil, p.Execute)
a := genkit.NewAction(name, genkit.ActionTypePrompt, nil, p.Execute)
a.Metadata = map[string]any{
"type": "prompt",
"prompt": p,
Expand Down
Loading