Skip to content

Commit

Permalink
Add Context support to auth methods (#1949)
Browse files Browse the repository at this point in the history
Signed-off-by: Jon Johnson <jon.johnson@chainguard.dev>
  • Loading branch information
jonjohnsonjr authored May 16, 2024
1 parent ff385a9 commit 39d1148
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 33 deletions.
8 changes: 4 additions & 4 deletions cmd/crane/cmd/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ $ curl -H "$(crane auth token -H ubuntu)" https://index.docker.io/v2/library/ubu
return err
}

auth, err := o.Keychain.Resolve(repo)
auth, err := authn.Resolve(cmd.Context(), o.Keychain, repo)
if err != nil {
return err
}
Expand Down Expand Up @@ -152,7 +152,7 @@ func NewCmdAuthGet(options []crane.Option, argv ...string) *cobra.Command {
Short: "Implements a credential helper",
Example: eg,
Args: cobra.MaximumNArgs(1),
RunE: func(_ *cobra.Command, args []string) error {
RunE: func(cmd *cobra.Command, args []string) error {
registryAddr := ""
if len(args) == 1 {
registryAddr = args[0]
Expand All @@ -168,7 +168,7 @@ func NewCmdAuthGet(options []crane.Option, argv ...string) *cobra.Command {
if err != nil {
return err
}
authorizer, err := crane.GetOptions(options...).Keychain.Resolve(reg)
authorizer, err := authn.Resolve(cmd.Context(), crane.GetOptions(options...).Keychain, reg)
if err != nil {
return err
}
Expand All @@ -182,7 +182,7 @@ func NewCmdAuthGet(options []crane.Option, argv ...string) *cobra.Command {
os.Exit(1)
}

auth, err := authorizer.Authorization()
auth, err := authn.Authorization(cmd.Context(), authorizer)
if err != nil {
return err
}
Expand Down
17 changes: 17 additions & 0 deletions pkg/authn/authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package authn

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
Expand All @@ -27,6 +28,22 @@ type Authenticator interface {
Authorization() (*AuthConfig, error)
}

// ContextAuthenticator is like Authenticator, but allows for context to be passed in.
type ContextAuthenticator interface {
// Authorization returns the value to use in an http transport's Authorization header.
AuthorizationContext(context.Context) (*AuthConfig, error)
}

// Authorization calls AuthorizationContext with ctx if the given [Authenticator] implements [ContextAuthenticator],
// otherwise it calls Resolve with the given [Resource].
func Authorization(ctx context.Context, authn Authenticator) (*AuthConfig, error) {
if actx, ok := authn.(ContextAuthenticator); ok {
return actx.AuthorizationContext(ctx)
}

return authn.Authorization()
}

// AuthConfig contains authorization information for connecting to a Registry
// Inlined what we use from github.com/docker/cli/cli/config/types
type AuthConfig struct {
Expand Down
41 changes: 37 additions & 4 deletions pkg/authn/keychain.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package authn

import (
"context"
"os"
"path/filepath"
"sync"
Expand Down Expand Up @@ -45,6 +46,11 @@ type Keychain interface {
Resolve(Resource) (Authenticator, error)
}

// ContextKeychain is like Keychain, but allows for context to be passed in.
type ContextKeychain interface {
ResolveContext(context.Context, Resource) (Authenticator, error)
}

// defaultKeychain implements Keychain with the semantics of the standard Docker
// credential keychain.
type defaultKeychain struct {
Expand All @@ -62,8 +68,23 @@ const (
DefaultAuthKey = "https://" + name.DefaultRegistry + "/v1/"
)

// Resolve implements Keychain.
// Resolve calls ResolveContext with ctx if the given [Keychain] implements [ContextKeychain],
// otherwise it calls Resolve with the given [Resource].
func Resolve(ctx context.Context, keychain Keychain, target Resource) (Authenticator, error) {
if rctx, ok := keychain.(ContextKeychain); ok {
return rctx.ResolveContext(ctx, target)
}

return keychain.Resolve(target)
}

// ResolveContext implements ContextKeychain.
func (dk *defaultKeychain) Resolve(target Resource) (Authenticator, error) {
return dk.ResolveContext(context.Background(), target)
}

// Resolve implements Keychain.
func (dk *defaultKeychain) ResolveContext(ctx context.Context, target Resource) (Authenticator, error) {
dk.mu.Lock()
defer dk.mu.Unlock()

Expand Down Expand Up @@ -180,6 +201,10 @@ func NewKeychainFromHelper(h Helper) Keychain { return wrapper{h} }
type wrapper struct{ h Helper }

func (w wrapper) Resolve(r Resource) (Authenticator, error) {
return w.ResolveContext(context.Background(), r)
}

func (w wrapper) ResolveContext(ctx context.Context, r Resource) (Authenticator, error) {
u, p, err := w.h.Get(r.RegistryStr())
if err != nil {
return Anonymous, nil
Expand All @@ -206,8 +231,12 @@ type refreshingKeychain struct {
}

func (r *refreshingKeychain) Resolve(target Resource) (Authenticator, error) {
return r.ResolveContext(context.Background(), target)
}

func (r *refreshingKeychain) ResolveContext(ctx context.Context, target Resource) (Authenticator, error) {
last := time.Now()
auth, err := r.keychain.Resolve(target)
auth, err := Resolve(ctx, r.keychain, target)
if err != nil || auth == Anonymous {
return auth, err
}
Expand Down Expand Up @@ -236,17 +265,21 @@ type refreshing struct {
}

func (r *refreshing) Authorization() (*AuthConfig, error) {
return r.AuthorizationContext(context.Background())
}

func (r *refreshing) AuthorizationContext(ctx context.Context) (*AuthConfig, error) {
r.Lock()
defer r.Unlock()
if r.cached == nil || r.expired() {
r.last = r.now()
auth, err := r.keychain.Resolve(r.target)
auth, err := Resolve(ctx, r.keychain, r.target)
if err != nil {
return nil, err
}
r.cached = auth
}
return r.cached.Authorization()
return Authorization(ctx, r.cached)
}

func (r *refreshing) now() time.Time {
Expand Down
8 changes: 7 additions & 1 deletion pkg/authn/multikeychain.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package authn

import "context"

type multiKeychain struct {
keychains []Keychain
}
Expand All @@ -28,8 +30,12 @@ func NewMultiKeychain(kcs ...Keychain) Keychain {

// Resolve implements Keychain.
func (mk *multiKeychain) Resolve(target Resource) (Authenticator, error) {
return mk.ResolveContext(context.Background(), target)
}

func (mk *multiKeychain) ResolveContext(ctx context.Context, target Resource) (Authenticator, error) {
for _, kc := range mk.keychains {
auth, err := kc.Resolve(target)
auth, err := Resolve(ctx, kc, target)
if err != nil {
return nil, err
}
Expand Down
18 changes: 10 additions & 8 deletions pkg/v1/google/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,23 @@ import (
const cloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"

// GetGcloudCmd is exposed so we can test this.
var GetGcloudCmd = func() *exec.Cmd {
var GetGcloudCmd = func(ctx context.Context) *exec.Cmd {
// This is odd, but basically what docker-credential-gcr does.
//
// config-helper is undocumented, but it's purportedly the only supported way
// of accessing tokens (`gcloud auth print-access-token` is discouraged).
//
// --force-auth-refresh means we are getting a token that is valid for about
// an hour (we reuse it until it's expired).
return exec.Command("gcloud", "config", "config-helper", "--force-auth-refresh", "--format=json(credential)")
return exec.CommandContext(ctx, "gcloud", "config", "config-helper", "--force-auth-refresh", "--format=json(credential)")
}

// NewEnvAuthenticator returns an authn.Authenticator that generates access
// tokens from the environment we're running in.
//
// See: https://godoc.org/golang.org/x/oauth2/google#FindDefaultCredentials
func NewEnvAuthenticator() (authn.Authenticator, error) {
ts, err := googauth.DefaultTokenSource(context.Background(), cloudPlatformScope)
func NewEnvAuthenticator(ctx context.Context) (authn.Authenticator, error) {
ts, err := googauth.DefaultTokenSource(ctx, cloudPlatformScope)
if err != nil {
return nil, err
}
Expand All @@ -62,14 +62,14 @@ func NewEnvAuthenticator() (authn.Authenticator, error) {

// NewGcloudAuthenticator returns an oauth2.TokenSource that generates access
// tokens by shelling out to the gcloud sdk.
func NewGcloudAuthenticator() (authn.Authenticator, error) {
func NewGcloudAuthenticator(ctx context.Context) (authn.Authenticator, error) {
if _, err := exec.LookPath("gcloud"); err != nil {
// gcloud is not available, fall back to anonymous
logs.Warn.Println("gcloud binary not found")
return authn.Anonymous, nil
}

ts := gcloudSource{GetGcloudCmd}
ts := gcloudSource{ctx, GetGcloudCmd}

// Attempt to fetch a token to ensure gcloud is installed and we can run it.
token, err := ts.Token()
Expand Down Expand Up @@ -143,13 +143,15 @@ type gcloudOutput struct {
}

type gcloudSource struct {
ctx context.Context

// This is passed in so that we mock out gcloud and test Token.
exec func() *exec.Cmd
exec func(ctx context.Context) *exec.Cmd
}

// Token implements oauath2.TokenSource.
func (gs gcloudSource) Token() (*oauth2.Token, error) {
cmd := gs.exec()
cmd := gs.exec(gs.ctx)
var out bytes.Buffer
cmd.Stdout = &out

Expand Down
15 changes: 9 additions & 6 deletions pkg/v1/google/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package google

import (
"bytes"
"context"
"fmt"
"os"
"os/exec"
Expand Down Expand Up @@ -84,15 +85,16 @@ func TestMain(m *testing.M) {
}
}

func newGcloudCmdMock(env string) func() *exec.Cmd {
return func() *exec.Cmd {
cmd := exec.Command(os.Args[0])
func newGcloudCmdMock(env string) func(context.Context) *exec.Cmd {
return func(ctx context.Context) *exec.Cmd {
cmd := exec.CommandContext(ctx, os.Args[0])
cmd.Env = []string{fmt.Sprintf("GO_TEST_MODE=%s", env)}
return cmd
}
}

func TestGcloudErrors(t *testing.T) {
ctx := context.Background()
cases := []struct {
env string

Expand All @@ -113,7 +115,7 @@ func TestGcloudErrors(t *testing.T) {
t.Run(tc.env, func(t *testing.T) {
GetGcloudCmd = newGcloudCmdMock(tc.env)

if _, err := NewGcloudAuthenticator(); err == nil {
if _, err := NewGcloudAuthenticator(ctx); err == nil {
t.Errorf("wanted error, got nil")
} else if got := err.Error(); !strings.HasPrefix(got, tc.wantPrefix) {
t.Errorf("wanted error prefix %q, got %q", tc.wantPrefix, got)
Expand All @@ -123,13 +125,14 @@ func TestGcloudErrors(t *testing.T) {
}

func TestGcloudSuccess(t *testing.T) {
ctx := context.Background()
// Stupid coverage to make sure it doesn't panic.
var b bytes.Buffer
logs.Debug.SetOutput(&b)

GetGcloudCmd = newGcloudCmdMock("success")

auth, err := NewGcloudAuthenticator()
auth, err := NewGcloudAuthenticator(ctx)
if err != nil {
t.Fatalf("NewGcloudAuthenticator got error %v", err)
}
Expand Down Expand Up @@ -263,7 +266,7 @@ func TestNewEnvAuthenticatorFailure(t *testing.T) {
}

// Expect error.
_, err := NewEnvAuthenticator()
_, err := NewEnvAuthenticator(context.Background())
if err == nil {
t.Errorf("expected err, got nil")
}
Expand Down
14 changes: 10 additions & 4 deletions pkg/v1/google/keychain.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package google

import (
"context"
"strings"
"sync"

Expand Down Expand Up @@ -52,26 +53,31 @@ type googleKeychain struct {
// In general, we don't worry about that here because we expect to use the same
// gcloud configuration in the scope of this one process.
func (gk *googleKeychain) Resolve(target authn.Resource) (authn.Authenticator, error) {
return gk.ResolveContext(context.Background(), target)
}

// ResolveContext implements authn.ContextKeychain.
func (gk *googleKeychain) ResolveContext(ctx context.Context, target authn.Resource) (authn.Authenticator, error) {
// Only authenticate GCR and AR so it works with authn.NewMultiKeychain to fallback.
if !isGoogle(target.RegistryStr()) {
return authn.Anonymous, nil
}

gk.once.Do(func() {
gk.auth = resolve()
gk.auth = resolve(ctx)
})

return gk.auth, nil
}

func resolve() authn.Authenticator {
auth, envErr := NewEnvAuthenticator()
func resolve(ctx context.Context) authn.Authenticator {
auth, envErr := NewEnvAuthenticator(ctx)
if envErr == nil && auth != authn.Anonymous {
logs.Debug.Println("google.Keychain: using Application Default Credentials")
return auth
}

auth, gErr := NewGcloudAuthenticator()
auth, gErr := NewGcloudAuthenticator(ctx)
if gErr == nil && auth != authn.Anonymous {
logs.Debug.Println("google.Keychain: using gcloud fallback")
return auth
Expand Down
2 changes: 1 addition & 1 deletion pkg/v1/remote/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type fetcher struct {
func makeFetcher(ctx context.Context, target resource, o *options) (*fetcher, error) {
auth := o.auth
if o.keychain != nil {
kauth, err := o.keychain.Resolve(target)
kauth, err := authn.Resolve(ctx, o.keychain, target)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/v1/remote/transport/basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ var _ http.RoundTripper = (*basicTransport)(nil)
// RoundTrip implements http.RoundTripper
func (bt *basicTransport) RoundTrip(in *http.Request) (*http.Response, error) {
if bt.auth != authn.Anonymous {
auth, err := bt.auth.Authorization()
auth, err := authn.Authorization(in.Context(), bt.auth)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/v1/remote/transport/bearer.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func Exchange(ctx context.Context, reg name.Registry, auth authn.Authenticator,
if err != nil {
return nil, err
}
authcfg, err := auth.Authorization()
authcfg, err := authn.Authorization(ctx, auth)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -190,7 +190,7 @@ func (bt *bearerTransport) RoundTrip(in *http.Request) (*http.Response, error) {
// The basic token exchange is attempted first, falling back to the oauth flow.
// If the IdentityToken is set, this indicates that we should start with the oauth flow.
func (bt *bearerTransport) refresh(ctx context.Context) error {
auth, err := bt.basic.Authorization()
auth, err := authn.Authorization(ctx, bt.basic)
if err != nil {
return err
}
Expand Down Expand Up @@ -295,7 +295,7 @@ func canonicalAddress(host, scheme string) (address string) {

// https://docs.docker.com/registry/spec/auth/oauth/
func (bt *bearerTransport) refreshOauth(ctx context.Context) ([]byte, error) {
auth, err := bt.basic.Authorization()
auth, err := authn.Authorization(ctx, bt.basic)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 39d1148

Please sign in to comment.