Skip to content

Commit

Permalink
[Go] move most of genkit package into core (#244)
Browse files Browse the repository at this point in the history
This PR greatly simplifies the genkit package, limiting it to
the symbols that Genkit app developers, as exemplified by
the programs in the "samples" directory, would need.

To accomplish this, it moves most of the code to a new package, named
core. The core package, rather than the genkit package, is now imported
by the ai package and plugins. End-user applications should not normally
require it.

This overlapped with #229, so unfortunately those (minor) changes
are incorporated here as well, in a slightly different form.
  • Loading branch information
jba authored May 24, 2024
1 parent 5360a11 commit 43fe9bd
Show file tree
Hide file tree
Showing 29 changed files with 320 additions and 180 deletions.
5 changes: 2 additions & 3 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package ai
import (
"context"

"github.com/firebase/genkit/go/genkit"
"github.com/firebase/genkit/go/core"
)

// Embedder is the interface used to convert a document to a
Expand All @@ -35,6 +35,5 @@ type EmbedRequest struct {

// RegisterEmbedder registers the actions for a specific embedder.
func RegisterEmbedder(name string, embedder Embedder) {
genkit.RegisterAction(genkit.ActionTypeEmbedder, name,
genkit.NewAction(name, genkit.ActionTypeEmbedder, nil, embedder.Embed))
core.RegisterAction(name, core.NewAction(name, core.ActionTypeEmbedder, nil, embedder.Embed))
}
22 changes: 11 additions & 11 deletions go/ai/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"slices"
"strings"

"github.com/firebase/genkit/go/genkit"
"github.com/firebase/genkit/go/core"
)

// Generator is the interface used to query an AI model.
Expand All @@ -31,7 +31,7 @@ type Generator interface {
// populating the result's Candidates field.
// - If the streaming callback returns a non-nil error, generation will stop
// and Generate immediately returns that error (and a nil response).
Generate(context.Context, *GenerateRequest, genkit.StreamingCallback[*Candidate]) (*GenerateResponse, error)
Generate(context.Context, *GenerateRequest, func(context.Context, *Candidate) error) (*GenerateResponse, error)
}

// GeneratorCapabilities describes various capabilities of the generator.
Expand Down Expand Up @@ -63,16 +63,16 @@ func RegisterGenerator(provider, name string, metadata *GeneratorMetadata, gener
}
metadataMap["supports"] = supports
}
genkit.RegisterAction(genkit.ActionTypeModel, provider,
genkit.NewStreamingAction(name, genkit.ActionTypeModel, map[string]any{
core.RegisterAction(provider,
core.NewStreamingAction(name, core.ActionTypeModel, map[string]any{
"model": metadataMap,
}, generator.Generate))
}

// Generate applies a [Generator] to some input, handling tool requests.
func Generate(ctx context.Context, generator Generator, input *GenerateRequest, cb genkit.StreamingCallback[*Candidate]) (*GenerateResponse, error) {
func Generate(ctx context.Context, g Generator, input *GenerateRequest, cb func(context.Context, *Candidate) error) (*GenerateResponse, error) {
for {
resp, err := generator.Generate(ctx, input, cb)
resp, err := g.Generate(ctx, input, cb)
if err != nil {
return nil, err
}
Expand All @@ -89,14 +89,14 @@ func Generate(ctx context.Context, generator Generator, input *GenerateRequest,
}
}

// generatorActionType is the instantiated genkit.Action type registered
// generatorActionType is the instantiated core.Action type registered
// by RegisterGenerator.
type generatorActionType = genkit.Action[*GenerateRequest, *GenerateResponse, *Candidate]
type generatorActionType = core.Action[*GenerateRequest, *GenerateResponse, *Candidate]

// LookupGeneratorAction looks up an action registered by [RegisterGenerator]
// and returns a generator that invokes the action.
func LookupGeneratorAction(provider, name string) (Generator, error) {
action := genkit.LookupAction(genkit.ActionTypeModel, provider, name)
action := core.LookupAction(core.ActionTypeModel, provider, name)
if action == nil {
return nil, fmt.Errorf("LookupGeneratorAction: no generator action named %q/%q", provider, name)
}
Expand All @@ -113,9 +113,9 @@ type generatorAction struct {
}

// Generate implements Generator. This is like the [Generate] function,
// but invokes the [genkit.Action] rather than invoking the Generator
// but invokes the [core.Action] rather than invoking the Generator
// directly.
func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, cb genkit.StreamingCallback[*Candidate]) (*GenerateResponse, error) {
func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, cb func(context.Context, *Candidate) error) (*GenerateResponse, error) {
for {
resp, err := ga.action.Run(ctx, input, cb)
if err != nil {
Expand Down
10 changes: 5 additions & 5 deletions go/ai/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package ai
import (
"context"

"github.com/firebase/genkit/go/genkit"
"github.com/firebase/genkit/go/core"
)

// Retriever supports adding documents to a database, and
Expand Down Expand Up @@ -51,11 +51,11 @@ type RetrieverResponse struct {

// RegisterRetriever registers the actions for a specific retriever.
func RegisterRetriever(name string, retriever Retriever) {
genkit.RegisterAction(genkit.ActionTypeRetriever, name,
genkit.NewAction(name, genkit.ActionTypeRetriever, nil, retriever.Retrieve))
core.RegisterAction(name,
core.NewAction(name, core.ActionTypeRetriever, nil, retriever.Retrieve))

genkit.RegisterAction(genkit.ActionTypeIndexer, name,
genkit.NewAction(name, genkit.ActionTypeIndexer, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) {
core.RegisterAction(name,
core.NewAction(name, core.ActionTypeIndexer, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) {
err := retriever.Index(ctx, req)
return struct{}{}, err
}))
Expand Down
11 changes: 5 additions & 6 deletions go/ai/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"fmt"
"maps"

"github.com/firebase/genkit/go/genkit"
"github.com/firebase/genkit/go/core"
)

// A Tool is an implementation of a single tool.
Expand All @@ -42,18 +42,17 @@ func RegisterTool(name string, definition *ToolDefinition, metadata map[string]a
metadata["type"] = "tool"

// TODO: There is no provider for a tool.
genkit.RegisterAction(genkit.ActionTypeTool, "tool",
genkit.NewAction(definition.Name, genkit.ActionTypeTool, metadata, fn))
core.RegisterAction("tool", core.NewAction(definition.Name, core.ActionTypeTool, metadata, fn))
}

// toolActionType is the instantiated genkit.Action type registered
// toolActionType is the instantiated core.Action type registered
// by RegisterTool.
type toolActionType = genkit.Action[map[string]any, map[string]any, struct{}]
type toolActionType = core.Action[map[string]any, map[string]any, struct{}]

// RunTool looks up a tool registered by [RegisterTool],
// runs it with the given input, and returns the result.
func RunTool(ctx context.Context, name string, input map[string]any) (map[string]any, error) {
action := genkit.LookupAction(genkit.ActionTypeTool, "tool", name)
action := core.LookupAction(core.ActionTypeTool, "tool", name)
if action == nil {
return nil, fmt.Errorf("no tool named %q", name)
}
Expand Down
37 changes: 20 additions & 17 deletions go/genkit/action.go → go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package genkit
package core

import (
"context"
Expand All @@ -34,18 +34,16 @@ import (
// stream the results by invoking the callback periodically, ultimately returning
// with a final return value. Otherwise, it should ignore the StreamingCallback and
// just return a result.
type Func[I, O, S any] func(context.Context, I, StreamingCallback[S]) (O, error)
type Func[I, O, S any] func(context.Context, I, func(context.Context, S) error) (O, error)

// TODO(jba): use a generic type alias for the above when they become available?

// StreamingCallback is the type of streaming callbacks, which is passed to action
// functions who should stream their responses.
type StreamingCallback[S any] func(context.Context, S) error

// NoStream indicates that the action or flow does not support streaming.
// A Func[I, O, NoStream] will ignore its streaming callback.
// Such a function corresponds to a Flow[I, O, struct{}].
type NoStream = StreamingCallback[struct{}]
type NoStream = func(context.Context, struct{}) error

type streamingCallback[S any] func(context.Context, S) error

// An Action is a named, observable operation.
// It consists of a function that takes an input of type I and returns an output
Expand All @@ -56,6 +54,7 @@ type NoStream = StreamingCallback[struct{}]
// Each time an Action is run, it results in a new trace span.
type Action[I, O, S any] struct {
name string
atype ActionType
fn Func[I, O, S]
tstate *tracing.State
inputSchema *jsonschema.Schema
Expand All @@ -68,20 +67,21 @@ type Action[I, O, S any] struct {
// See js/common/src/types.ts

// NewAction creates a new Action with the given name and non-streaming function.
func NewAction[I, O any](name string, actionType ActionType, metadata map[string]any, fn func(context.Context, I) (O, error)) *Action[I, O, struct{}] {
return NewStreamingAction(name, actionType, metadata, func(ctx context.Context, in I, cb NoStream) (O, error) {
func NewAction[I, O any](name string, atype ActionType, metadata map[string]any, fn func(context.Context, I) (O, error)) *Action[I, O, struct{}] {
return NewStreamingAction(name, atype, metadata, func(ctx context.Context, in I, cb NoStream) (O, error) {
return fn(ctx, in)
})
}

// NewStreamingAction creates a new Action with the given name and streaming function.
func NewStreamingAction[I, O, S any](name string, actionType ActionType, metadata map[string]any, fn Func[I, O, S]) *Action[I, O, S] {
func NewStreamingAction[I, O, S any](name string, atype ActionType, metadata map[string]any, fn Func[I, O, S]) *Action[I, O, S] {
var i I
var o O
return &Action[I, O, S]{
name: name,
fn: func(ctx context.Context, input I, sc StreamingCallback[S]) (O, error) {
tracing.SetCustomMetadataAttr(ctx, "subtype", string(actionType))
name: name,
atype: atype,
fn: func(ctx context.Context, input I, sc func(context.Context, S) error) (O, error) {
tracing.SetCustomMetadataAttr(ctx, "subtype", string(atype))
return fn(ctx, input, sc)
},
inputSchema: inferJSONSchema(i),
Expand All @@ -93,11 +93,13 @@ func NewStreamingAction[I, O, S any](name string, actionType ActionType, metadat
// Name returns the Action's name.
func (a *Action[I, O, S]) Name() string { return a.name }

func (a *Action[I, O, S]) actionType() ActionType { return a.atype }

// setTracingState sets the action's tracing.State.
func (a *Action[I, O, S]) setTracingState(tstate *tracing.State) { a.tstate = tstate }

// Run executes the Action's function in a new trace span.
func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb StreamingCallback[S]) (output O, err error) {
func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb func(context.Context, S) error) (output O, err error) {
// TODO: validate input against JSONSchema for I.
// TODO: validate output against JSONSchema for O.
internal.Logger(ctx).Debug("Action.Run",
Expand Down Expand Up @@ -128,12 +130,12 @@ func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb StreamingCallback
})
}

func (a *Action[I, O, S]) runJSON(ctx context.Context, input json.RawMessage, cb StreamingCallback[json.RawMessage]) (json.RawMessage, error) {
func (a *Action[I, O, S]) runJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) {
var in I
if err := json.Unmarshal(input, &in); err != nil {
return nil, err
}
var callback StreamingCallback[S]
var callback func(context.Context, S) error
if cb != nil {
callback = func(ctx context.Context, s S) error {
bytes, err := json.Marshal(s)
Expand All @@ -157,10 +159,11 @@ func (a *Action[I, O, S]) runJSON(ctx context.Context, input json.RawMessage, cb
// action is the type that all Action[I, O, S] have in common.
type action interface {
Name() string
actionType() ActionType

// runJSON uses encoding/json to unmarshal the input,
// calls Action.Run, then returns the marshaled result.
runJSON(ctx context.Context, input json.RawMessage, cb StreamingCallback[json.RawMessage]) (json.RawMessage, error)
runJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error)

// desc returns a description of the action.
// It should set all fields of actionDesc except Key, which
Expand Down
4 changes: 2 additions & 2 deletions go/genkit/action_test.go → go/core/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package genkit
package core

import (
"bytes"
Expand Down Expand Up @@ -55,7 +55,7 @@ func TestNewAction(t *testing.T) {
}

// count streams the numbers from 0 to n-1, then returns n.
func count(ctx context.Context, n int, cb StreamingCallback[int]) (int, error) {
func count(ctx context.Context, n int, cb func(context.Context, int) error) (int, error) {
if cb != nil {
for i := 0; i < n; i++ {
if err := cb(ctx, i); err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package genkit
package core

import (
"cmp"
Expand Down Expand Up @@ -82,6 +82,9 @@ func TestFlowConformance(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if len(testFiles) == 0 {
t.Fatal("did not find any test files")
}
for _, filename := range testFiles {
t.Run(strings.TrimSuffix(filepath.Base(filename), ".json"), func(t *testing.T) {
var test conformanceTest
Expand Down
18 changes: 18 additions & 0 deletions go/core/core.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package core implements Genkit actions, flows and other essential machinery.
// This package is primarily intended for genkit internals and for plugins.
// Applications using genkit should use the genkit package.
package core
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package genkit
package core

import (
"context"
Expand Down
Loading

0 comments on commit 43fe9bd

Please sign in to comment.