From d42c71a3955a41992d7df2bec5f9641a92088ada Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Thu, 23 May 2024 09:11:47 -0400 Subject: [PATCH] Added action type and subtype span metadata to actions --- go/ai/embedder.go | 2 +- go/ai/generator.go | 2 +- go/ai/retriever.go | 4 ++-- go/ai/tools.go | 2 +- go/genkit/action.go | 13 ++++++++----- go/genkit/action_test.go | 10 +++++----- go/genkit/dev_server_test.go | 4 ++-- go/genkit/dotprompt/genkit.go | 2 +- go/genkit/flow.go | 2 +- go/genkit/registry.go | 1 + 10 files changed, 23 insertions(+), 19 deletions(-) diff --git a/go/ai/embedder.go b/go/ai/embedder.go index 665738cd5..3d07b1a91 100644 --- a/go/ai/embedder.go +++ b/go/ai/embedder.go @@ -36,5 +36,5 @@ type EmbedRequest struct { // RegisterEmbedder registers the actions for a specific embedder. func RegisterEmbedder(name string, embedder Embedder) { genkit.RegisterAction(genkit.ActionTypeEmbedder, name, - genkit.NewAction(name, nil, embedder.Embed)) + genkit.NewAction(name, genkit.ActionTypeEmbedder, nil, embedder.Embed)) } diff --git a/go/ai/generator.go b/go/ai/generator.go index c1dc3ab53..4295cb8aa 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -64,7 +64,7 @@ func RegisterGenerator(provider, name string, metadata *GeneratorMetadata, gener metadataMap["supports"] = supports } genkit.RegisterAction(genkit.ActionTypeModel, provider, - genkit.NewStreamingAction(name, map[string]any{ + genkit.NewStreamingAction(name, genkit.ActionTypeModel, map[string]any{ "model": metadataMap, }, generator.Generate)) } diff --git a/go/ai/retriever.go b/go/ai/retriever.go index 19bb3def8..d4f857b72 100644 --- a/go/ai/retriever.go +++ b/go/ai/retriever.go @@ -52,10 +52,10 @@ type RetrieverResponse struct { // RegisterRetriever registers the actions for a specific retriever. func RegisterRetriever(name string, retriever Retriever) { genkit.RegisterAction(genkit.ActionTypeRetriever, name, - genkit.NewAction(name, nil, retriever.Retrieve)) + genkit.NewAction(name, genkit.ActionTypeRetriever, nil, retriever.Retrieve)) genkit.RegisterAction(genkit.ActionTypeIndexer, name, - genkit.NewAction(name, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) { + genkit.NewAction(name, genkit.ActionTypeIndexer, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) { err := retriever.Index(ctx, req) return struct{}{}, err })) diff --git a/go/ai/tools.go b/go/ai/tools.go index 0a16ad6ed..d584e0d59 100644 --- a/go/ai/tools.go +++ b/go/ai/tools.go @@ -43,7 +43,7 @@ func RegisterTool(name string, definition *ToolDefinition, metadata map[string]a // TODO: There is no provider for a tool. genkit.RegisterAction(genkit.ActionTypeTool, "tool", - genkit.NewAction(definition.Name, metadata, fn)) + genkit.NewAction(definition.Name, genkit.ActionTypeTool, metadata, fn)) } // toolActionType is the instantiated genkit.Action type registered diff --git a/go/genkit/action.go b/go/genkit/action.go index f9c1774e0..1880c55c6 100644 --- a/go/genkit/action.go +++ b/go/genkit/action.go @@ -68,19 +68,22 @@ type Action[I, O, S any] struct { // See js/common/src/types.ts // NewAction creates a new Action with the given name and non-streaming function. -func NewAction[I, O any](name string, metadata map[string]any, fn func(context.Context, I) (O, error)) *Action[I, O, struct{}] { - return NewStreamingAction(name, metadata, func(ctx context.Context, in I, cb NoStream) (O, error) { +func NewAction[I, O any](name string, actionType ActionType, metadata map[string]any, fn func(context.Context, I) (O, error)) *Action[I, O, struct{}] { + return NewStreamingAction(name, actionType, metadata, func(ctx context.Context, in I, cb NoStream) (O, error) { return fn(ctx, in) }) } // NewStreamingAction creates a new Action with the given name and streaming function. -func NewStreamingAction[I, O, S any](name string, metadata map[string]any, fn Func[I, O, S]) *Action[I, O, S] { +func NewStreamingAction[I, O, S any](name string, actionType ActionType, metadata map[string]any, fn Func[I, O, S]) *Action[I, O, S] { var i I var o O return &Action[I, O, S]{ - name: name, - fn: fn, + name: name, + fn: func(ctx context.Context, input I, sc StreamingCallback[S]) (O, error) { + tracing.SetCustomMetadataAttr(ctx, "subtype", string(actionType)) + return fn(ctx, input, sc) + }, inputSchema: inferJSONSchema(i), outputSchema: inferJSONSchema(o), Metadata: metadata, diff --git a/go/genkit/action_test.go b/go/genkit/action_test.go index d7e23bae1..ae28a4322 100644 --- a/go/genkit/action_test.go +++ b/go/genkit/action_test.go @@ -26,7 +26,7 @@ func inc(_ context.Context, x int) (int, error) { } func TestActionRun(t *testing.T) { - a := NewAction("inc", nil, inc) + a := NewAction("inc", ActionTypeCustom, nil, inc) got, err := a.Run(context.Background(), 3, nil) if err != nil { t.Fatal(err) @@ -37,7 +37,7 @@ func TestActionRun(t *testing.T) { } func TestActionRunJSON(t *testing.T) { - a := NewAction("inc", nil, inc) + a := NewAction("inc", ActionTypeCustom, nil, inc) input := []byte("3") want := []byte("4") got, err := a.runJSON(context.Background(), input, nil) @@ -51,7 +51,7 @@ func TestActionRunJSON(t *testing.T) { func TestNewAction(t *testing.T) { // Verify that struct{} can occur in the function signature. - _ = NewAction("f", nil, func(context.Context, int) (struct{}, error) { return struct{}{}, nil }) + _ = NewAction("f", ActionTypeCustom, nil, func(context.Context, int) (struct{}, error) { return struct{}{}, nil }) } // count streams the numbers from 0 to n-1, then returns n. @@ -68,7 +68,7 @@ func count(ctx context.Context, n int, cb StreamingCallback[int]) (int, error) { func TestActionStreaming(t *testing.T) { ctx := context.Background() - a := NewStreamingAction("count", nil, count) + a := NewStreamingAction("count", ActionTypeCustom, nil, count) const n = 3 // Non-streaming. @@ -101,7 +101,7 @@ func TestActionStreaming(t *testing.T) { func TestActionTracing(t *testing.T) { ctx := context.Background() const actionName = "TestTracing-inc" - a := NewAction(actionName, nil, inc) + a := NewAction(actionName, ActionTypeCustom, nil, inc) if _, err := a.Run(context.Background(), 3, nil); err != nil { t.Fatal(err) } diff --git a/go/genkit/dev_server_test.go b/go/genkit/dev_server_test.go index 77238dc50..0f9c58542 100644 --- a/go/genkit/dev_server_test.go +++ b/go/genkit/dev_server_test.go @@ -38,10 +38,10 @@ func TestDevServer(t *testing.T) { if err != nil { t.Fatal(err) } - r.registerAction("test", "devServer", NewAction("inc", map[string]any{ + r.registerAction("test", "devServer", NewAction("inc", ActionTypeCustom, map[string]any{ "foo": "bar", }, inc)) - r.registerAction("test", "devServer", NewAction("dec", map[string]any{ + r.registerAction("test", "devServer", NewAction("dec", ActionTypeCustom, map[string]any{ "bar": "baz", }, dec)) srv := httptest.NewServer(newDevServerMux(r)) diff --git a/go/genkit/dotprompt/genkit.go b/go/genkit/dotprompt/genkit.go index fc4c9173a..5fc3c860f 100644 --- a/go/genkit/dotprompt/genkit.go +++ b/go/genkit/dotprompt/genkit.go @@ -130,7 +130,7 @@ func (p *Prompt) Action() (*genkit.Action[*ActionInput, *ai.GenerateResponse, st name += "." + p.Variant } - a := genkit.NewAction(name, nil, p.Execute) + a := genkit.NewAction(name, genkit.ActionTypePrompt, nil, p.Execute) a.Metadata = map[string]any{ "type": "prompt", "prompt": p, diff --git a/go/genkit/flow.go b/go/genkit/flow.go index 86c330424..b4654a456 100644 --- a/go/genkit/flow.go +++ b/go/genkit/flow.go @@ -239,7 +239,7 @@ type FlowResult[O any] struct { // action creates an action for the flow. See the comment at the top of this file for more information. func (f *Flow[I, O, S]) action() *Action[*flowInstruction[I], *flowState[I, O], S] { - return NewStreamingAction(f.name, nil, func(ctx context.Context, inst *flowInstruction[I], cb StreamingCallback[S]) (*flowState[I, O], error) { + return NewStreamingAction(f.name, ActionTypeFlow, nil, func(ctx context.Context, inst *flowInstruction[I], cb StreamingCallback[S]) (*flowState[I, O], error) { tracing.SpanMetaKey.FromContext(ctx).SetAttr("flow:wrapperAction", "true") return f.runInstruction(ctx, inst, cb) }) diff --git a/go/genkit/registry.go b/go/genkit/registry.go index 50ac53d9c..1c540a395 100644 --- a/go/genkit/registry.go +++ b/go/genkit/registry.go @@ -89,6 +89,7 @@ const ( ActionTypeModel ActionType = "model" ActionTypePrompt ActionType = "prompt" ActionTypeTool ActionType = "tool" + ActionTypeCustom ActionType = "custom" ) // RegisterAction records the action in the global registry.