Skip to content

Commit

Permalink
Revert "googleai: fix options need add default value" (#627)
Browse files Browse the repository at this point in the history
This reverts commit 40d40ea.
  • Loading branch information
eliben authored Feb 22, 2024
1 parent 3aaa209 commit 55e1a97
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 52 deletions.
30 changes: 0 additions & 30 deletions chains/llm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/stretchr/testify/require"
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/llms/googleai"
"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/prompts"
)
Expand Down Expand Up @@ -56,32 +55,3 @@ func TestLLMChainWithChatPromptTemplate(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "AI: foo\nHuman: boo", result)
}

func TestLLMChainWithGoogleAI(t *testing.T) {
t.Parallel()
genaiKey := os.Getenv("GENAI_API_KEY")
if genaiKey == "" {
t.Skip("GENAI_API_KEY not set")
}
model, err := googleai.New(context.Background(), googleai.WithAPIKey(genaiKey))
require.NoError(t, err)
require.NoError(t, err)
model.CallbacksHandler = callbacks.LogHandler{}

prompt := prompts.NewPromptTemplate(
"What is the capital of {{.country}}",
[]string{"country"},
)
require.NoError(t, err)

chain := NewLLMChain(model, prompt)

result, err := Predict(context.Background(), chain,
map[string]any{
"country": "France",
},
WithCallback(callbacks.LogHandler{}),
)
require.NoError(t, err)
require.True(t, strings.Contains(result, "Paris"))
}
22 changes: 0 additions & 22 deletions llms/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageC
for _, opt := range options {
opt(&opts)
}
g.setCallOptionsDefaults(&opts)

model := g.client.GenerativeModel(opts.Model)
model.SetCandidateCount(int32(opts.CandidateCount))
Expand Down Expand Up @@ -83,27 +82,6 @@ func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageC
return response, nil
}

func (g *GoogleAI) setCallOptionsDefaults(opts *llms.CallOptions) {
if opts.Model == "" {
opts.Model = g.opts.defaultModel
}
if opts.CandidateCount == 0 {
opts.CandidateCount = g.opts.defaultCandidateCount
}
if opts.MaxTokens == 0 {
opts.MaxTokens = g.opts.defaultMaxTokens
}
if opts.Temperature == 0 {
opts.Temperature = g.opts.defaultTemperature
}
if opts.TopP == 0 {
opts.TopP = g.opts.defaultTopP
}
if opts.TopK == 0 {
opts.TopK = g.opts.defaultTopK
}
}

// convertCandidates converts a sequence of genai.Candidate to a response.
func convertCandidates(candidates []*genai.Candidate) (*llms.ContentResponse, error) {
var contentResponse llms.ContentResponse
Expand Down

0 comments on commit 55e1a97

Please sign in to comment.