Skip to content

Commit ea76469

Browse files
committed
enhance: update credentials framework for oauth support
Signed-off-by: Grant Linville <grant@acorn.io>
1 parent 9ee00d4 commit ea76469

File tree

2 files changed

+59
-35
lines changed

2 files changed

+59
-35
lines changed

pkg/credentials/credential.go

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,43 +4,58 @@ import (
44
"encoding/json"
55
"fmt"
66
"strings"
7+
"time"
78

89
"github.com/docker/cli/cli/config/types"
910
)
1011

11-
const ctxSeparator = "///"
12-
1312
type CredentialType string
1413

1514
const (
15+
ctxSeparator = "///"
1616
CredentialTypeTool CredentialType = "tool"
1717
CredentialTypeModelProvider CredentialType = "modelProvider"
18+
ExistingCredential = "GPTSCRIPT_EXISTING_CREDENTIAL"
1819
)
1920

2021
type Credential struct {
21-
Context string `json:"context"`
22-
ToolName string `json:"toolName"`
23-
Type CredentialType `json:"type"`
24-
Env map[string]string `json:"env"`
22+
Context string `json:"context"`
23+
ToolName string `json:"toolName"`
24+
Type CredentialType `json:"type"`
25+
Env map[string]string `json:"env"`
26+
ExpiresAt *time.Time `json:"expiresAt"`
27+
RefreshToken string `json:"refreshToken"`
28+
}
29+
30+
func (c Credential) IsExpired() bool {
31+
if c.ExpiresAt == nil {
32+
return false
33+
}
34+
return time.Now().After(*c.ExpiresAt)
2535
}
2636

2737
func (c Credential) toDockerAuthConfig() (types.AuthConfig, error) {
28-
env, err := json.Marshal(c.Env)
38+
cred, err := json.Marshal(c)
2939
if err != nil {
3040
return types.AuthConfig{}, err
3141
}
3242

3343
return types.AuthConfig{
34-
Username: string(c.Type),
35-
Password: string(env),
44+
Username: string(c.Type), // Username is required, but not used
45+
Password: string(cred),
3646
ServerAddress: toolNameWithCtx(c.ToolName, c.Context),
3747
}, nil
3848
}
3949

