Skip to content

Commit

Permalink
Improve (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunshineplan authored Jul 11, 2024
1 parent e9d14b7 commit 476a4d3
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 7 deletions.
1 change: 1 addition & 0 deletions ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ func TestGemini(t *testing.T) {
return
}
gemini, err := gemini.New(
context.Background(),
ai.WithAPIKey(apiKey),
ai.WithEndpoint(os.Getenv("GEMINI_ENDPOINT")),
ai.WithProxy(os.Getenv("GEMINI_PROXY")),
Expand Down
3 changes: 2 additions & 1 deletion client/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"errors"

"github.com/sunshineplan/ai"
Expand All @@ -26,7 +27,7 @@ func New(cfg ai.ClientConfig) (client ai.AI, err error) {
case ai.ChatGPT:
client, err = chatgpt.New(opts...)
case ai.Gemini:
client, err = gemini.New(opts...)
client, err = gemini.New(context.Background(), opts...)
default:
err = errors.New("unknown LLMs: " + string(cfg.LLMs))
}
Expand Down
4 changes: 2 additions & 2 deletions gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type Gemini struct {
limiter *rate.Limiter
}

func New(opts ...ai.ClientOption) (ai.AI, error) {
func New(ctx context.Context, opts ...ai.ClientOption) (ai.AI, error) {
cfg := new(ai.ClientConfig)
for _, i := range opts {
i.Apply(cfg)
Expand All @@ -50,7 +50,7 @@ func New(opts ...ai.ClientOption) (ai.AI, error) {
if cfg.Endpoint != "" {
o = append(o, option.WithEndpoint(cfg.Endpoint))
}
client, err := genai.NewClient(context.Background(), o...)
client, err := genai.NewClient(ctx, o...)
if err != nil {
return nil, err
}
Expand Down
8 changes: 4 additions & 4 deletions prompt/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

const (
defaultTimeout = time.Minute
defaultTimeout = 5 * time.Minute
defaultWorkers = 3
)

Expand Down Expand Up @@ -144,10 +144,10 @@ func (prompt *Prompt) Execute(ai ai.AI, input []string, prefix string) (<-chan *
return c, n, nil
}

func (prompt *Prompt) JobList(ai ai.AI, input []string, prefix string, c chan<- *Result) (*workers.JobList[*Result], error) {
func (prompt *Prompt) JobList(ai ai.AI, input []string, prefix string, c chan<- *Result) (*workers.JobList[*Result], int, error) {
prompts, err := prompt.Prompts(input, prefix)
if err != nil {
return nil, err
return nil, 0, err
}
jobList := workers.NewJobList(workers.NewWorkers(prompt.workers), func(r *Result) {
resp, err := chat(ai, prompt.d, r.Prompt)
Expand All @@ -163,7 +163,7 @@ func (prompt *Prompt) JobList(ai ai.AI, input []string, prefix string, c chan<-
for i, p := range prompts {
jobList.PushBack(&Result{Index: i, Prompt: p})
}
return jobList, nil
return jobList, len(prompts), nil
}

func chat(ai ai.AI, d time.Duration, p string) (ai.ChatResponse, error) {
Expand Down

0 comments on commit 476a4d3

Please sign in to comment.