Skip to content

enhance: update credentials framework for OAuth support #305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 50 additions & 30 deletions pkg/cli/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"sort"
"strings"
"text/tabwriter"
"time"

cmd2 "github.com/acorn-io/cmd"
"github.com/gptscript-ai/gptscript/pkg/cache"
Expand All @@ -14,6 +15,11 @@ import (
"github.com/spf13/cobra"
)

const (
expiresNever = "never"
expiresExpired = "expired"
)

type Credential struct {
root *GPTScript
AllContexts bool `usage:"List credentials for all contexts" local:"true"`
Expand Down Expand Up @@ -46,6 +52,7 @@ func (c *Credential) Run(_ *cobra.Command, _ []string) error {
}
opts.Cache = cache.Complete(opts.Cache)

// Initialize the credential store and get all the credentials.
store, err := credentials.NewStore(cfg, ctx, opts.Cache.CacheDir)
if err != nil {
return fmt.Errorf("failed to get credentials store: %w", err)
Expand All @@ -56,6 +63,10 @@ func (c *Credential) Run(_ *cobra.Command, _ []string) error {
return fmt.Errorf("failed to list credentials: %w", err)
}

w := tabwriter.NewWriter(os.Stdout, 10, 1, 3, ' ', 0)
defer w.Flush()

// Sort credentials and print column names, depending on the options.
if c.AllContexts {
// Sort credentials by context
sort.Slice(creds, func(i, j int) bool {
Expand All @@ -65,25 +76,10 @@ func (c *Credential) Run(_ *cobra.Command, _ []string) error {
return creds[i].Context < creds[j].Context
})

w := tabwriter.NewWriter(os.Stdout, 10, 1, 3, ' ', 0)
defer w.Flush()

if c.ShowEnvVars {
_, _ = w.Write([]byte("CONTEXT\tCREDENTIAL\tENVIRONMENT VARIABLES\n"))

for _, cred := range creds {
envVars := make([]string, 0, len(cred.Env))
for envVar := range cred.Env {
envVars = append(envVars, envVar)
}
sort.Strings(envVars)
_, _ = fmt.Fprintf(w, "%s\t%s\t%s\n", cred.Context, cred.ToolName, strings.Join(envVars, ", "))
}
_, _ = w.Write([]byte("CONTEXT\tCREDENTIAL\tEXPIRES IN\tENV\n"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: does the linter complain if you omit _, _ =?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. golangci-lint complains, and GoLand underlines it

} else {
_, _ = w.Write([]byte("CONTEXT\tCREDENTIAL\n"))
for _, cred := range creds {
_, _ = fmt.Fprintf(w, "%s\t%s\n", cred.Context, cred.ToolName)
}
_, _ = w.Write([]byte("CONTEXT\tCREDENTIAL\tEXPIRES IN\n"))
}
} else {
// Sort credentials by tool name
Expand All @@ -92,24 +88,48 @@ func (c *Credential) Run(_ *cobra.Command, _ []string) error {
})

if c.ShowEnvVars {
w := tabwriter.NewWriter(os.Stdout, 10, 1, 3, ' ', 0)
defer w.Flush()
_, _ = w.Write([]byte("CREDENTIAL\tENVIRONMENT VARIABLES\n"))

for _, cred := range creds {
envVars := make([]string, 0, len(cred.Env))
for envVar := range cred.Env {
envVars = append(envVars, envVar)
}
sort.Strings(envVars)
_, _ = fmt.Fprintf(w, "%s\t%s\n", cred.ToolName, strings.Join(envVars, ", "))
_, _ = w.Write([]byte("CREDENTIAL\tEXPIRES IN\tENV\n"))
} else {
_, _ = w.Write([]byte("CREDENTIAL\tEXPIRES IN\n"))
}
}

for _, cred := range creds {
expires := expiresNever
if cred.ExpiresAt != nil {
expires = expiresExpired
if !cred.IsExpired() {
expires = time.Until(*cred.ExpiresAt).Truncate(time.Second).String()
}
}

var fields []any
if c.AllContexts {
fields = []any{cred.Context, cred.ToolName, expires}
} else {
for _, cred := range creds {
fmt.Println(cred.ToolName)
fields = []any{cred.ToolName, expires}
}

if c.ShowEnvVars {
envVars := make([]string, 0, len(cred.Env))
for envVar := range cred.Env {
envVars = append(envVars, envVar)
}
sort.Strings(envVars)
fields = append(fields, strings.Join(envVars, ", "))
}

printFields(w, fields)
}

return nil
}

func printFields(w *tabwriter.Writer, fields []any) {
if len(fields) == 0 {
return
}

fmtStr := strings.Repeat("%s\t", len(fields)-1) + "%s\n"
_, _ = fmt.Fprintf(w, fmtStr, fields...)
}
47 changes: 32 additions & 15 deletions pkg/credentials/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,58 @@ import (
"encoding/json"
"fmt"
"strings"
"time"

"github.com/docker/cli/cli/config/types"
)

const ctxSeparator = "///"

type CredentialType string

const (
ctxSeparator = "///"
CredentialTypeTool CredentialType = "tool"
CredentialTypeModelProvider CredentialType = "modelProvider"
ExistingCredential = "GPTSCRIPT_EXISTING_CREDENTIAL"
)

type Credential struct {
Context string `json:"context"`
ToolName string `json:"toolName"`
Type CredentialType `json:"type"`
Env map[string]string `json:"env"`
Context string `json:"context"`
ToolName string `json:"toolName"`
Type CredentialType `json:"type"`
Env map[string]string `json:"env"`
ExpiresAt *time.Time `json:"expiresAt"`
RefreshToken string `json:"refreshToken"`
}

func (c Credential) IsExpired() bool {
if c.ExpiresAt == nil {
return false
}
return time.Now().After(*c.ExpiresAt)
}

func (c Credential) toDockerAuthConfig() (types.AuthConfig, error) {
env, err := json.Marshal(c.Env)
cred, err := json.Marshal(c)
if err != nil {
return types.AuthConfig{}, err
}

return types.AuthConfig{
Username: string(c.Type),
Password: string(env),
Password: string(cred),
ServerAddress: toolNameWithCtx(c.ToolName, c.Context),
}, nil
}

func credentialFromDockerAuthConfig(authCfg types.AuthConfig) (Credential, error) {
var env map[string]string
if err := json.Unmarshal([]byte(authCfg.Password), &env); err != nil {
return Credential{}, err
var cred Credential
if err := json.Unmarshal([]byte(authCfg.Password), &cred); err != nil || len(cred.Env) == 0 {
// Legacy: try unmarshalling into just an env map
var env map[string]string
if err := json.Unmarshal([]byte(authCfg.Password), &env); err != nil {
return Credential{}, err
}
cred.Env = env
}

// We used to hardcode the username as "gptscript" before CredentialType was introduced, so
Expand All @@ -62,10 +77,12 @@ func credentialFromDockerAuthConfig(authCfg types.AuthConfig) (Credential, error
}

return Credential{
Context: ctx,
ToolName: tool,
Type: CredentialType(credType),
Env: env,
Context: ctx,
ToolName: tool,
Type: CredentialType(credType),
Env: cred.Env,
ExpiresAt: cred.ExpiresAt,
RefreshToken: cred.RefreshToken,
}, nil
}

Expand Down
45 changes: 26 additions & 19 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ var (
EventTypeRunFinish EventType = "runFinish"
)

func getContextInput(prg *types.Program, ref types.ToolReference, input string) (string, error) {
func getToolRefInput(prg *types.Program, ref types.ToolReference, input string) (string, error) {
if ref.Arg == "" {
return "", nil
}
Expand Down Expand Up @@ -355,7 +355,7 @@ func (r *Runner) getContext(callCtx engine.Context, state *State, monitor Monito
continue
}

contextInput, err := getContextInput(callCtx.Program, toolRef, input)
contextInput, err := getToolRefInput(callCtx.Program, toolRef, input)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -867,7 +867,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
}

var (
cred *credentials.Credential
c *credentials.Credential
exists bool
)

Expand All @@ -879,25 +879,39 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
// Only try to look up the cred if the tool is on GitHub or has an alias.
// If it is a GitHub tool and has an alias, the alias overrides the tool name, so we use it as the credential name.
if isGitHubTool(toolName) && credentialAlias == "" {
cred, exists, err = r.credStore.Get(toolName)
c, exists, err = r.credStore.Get(toolName)
if err != nil {
return nil, fmt.Errorf("failed to get credentials for tool %s: %w", toolName, err)
}
} else if credentialAlias != "" {
cred, exists, err = r.credStore.Get(credentialAlias)
c, exists, err = r.credStore.Get(credentialAlias)
if err != nil {
return nil, fmt.Errorf("failed to get credentials for tool %s: %w", credentialAlias, err)
}
}

if c == nil {
c = &credentials.Credential{}
}

// If the credential doesn't already exist in the store, run the credential tool in order to get the value,
// and save it in the store.
if !exists {
if !exists || c.IsExpired() {
credToolRefs, ok := callCtx.Tool.ToolMapping[credToolName]
if !ok || len(credToolRefs) != 1 {
return nil, fmt.Errorf("failed to find ID for tool %s", credToolName)
}

// If the existing credential is expired, we need to provide it to the cred tool through the environment.
if exists && c.IsExpired() {
credJSON, err := json.Marshal(c)
if err != nil {
return nil, fmt.Errorf("failed to marshal credential: %w", err)
}
env = append(env, fmt.Sprintf("%s=%s", credentials.ExistingCredential, string(credJSON)))
}

// Get the input for the credential tool, if there is any.
var input string
if args != nil {
inputBytes, err := json.Marshal(args)
Expand All @@ -916,21 +930,14 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", credToolName)
}

var envMap struct {
Env map[string]string `json:"env"`
}
if err := json.Unmarshal([]byte(*res.Result), &envMap); err != nil {
if err := json.Unmarshal([]byte(*res.Result), &c); err != nil {
return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", credToolName, err)
}

cred = &credentials.Credential{
Type: credentials.CredentialTypeTool,
Env: envMap.Env,
ToolName: credName,
}
c.ToolName = credName
c.Type = credentials.CredentialTypeTool

isEmpty := true
for _, v := range cred.Env {
for _, v := range c.Env {
if v != "" {
isEmpty = false
break
Expand All @@ -941,15 +948,15 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[credToolRefs[0].ToolID].Source.Repo != nil) || credentialAlias != "" {
if isEmpty {
log.Warnf("Not saving empty credential for tool %s", toolName)
} else if err := r.credStore.Add(*cred); err != nil {
} else if err := r.credStore.Add(*c); err != nil {
return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err)
}
} else {
log.Warnf("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases.", toolName)
}
}

for k, v := range cred.Env {
for k, v := range c.Env {
env = append(env, fmt.Sprintf("%s=%s", k, v))
}
}
Expand Down
Loading