Skip to content

Commit

Permalink
[Go] Fix model name lookup in dotprompt. (#1369)
Browse files Browse the repository at this point in the history
  • Loading branch information
apascal07 authored and hugoaguirre committed Nov 25, 2024
1 parent 7ed5e0a commit 0fbba32
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
2 changes: 1 addition & 1 deletion go/plugins/dotprompt/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func (p *Prompt) Generate(ctx context.Context, pr *PromptRequest, cb func(contex
return nil, errors.New("dotprompt model not in provider/name format")
}

model := ai.LookupModel(provider, name)
model = ai.LookupModel(provider, name)
if model == nil {
return nil, fmt.Errorf("no model named %q for provider %q", name, provider)
}
Expand Down
33 changes: 25 additions & 8 deletions go/plugins/dotprompt/genkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,31 @@ func testGenerate(ctx context.Context, req *ai.GenerateRequest, cb func(context.

func TestExecute(t *testing.T) {
testModel := ai.DefineModel("test", "test", nil, testGenerate)
p, err := New("TestExecute", "TestExecute", Config{Model: testModel})
if err != nil {
t.Fatal(err)
}
resp, err := p.Generate(context.Background(), &PromptRequest{}, nil)
if err != nil {
t.Fatal(err)
}
t.Run("Model", func(t *testing.T) {
p, err := New("TestExecute", "TestExecute", Config{Model: testModel})
if err != nil {
t.Fatal(err)
}
resp, err := p.Generate(context.Background(), &PromptRequest{}, nil)
if err != nil {
t.Fatal(err)
}
assertResponse(t, resp)
})
t.Run("ModelName", func(t *testing.T) {
p, err := New("TestExecute", "TestExecute", Config{ModelName: "test/test"})
if err != nil {
t.Fatal(err)
}
resp, err := p.Generate(context.Background(), &PromptRequest{}, nil)
if err != nil {
t.Fatal(err)
}
assertResponse(t, resp)
})
}

func assertResponse(t *testing.T, resp *ai.GenerateResponse) {
if len(resp.Candidates) != 1 {
t.Errorf("got %d candidates, want 1", len(resp.Candidates))
if len(resp.Candidates) < 1 {
Expand Down

0 comments on commit 0fbba32

Please sign in to comment.