4050
func credentialFromDockerAuthConfig(authCfg types.AuthConfig) (Credential, error) {
41-
var env map[string]string
42-
if err := json.Unmarshal([]byte(authCfg.Password), &env); err != nil {
43-
return Credential{}, err
51+
var cred Credential
52+
if err := json.Unmarshal([]byte(authCfg.Password), &cred); err != nil {
53+
// Legacy: try unmarshalling into just an env map
54+
var env map[string]string
55+
if err := json.Unmarshal([]byte(authCfg.Password), &env); err != nil {
56+
return Credential{}, err
57+
}
58+
cred.Env = env
4459
}
4560

4661
// We used to hardcode the username as "gptscript" before CredentialType was introduced, so
@@ -62,10 +77,12 @@ func credentialFromDockerAuthConfig(authCfg types.AuthConfig) (Credential, error
6277
}
6378

6479
return Credential{
65-
Context: ctx,
66-
ToolName: tool,
67-
Type: CredentialType(credType),
68-
Env: env,
80+
Context: ctx,
81+
ToolName: tool,
82+
Type: CredentialType(credType),
83+
Env: cred.Env,
84+
ExpiresAt: cred.ExpiresAt,
85+
RefreshToken: cred.RefreshToken,
6986
}, nil
7087
}
7188

pkg/runner/runner.go

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ var (
246246
EventTypeRunFinish EventType = "runFinish"
247247
)
248248

249-
func getContextInput(prg *types.Program, ref types.ToolReference, input string) (string, error) {
249+
func getToolRefInput(prg *types.Program, ref types.ToolReference, input string) (string, error) {
250250
if ref.Arg == "" {
251251
return "", nil
252252
}
@@ -351,7 +351,7 @@ func (r *Runner) getContext(callCtx engine.Context, state *State, monitor Monito
351351
continue
352352
}
353353

354-
contextInput, err := getContextInput(callCtx.Program, toolRef, input)
354+
contextInput, err := getToolRefInput(callCtx.Program, toolRef, input)
355355
if err != nil {
356356
return nil, nil, err
357357
}
@@ -842,7 +842,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
842842
}
843843

844844
var (
845-
cred *credentials.Credential
845+
c *credentials.Credential
846846
exists bool
847847
)
848848

@@ -854,25 +854,39 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
854854
// Only try to look up the cred if the tool is on GitHub or has an alias.
855855
// If it is a GitHub tool and has an alias, the alias overrides the tool name, so we use it as the credential name.
856856
if isGitHubTool(toolName) && credentialAlias == "" {
857-
cred, exists, err = r.credStore.Get(toolName)
857+
c, exists, err = r.credStore.Get(toolName)
858858
if err != nil {
859859
return nil, fmt.Errorf("failed to get credentials for tool %s: %w", toolName, err)
860860
}
861861
} else if credentialAlias != "" {
862-
cred, exists, err = r.credStore.Get(credentialAlias)
862+
c, exists, err = r.credStore.Get(credentialAlias)
863863
if err != nil {
864864
return nil, fmt.Errorf("failed to get credentials for tool %s: %w", credentialAlias, err)
865865
}
866866
}
867867

868+
if c == nil {
869+
c = &credentials.Credential{}
870+
}
871+
868872
// If the credential doesn't already exist in the store, run the credential tool in order to get the value,
869873
// and save it in the store.
870-
if !exists {
874+
if !exists || c.IsExpired() {
871875
credToolRefs, ok := callCtx.Tool.ToolMapping[credToolName]
872876
if !ok || len(credToolRefs) != 1 {
873877
return nil, fmt.Errorf("failed to find ID for tool %s", credToolName)
874878
}
875879

880+
// If the existing credential is expired, we need to provide it to the cred tool through the environment.
881+
if exists && c.IsExpired() {
882+
credJson, err := json.Marshal(c)
883+
if err != nil {
884+
return nil, fmt.Errorf("failed to marshal credential: %w", err)
885+
}
886+
env = append(env, fmt.Sprintf("%s=%s", credentials.ExistingCredential, string(credJson)))
887+
}
888+
889+
// Get the input for the credential tool, if there is any.
876890
var input string
877891
if args != nil {
878892
inputBytes, err := json.Marshal(args)
@@ -882,6 +896,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
882896
input = string(inputBytes)
883897
}
884898

899+
// Prepare and execute the subcall.
885900
subCtx, err := callCtx.SubCall(callCtx.Ctx, input, credToolRefs[0].ToolID, "", engine.CredentialToolCategory) // leaving callID as "" will cause it to be set by the engine
886901
if err != nil {
887902
return nil, fmt.Errorf("failed to create subcall context for tool %s: %w", credToolName, err)
@@ -896,21 +911,13 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
896911
return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", credToolName)
897912
}
898913

899-
var envMap struct {
900-
Env map[string]string `json:"env"`
901-
}
902-
if err := json.Unmarshal([]byte(*res.Result), &envMap); err != nil {
914+
if err := json.Unmarshal([]byte(*res.Result), &c); err != nil {
903915
return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", credToolName, err)
904916
}
905-
906-
cred = &credentials.Credential{
907-
Type: credentials.CredentialTypeTool,
908-
Env: envMap.Env,
909-
ToolName: credName,
910-
}
917+
c.ToolName = credName
911918

912919
isEmpty := true
913-
for _, v := range cred.Env {
920+
for _, v := range c.Env {
914921
if v != "" {
915922
isEmpty = false
916923
break
@@ -921,15 +928,15 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
921928
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[credToolRefs[0].ToolID].Source.Repo != nil) || credentialAlias != "" {
922929
if isEmpty {
923930
log.Warnf("Not saving empty credential for tool %s", toolName)
924-
} else if err := r.credStore.Add(*cred); err != nil {
931+
} else if err := r.credStore.Add(*c); err != nil {
925932
return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err)
926933
}
927934
} else {
928935
log.Warnf("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases.", toolName)
929936
}
930937
}
931938

932-
for k, v := range cred.Env {
939+
for k, v := range c.Env {
933940
env = append(env, fmt.Sprintf("%s=%s", k, v))
934941
}
935942
}

0 commit comments

Comments
 (0)