Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 92 additions & 98 deletions go/genkit/dotprompt/dotprompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,50 @@ type Prompt struct {
// The name of the prompt. Optional unless the prompt is
// registered as an action.
Name string
// The name of the prompt variant, if any.
// Variants can be used for easy testing of a tweaked prompt.
Variant string

// A hash of the prompt contents.
Hash string
Config

// The template for the prompt.
Template *raymond.Template

// The prompt YAML frontmatter.
Frontmatter *Frontmatter
// A hash of the prompt contents.
hash string

// A Generator to use. If not nil, this is used by the
// [genkit.Action] returned by [Prompt.Action] to execute the prompt.
generator ai.Generator
}

// Config is optional configuration for a [Prompt].
type Config struct {
// The prompt variant.
Variant string
// The name of the model for which the prompt is input.
Model string

// TODO(iant): document
Tools []*ai.ToolDefinition

// Number of candidates to generate when passing the prompt
// to a generator. If 0, uses 1.
Candidates int

// Details for the model.
GenerationConfig *ai.GenerationCommonConfig

InputSchema *jsonschema.Schema // schema for input variables
VariableDefaults map[string]any // default input variable values

// Desired output format.
OutputFormat ai.OutputFormat

// Desired output schema, for JSON output.
OutputSchema map[string]any // TODO: use *jsonschema.Schema

// Arbitrary metadata.
Metadata map[string]any
}

