From 259e094f80b0872edb33bc0033644576418328b2 Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Thu, 28 Mar 2024 02:46:54 -0700 Subject: [PATCH] bug: fix model provider working with no openai key set --- go.mod | 2 ++ go.sum | 4 ++-- pkg/engine/cmd.go | 3 +++ pkg/engine/daemon.go | 31 ++++++++++++++++++++---- pkg/env/env.go | 4 +++- pkg/gptscript/gptscript.go | 3 ++- pkg/openai/client.go | 48 ++++++++++++++++++++++++++++---------- pkg/parser/parser.go | 2 ++ pkg/remote/remote.go | 18 ++++---------- pkg/types/tool.go | 4 ++++ 10 files changed, 85 insertions(+), 34 deletions(-) diff --git a/go.mod b/go.mod index 0c118fff..d2ab7bed 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,8 @@ module github.com/gptscript-ai/gptscript go 1.22.0 +replace github.com/sashabaranov/go-openai => github.com/gptscript-ai/go-openai v0.0.0-20240328093028-7993661f9eab + require ( github.com/AlecAivazis/survey/v2 v2.3.7 github.com/BurntSushi/locker v0.0.0-20171006230638-a6e239ea1c69 diff --git a/go.sum b/go.sum index db983116..24a9d14e 100644 --- a/go.sum +++ b/go.sum @@ -116,6 +116,8 @@ github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gptscript-ai/go-openai v0.0.0-20240328093028-7993661f9eab h1:uZP7zZqtQI5lfK0fGBmi2ZUrI973tNCnCDx326LG00k= +github.com/gptscript-ai/go-openai v0.0.0-20240328093028-7993661f9eab/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= @@ -206,8 +208,6 @@ github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= github.com/samber/slog-logrus v1.0.0 h1:SsrN0p9akjCEaYd42Q5GtisMdHm0q11UD4fp4XCZi04= github.com/samber/slog-logrus v1.0.0/go.mod h1:ZTdPCmVWljwlfjz6XflKNvW4TcmYlexz4HMUOO/42bI= -github.com/sashabaranov/go-openai v1.20.1 h1:cFnTixAtc0I0cCBFr8gkvEbGCm6Rjf2JyoVWCjXwy9g= -github.com/sashabaranov/go-openai v1.20.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= diff --git a/pkg/engine/cmd.go b/pkg/engine/cmd.go index c2dbfd8a..af2dffc8 100644 --- a/pkg/engine/cmd.go +++ b/pkg/engine/cmd.go @@ -151,6 +151,9 @@ func appendInputAsEnv(env []string, input string) []string { func (e *Engine) newCommand(ctx context.Context, extraEnv []string, tool types.Tool, input string) (*exec.Cmd, func(), error) { envvars := append(e.Env[:], extraEnv...) envvars = appendInputAsEnv(envvars, input) + if log.IsDebug() { + envvars = append(envvars, "GPTSCRIPT_DEBUG=true") + } interpreter, rest, _ := strings.Cut(tool.Instructions, "\n") interpreter = strings.TrimSpace(interpreter)[2:] diff --git a/pkg/engine/daemon.go b/pkg/engine/daemon.go index 607b7870..01e77533 100644 --- a/pkg/engine/daemon.go +++ b/pkg/engine/daemon.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "math/rand" "net/http" "os" "strings" @@ -18,7 +19,7 @@ var ( daemonLock sync.Mutex startPort, endPort int64 - nextPort int64 + usedPorts map[int64]struct{} daemonCtx context.Context daemonClose func() daemonWG sync.WaitGroup @@ -41,10 +42,29 @@ func (e *Engine) getNextPort() int64 { startPort = 10240 endPort = 11240 } - count := endPort - startPort - nextPort++ - nextPort = nextPort % count - return startPort + nextPort + // This is pretty simple and inefficient approach, but also never releases ports + count := endPort - startPort + 1 + toTry := make([]int64, 0, count) + for i := startPort; i <= endPort; i++ { + toTry = append(toTry, i) + } + + rand.Shuffle(len(toTry), func(i, j int) { + toTry[i], toTry[j] = toTry[j], toTry[i] + }) + + for _, nextPort := range toTry { + if _, ok := usedPorts[nextPort]; ok { + continue + } + if usedPorts == nil { + usedPorts = map[int64]struct{}{} + } + usedPorts[nextPort] = struct{}{} + return nextPort + } + + panic("Ran out of usable ports") } func getPath(instructions string) (string, string) { @@ -92,6 +112,7 @@ func (e *Engine) startDaemon(_ context.Context, tool types.Tool) (string, error) cmd, stop, err := e.newCommand(ctx, []string{ fmt.Sprintf("PORT=%d", port), + fmt.Sprintf("GPTSCRIPT_PORT=%d", port), }, tool, "{}", diff --git a/pkg/env/env.go b/pkg/env/env.go index 7402426d..4d84af0d 100644 --- a/pkg/env/env.go +++ b/pkg/env/env.go @@ -13,7 +13,9 @@ func execEquals(bin, check string) bool { } func ToEnvLike(v string) string { - return strings.ToUpper(strings.ReplaceAll(v, "-", "_")) + v = strings.ReplaceAll(v, ".", "_") + v = strings.ReplaceAll(v, "-", "_") + return strings.ToUpper(v) } func Matches(cmd []string, bin string) bool { diff --git a/pkg/gptscript/gptscript.go b/pkg/gptscript/gptscript.go index 323917a7..9c64b2ee 100644 --- a/pkg/gptscript/gptscript.go +++ b/pkg/gptscript/gptscript.go @@ -54,7 +54,8 @@ func New(opts *Options) (*GPTScript, error) { } oAIClient, err := openai.NewClient(append([]openai.Options{opts.OpenAI}, openai.Options{ - Cache: cacheClient, + Cache: cacheClient, + SetSeed: true, })...) if err != nil { return nil, err diff --git a/pkg/openai/client.go b/pkg/openai/client.go index 884acb08..8ee9e0d2 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -33,11 +33,12 @@ var ( ) type Client struct { - url string - key string defaultModel string c *openai.Client cache *cache.Client + invalidAuth bool + cacheKeyBase string + setSeed bool } type Options struct { @@ -47,6 +48,8 @@ type Options struct { APIType openai.APIType `usage:"OpenAI API Type (valid: OPEN_AI, AZURE, AZURE_AD)" name:"openai-api-type" env:"OPENAI_API_TYPE"` OrgID string `usage:"OpenAI organization ID" name:"openai-org-id" env:"OPENAI_ORG_ID"` DefaultModel string `usage:"Default LLM model to use" default:"gpt-4-turbo-preview"` + SetSeed bool `usage:"-"` + CacheKey string `usage:"-"` Cache *cache.Client } @@ -59,6 +62,8 @@ func complete(opts ...Options) (result Options, err error) { result.APIVersion = types.FirstSet(opt.APIVersion, result.APIVersion) result.APIType = types.FirstSet(opt.APIType, result.APIType) result.DefaultModel = types.FirstSet(opt.DefaultModel, result.DefaultModel) + result.SetSeed = types.FirstSet(opt.SetSeed, result.SetSeed) + result.CacheKey = types.FirstSet(opt.CacheKey, result.CacheKey) } if result.Cache == nil { @@ -75,10 +80,6 @@ func complete(opts ...Options) (result Options, err error) { result.APIKey = key } - if result.APIKey == "" && result.BaseURL == "" { - return result, fmt.Errorf("OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable") - } - return result, err } @@ -112,13 +113,28 @@ func NewClient(opts ...Options) (*Client, error) { cfg.APIVersion = types.FirstSet(opt.APIVersion, cfg.APIVersion) cfg.APIType = types.FirstSet(opt.APIType, cfg.APIType) + cacheKeyBase := opt.CacheKey + if cacheKeyBase == "" { + cacheKeyBase = hash.ID(opt.APIKey, opt.BaseURL) + } + return &Client{ c: openai.NewClientWithConfig(cfg), cache: opt.Cache, defaultModel: opt.DefaultModel, + cacheKeyBase: cacheKeyBase, + invalidAuth: opt.APIKey == "" && opt.BaseURL == "", + setSeed: opt.SetSeed, }, nil } +func (c *Client) ValidAuth() error { + if c.invalidAuth { + return fmt.Errorf("OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable") + } + return nil +} + func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) { models, err := c.ListModels(ctx) if err != nil { @@ -133,6 +149,10 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result [] return nil, nil } + if err := c.ValidAuth(); err != nil { + return nil, err + } + models, err := c.c.ListModels(ctx) if err != nil { return nil, err @@ -146,8 +166,7 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result [] func (c *Client) cacheKey(request openai.ChatCompletionRequest) string { return hash.Encode(map[string]any{ - "url": c.url, - "key": c.key, + "base": c.cacheKeyBase, "request": request, }) } @@ -277,6 +296,10 @@ func toMessages(request types.CompletionRequest) (result []openai.ChatCompletion } func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) { + if err := c.ValidAuth(); err != nil { + return nil, err + } + if messageRequest.Model == "" { messageRequest.Model = c.defaultModel } @@ -296,10 +319,9 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques } if messageRequest.Temperature == nil { - // this is a hack because the field is marked as omitempty, so we need it to be set to a non-zero value but arbitrarily small - request.Temperature = 1e-08 + request.Temperature = new(float32) } else { - request.Temperature = *messageRequest.Temperature + request.Temperature = messageRequest.Temperature } if messageRequest.JSONResponse { @@ -330,7 +352,9 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques } var cacheResponse bool - request.Seed = ptr(c.seed(request)) + if c.setSeed { + request.Seed = ptr(c.seed(request)) + } response, ok, err := c.fromCache(ctx, messageRequest, request) if err != nil { return nil, err diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go index 7c93c03d..003ddcf2 100644 --- a/pkg/parser/parser.go +++ b/pkg/parser/parser.go @@ -76,6 +76,8 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) { switch normalize(key) { case "name": tool.Parameters.Name = strings.ToLower(value) + case "modelprovider": + tool.Parameters.ModelProvider = true case "model", "modelname": tool.Parameters.ModelName = value case "description": diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 6dcd3e60..2dba17de 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -5,12 +5,12 @@ import ( "fmt" "net/url" "os" - "slices" "sort" "strings" "sync" "github.com/gptscript-ai/gptscript/pkg/cache" + env2 "github.com/gptscript-ai/gptscript/pkg/env" "github.com/gptscript-ai/gptscript/pkg/loader" "github.com/gptscript-ai/gptscript/pkg/openai" "github.com/gptscript-ai/gptscript/pkg/runner" @@ -78,15 +78,6 @@ func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) { return false, err } - models, err := client.ListModels(ctx) - if err != nil { - return false, err - } - - if !slices.Contains(models, modelNameSuffix) { - return false, fmt.Errorf("Failed in find model [%s], supported [%s]", modelNameSuffix, strings.Join(models, ", ")) - } - c.clientsLock.Lock() defer c.clientsLock.Unlock() @@ -108,7 +99,7 @@ func (c *Client) clientFromURL(apiURL string) (*openai.Client, error) { if err != nil { return nil, err } - env := strings.ToUpper(strings.ReplaceAll(parsed.Hostname(), ".", "_")) + "_API_KEY" + env := "GPTSCRIPT_PROVIDER_" + env2.ToEnvLike(parsed.Hostname()) + "_API_KEY" apiKey := os.Getenv(env) if apiKey == "" { apiKey = "" @@ -159,8 +150,9 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err } client, err = openai.NewClient(openai.Options{ - BaseURL: url, - Cache: c.cache, + BaseURL: url, + Cache: c.cache, + CacheKey: prg.EntryToolID, }) if err != nil { return nil, err diff --git a/pkg/types/tool.go b/pkg/types/tool.go index 7b58d21f..26e114b8 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -39,6 +39,7 @@ type Parameters struct { Description string `json:"description,omitempty"` MaxTokens int `json:"maxTokens,omitempty"` ModelName string `json:"modelName,omitempty"` + ModelProvider bool `json:"modelProvider,omitempty"` JSONResponse bool `json:"jsonResponse,omitempty"` Temperature *float32 `json:"temperature,omitempty"` Cache *bool `json:"cache,omitempty"` @@ -81,6 +82,9 @@ func (t Tool) String() string { if t.Parameters.ModelName != "" { _, _ = fmt.Fprintf(buf, "Model Name: %s\n", t.Parameters.ModelName) } + if t.Parameters.ModelProvider { + _, _ = fmt.Fprintf(buf, "Model Provider: true\n") + } if t.Parameters.JSONResponse { _, _ = fmt.Fprintln(buf, "JSON Response: true") }