// Open opens and parses a dotprompt file.
// The name is a base file name, without the ".prompt" extension.
func Open(name string) (*Prompt, error) {
Expand Down Expand Up @@ -114,44 +140,9 @@ func OpenVariant(name, variant string) (*Prompt, error) {
return Parse(name, variant, data)
}

// Frontmatter is the data we may see, YAML encoded, at the
// start of a dotprompt file. It appears within --- lines.
// These fields are optional.
type Frontmatter struct {
// The name of the prompt.
Name string
// The prompt variant.
Variant string
// The name of the model for which the prompt is input.
Model string

// TODO(iant): document
Tools []*ai.ToolDefinition

// Number of candidates to generate when passing the prompt
// to a generator. If 0, uses 1.
Candidates int

// Details for the model.
Config *ai.GenerationCommonConfig

// Description of input data.
Input FrontmatterInput

// Desired output format.
Output *ai.GenerateRequestOutput

// Arbitrary metadata.
Metadata map[string]any
}

// FrontmatterInput describes the input data.
type FrontmatterInput struct {
Schema *jsonschema.Schema // schema for input variables
Default map[string]any // default input variables
}

// frontmatterYAML is the type we use to unpack the frontmatter.
// (Frontmatter is the data we may see, YAML encoded, at the
// start of a dotprompt file. It appears within --- lines.)
// We do it this way so that we can handle the input and output
// fields as picoschema, while returning them as jsonschema.Schema.
type frontmatterYAML struct {
Expand All @@ -175,76 +166,75 @@ type frontmatterYAML struct {
// Parse parses the contents of a dotprompt file.
func Parse(name, variant string, data []byte) (*Prompt, error) {
const header = "---\n"
var front *Frontmatter
var fmName string
var cfg Config
if bytes.HasPrefix(data, []byte(header)) {
var err error
front, data, err = parseFrontmatter(data[len(header):])
fmName, cfg, data, err = parseFrontmatter(data[len(header):])
if err != nil {
return nil, err
}
}
// The name argument takes precedence over the name in the frontmatter.
if name == "" {
name = fmName
}

return define(name, variant, front, fmt.Sprintf("%02x", sha256.Sum256(data)), string(data), nil)
return newPrompt(name, string(data), fmt.Sprintf("%02x", sha256.Sum256(data)), cfg)
}

// define defines a new prompt from frontmatter and template text.
func define(name, variant string, frontmatter *Frontmatter, hash, text string, generator ai.Generator) (*Prompt, error) {
template, err := raymond.Parse(text)
// newPrompt creates a new prompt.
// templateText should be a handlebars template.
// hash is its SHA256 hash as a hex string.
func newPrompt(name, templateText, hash string, config Config) (*Prompt, error) {
template, err := raymond.Parse(templateText)
if err != nil {
return nil, fmt.Errorf("failed to parse template: %w", err)
}
template.RegisterHelpers(templateHelpers)

prompt := &Prompt{
Name: name,
Variant: variant,
Hash: hash,
Template: template,
Frontmatter: frontmatter,
generator: generator,
}
return prompt, nil
return &Prompt{
Name: name,
Config: config,
hash: hash,
Template: template,
}, nil
}

// parseFrontmatter parses the initial YAML frontmatter of a dotprompt file.
// Along with the frontmatter itself, it returns the remaining data.
func parseFrontmatter(data []byte) (*Frontmatter, []byte, error) {
// It returns the frontmatter as a Config along with the remaining data.
func parseFrontmatter(data []byte) (name string, c Config, rest []byte, err error) {
const footer = "\n---\n"
end := bytes.Index(data, []byte(footer))
if end == -1 {
return nil, nil, errors.New("dotprompt: missing marker for end of frontmatter")
return "", Config{}, nil, errors.New("dotprompt: missing marker for end of frontmatter")
}
input := data[:end]
var fy frontmatterYAML
if err := yaml.Unmarshal(input, &fy); err != nil {
return nil, nil, fmt.Errorf("dotprompt: failed to parse YAML frontmatter: %w", err)
return "", Config{}, nil, fmt.Errorf("dotprompt: failed to parse YAML frontmatter: %w", err)
}

ret := &Frontmatter{
Name: fy.Name,
Variant: fy.Variant,
Model: fy.Model,
Tools: fy.Tools,
Candidates: fy.Candidates,
Config: fy.Config,
Metadata: fy.Metadata,
ret := Config{
Variant: fy.Variant,
Model: fy.Model,
Tools: fy.Tools,
Candidates: fy.Candidates,
GenerationConfig: fy.Config,
VariableDefaults: fy.Input.Default,
Metadata: fy.Metadata,
}

inputSchema, err := picoschemaToJSONSchema(fy.Input.Schema)
if err != nil {
return nil, nil, fmt.Errorf("dotprompt: can't parse input: %w", err)
}
ret.Input = FrontmatterInput{
Schema: inputSchema,
Default: fy.Input.Default,
return "", Config{}, nil, fmt.Errorf("dotprompt: can't parse input: %w", err)
}
ret.InputSchema = inputSchema

outputSchema, err := picoschemaToJSONSchema(fy.Output.Schema)
if err != nil {
return nil, nil, fmt.Errorf("dotprompt: can't parse output: %w", err)
return "", Config{}, nil, fmt.Errorf("dotprompt: can't parse output: %w", err)
}

var generateOutputSchema map[string]any
if outputSchema != nil {
// We have a jsonschema.Schema and we want a map[string]any.
// TODO(iant): This conversion is useless.
Expand All @@ -255,43 +245,47 @@ func parseFrontmatter(data []byte) (*Frontmatter, []byte, error) {

data, err := json.Marshal(outputSchema)
if err != nil {
return nil, nil, fmt.Errorf("dotprompt: can't JSON marshal JSON schema: %w", err)
return "", Config{}, nil, fmt.Errorf("dotprompt: can't JSON marshal JSON schema: %w", err)
}
if err := json.Unmarshal(data, &generateOutputSchema); err != nil {
return nil, nil, fmt.Errorf("dotprompt: can't unmarshal JSON schema: %w", err)
if err := json.Unmarshal(data, &ret.OutputSchema); err != nil {
return "", Config{}, nil, fmt.Errorf("dotprompt: can't unmarshal JSON schema: %w", err)
}
}

// TODO(iant): The TypeScript codes supports media also,
// but there is no ai.OutputFormatMedia.
var of ai.OutputFormat
switch fy.Output.Format {
case "":
case string(ai.OutputFormatJSON):
of = ai.OutputFormatJSON
ret.OutputFormat = ai.OutputFormatJSON
case string(ai.OutputFormatText):
of = ai.OutputFormatText
ret.OutputFormat = ai.OutputFormatText
default:
return nil, nil, fmt.Errorf("dotprompt: unrecognized output format %q", fy.Output.Format)
return "", Config{}, nil, fmt.Errorf("dotprompt: unrecognized output format %q", fy.Output.Format)
}
return fy.Name, ret, data[end+len(footer):], nil
}

if of != "" || generateOutputSchema != nil {
ret.Output = &ai.GenerateRequestOutput{
Format: of,
Schema: generateOutputSchema,
}
// Define creates and registers a new Prompt. This can be called from code that
// doesn't have a prompt file.
func Define(name, templateText string, cfg *Config) (*Prompt, error) {
p, err := New(name, templateText, cfg)
if err != nil {
return nil, err
}

return ret, data[end+len(footer):], nil
p.Register()
return p, nil
}

// Define defines a new Prompt. This can be called from code that
// doesn't have a prompt file. If not nil, the generator argument
// will implement the AI model given the substituted prompt text.
// New creates a new Prompt without registering it.
// This may be used for testing or for direct calls not using the
// genkit action and flow mechanisms.
func Define(name string, frontmatter *Frontmatter, text string, generator ai.Generator) (*Prompt, error) {
return define(name, "", frontmatter, fmt.Sprintf("%02x", sha256.Sum256([]byte(text))), text, generator)
func New(name, templateText string, cfg *Config) (*Prompt, error) {
if cfg == nil {
cfg = &Config{}
}
hash := fmt.Sprintf("%02x", sha256.Sum256([]byte(templateText)))
return newPrompt(name, templateText, hash, *cfg)
}

// sortSchemaSlices sorts the slices in a jsonschema to permit
Expand Down
12 changes: 6 additions & 6 deletions go/genkit/dotprompt/dotprompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,23 @@ func TestPrompts(t *testing.T) {
t.Fatal(err)
}

if prompt.Frontmatter.Model != test.model {
t.Errorf("got model %q want %q", prompt.Frontmatter.Model, test.model)
if prompt.Model != test.model {
t.Errorf("got model %q want %q", prompt.Model, test.model)
}
if diff := cmpSchema(t, prompt.Frontmatter.Input.Schema, test.input); diff != "" {
if diff := cmpSchema(t, prompt.InputSchema, test.input); diff != "" {
t.Errorf("input schema mismatch (-want, +got):\n%s", diff)
}

if test.output == "" {
if prompt.Frontmatter.Output != nil && prompt.Frontmatter.Output.Schema != nil {
t.Errorf("unexpected output schema: %v", prompt.Frontmatter.Output.Schema)
if prompt.OutputSchema != nil {
t.Errorf("unexpected output schema: %v", prompt.OutputSchema)
}
} else {
var output map[string]any
if err := json.Unmarshal([]byte(test.output), &output); err != nil {
t.Fatalf("JSON unmarshal of %q failed: %v", test.output, err)
}
if diff := cmp.Diff(output, prompt.Frontmatter.Output.Schema); diff != "" {
if diff := cmp.Diff(output, prompt.OutputSchema); diff != "" {
t.Errorf("output schema mismatch (-want, +got):\n%s", diff)
}
}
Expand Down
24 changes: 8 additions & 16 deletions go/genkit/dotprompt/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,24 +98,19 @@ func (p *Prompt) buildRequest(input *ActionInput) (*ai.GenerateRequest, error) {
}

req.Candidates = input.Candidates
if req.Candidates == 0 && p.Frontmatter != nil {
req.Candidates = p.Frontmatter.Candidates
if req.Candidates == 0 {
req.Candidates = p.Candidates
}
if req.Candidates == 0 {
req.Candidates = 1
}

if p.Frontmatter != nil {
req.Config = p.Frontmatter.Config
}

if p.Frontmatter != nil {
req.Output = p.Frontmatter.Output
}

if p.Frontmatter != nil {
req.Tools = p.Frontmatter.Tools
req.Config = p.GenerationConfig
req.Output = &ai.GenerateRequestOutput{
Format: p.OutputFormat,
Schema: p.OutputSchema,
}
req.Tools = p.Tools

return req, nil
}
Expand Down Expand Up @@ -174,10 +169,7 @@ func (p *Prompt) Execute(ctx context.Context, input *ActionInput) (*ai.GenerateR

generator := p.generator
if generator == nil {
var model string
if p.Frontmatter != nil {
model = p.Frontmatter.Model
}
model := p.Model
if input.Model != "" {
model = input.Model
}
Expand Down
3 changes: 2 additions & 1 deletion go/genkit/dotprompt/genkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ func (testGenerator) Generate(ctx context.Context, req *ai.GenerateRequest, cb g
}

func TestExecute(t *testing.T) {
p, err := Define("TestExecute", &Frontmatter{}, "TestExecute", testGenerator{})
p, err := New("TestExecute", "TestExecute", nil)
if err != nil {
t.Fatal(err)
}
p.generator = testGenerator{}
resp, err := p.Execute(context.Background(), &ActionInput{})
if err != nil {
t.Fatal(err)
Expand Down
4 changes: 2 additions & 2 deletions go/genkit/dotprompt/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ func (p *Prompt) RenderText(variables map[string]any) (string, error) {

// RenderMessages executes the prompt's template and converts it into messages.
func (p *Prompt) RenderMessages(variables map[string]any) ([]*ai.Message, error) {
if p.Frontmatter != nil && p.Frontmatter.Input.Default != nil {
if p.VariableDefaults != nil {
nv := make(map[string]any)
maps.Copy(nv, p.Frontmatter.Input.Default)
maps.Copy(nv, p.VariableDefaults)
maps.Copy(nv, variables)
variables = nv
}
Expand Down
Loading