Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cache/remotecache/registry/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func ResolveCacheExporterFunc(sm *session.Manager, hosts docker.RegistryHosts) r
insecure = b
}

scope, hosts := registryConfig(hosts, ref, "push", insecure)
scope, hosts := registryConfig(hosts, ref, resolver.ScopeType{Push: true}, insecure)
remote := resolver.DefaultPool.GetResolver(hosts, refString, scope, sm, g)
pusher, err := push.Pusher(ctx, remote, refString)
if err != nil {
Expand All @@ -116,7 +116,7 @@ func ResolveCacheImporterFunc(sm *session.Manager, cs content.Store, hosts docke
insecure = b
}

scope, hosts := registryConfig(hosts, ref, "pull", insecure)
scope, hosts := registryConfig(hosts, ref, resolver.ScopeType{}, insecure)
remote := resolver.DefaultPool.GetResolver(hosts, refString, scope, sm, g)
xref, desc, err := remote.Resolve(ctx, refString)
if err != nil {
Expand Down Expand Up @@ -173,7 +173,7 @@ func (dsl *withDistributionSourceLabel) SnapshotLabels(descs []ocispecs.Descript
return labels
}

func registryConfig(hosts docker.RegistryHosts, ref reference.Named, scope string, insecure bool) (string, docker.RegistryHosts) {
func registryConfig(hosts docker.RegistryHosts, ref reference.Named, scope resolver.ScopeType, insecure bool) (resolver.ScopeType, docker.RegistryHosts) {
if insecure {
insecureTrue := true
httpTrue := true
Expand All @@ -183,7 +183,7 @@ func registryConfig(hosts docker.RegistryHosts, ref reference.Named, scope strin
PlainHTTP: &httpTrue,
},
})
scope += ":insecure"
scope.Insecure = true
}
return scope, hosts
}
2 changes: 1 addition & 1 deletion cmd/buildkitd/main_oci_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ func sourceWithSession(hosts docker.RegistryHosts, sm *session.Manager) sgzsourc
// Get source information based on labels and RegistryHosts containing
// session-based authorizer.
parse := sgzsource.FromDefaultLabels(func(ref reference.Spec) ([]docker.RegistryHost, error) {
return resolver.DefaultPool.GetResolver(hosts, named.String(), "pull", sm, session.NewGroup(sids...)).
return resolver.DefaultPool.GetResolver(hosts, named.String(), resolver.ScopeType{}, sm, session.NewGroup(sids...)).
HostsFunc(ref.Hostname())
})
if s, err := parse(map[string]string{
Expand Down
4 changes: 2 additions & 2 deletions source/containerimage/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (p *puller) CacheKey(ctx context.Context, jobCtx solver.JobContext, index i
var getResolver pull.SessionResolver
switch p.ResolverType {
case ResolverTypeRegistry:
resolver := resolver.DefaultPool.GetResolver(p.RegistryHosts, p.Ref, "pull", p.SessionManager, g).WithImageStore(p.ImageStore, p.Mode)
resolver := resolver.DefaultPool.GetResolver(p.RegistryHosts, p.Ref, resolver.ScopeType{}, p.SessionManager, g).WithImageStore(p.ImageStore, p.Mode)
p.Resolver = resolver
getResolver = func(g session.Group) remotes.Resolver { return resolver.WithSession(g) }
case ResolverTypeOCILayout:
Expand Down Expand Up @@ -218,7 +218,7 @@ func (p *puller) Snapshot(ctx context.Context, jobCtx solver.JobContext) (ir cac
var getResolver pull.SessionResolver
switch p.ResolverType {
case ResolverTypeRegistry:
resolver := resolver.DefaultPool.GetResolver(p.RegistryHosts, p.Ref, "pull", p.SessionManager, g).WithImageStore(p.ImageStore, p.Mode)
resolver := resolver.DefaultPool.GetResolver(p.RegistryHosts, p.Ref, resolver.ScopeType{}, p.SessionManager, g).WithImageStore(p.ImageStore, p.Mode)
p.Resolver = resolver
getResolver = func(g session.Group) remotes.Resolver { return resolver.WithSession(g) }
case ResolverTypeOCILayout:
Expand Down
2 changes: 1 addition & 1 deletion source/containerimage/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func (is *Source) ResolveImageMetadata(ctx context.Context, id *ImageIdentifier,
if err != nil {
return nil, err
}
rslvr := resolver.DefaultPool.GetResolver(is.RegistryHosts, ref, "pull", sm, g).WithImageStore(is.ImageStore, rm)
rslvr := resolver.DefaultPool.GetResolver(is.RegistryHosts, ref, resolver.ScopeType{}, sm, g).WithImageStore(is.ImageStore, rm)
key += rm.String()

ret := &sourceresolver.ResolveImageResponse{}
Expand Down
4 changes: 2 additions & 2 deletions util/push/push.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func Push(ctx context.Context, sm *session.Manager, sid string, provider content
ref = r.String()
}

scope := "push"
scope := resolver.ScopeType{Push: true}
if insecure {
insecureTrue := true
httpTrue := true
Expand All @@ -79,7 +79,7 @@ func Push(ctx context.Context, sm *session.Manager, sid string, provider content
PlainHTTP: &httpTrue,
},
})
scope += ":insecure"
scope.Insecure = true
}

resolver := resolver.DefaultPool.GetResolver(hosts, ref, scope, sm, session.NewGroup(sid))
Expand Down
78 changes: 40 additions & 38 deletions util/resolver/authorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const defaultExpiration = 60
type authHandlerNS struct {
counter int64 // needs to be 64bit aligned for 32bit systems

handlers map[string]*authHandler
fetchers map[string]*authFetcher
muHandlers sync.Mutex
hosts map[string][]docker.RegistryHost
muHosts sync.Mutex
Expand All @@ -40,21 +40,21 @@ type authHandlerNS struct {

func newAuthHandlerNS(sm *session.Manager) *authHandlerNS {
return &authHandlerNS{
handlers: map[string]*authHandler{},
fetchers: map[string]*authFetcher{},
hosts: map[string][]docker.RegistryHost{},
sm: sm,
}
}

func (a *authHandlerNS) get(ctx context.Context, host string, sm *session.Manager, g session.Group) *authHandler {
func (a *authHandlerNS) get(ctx context.Context, host string, sm *session.Manager, g session.Group) *authFetcher {
if g != nil {
if iter := g.SessionIterator(); iter != nil {
for {
id := iter.NextSession()
if id == "" {
break
}
h, ok := a.handlers[host+"/"+id]
h, ok := a.fetchers[host+"/"+id]
if ok {
h.lastUsed = time.Now()
return h
Expand All @@ -63,8 +63,8 @@ func (a *authHandlerNS) get(ctx context.Context, host string, sm *session.Manage
}
}

// link another handler
for k, h := range a.handlers {
// link existing fetcher
for k, h := range a.fetchers {
parts := strings.SplitN(k, "/", 2)
if len(parts) != 2 {
continue
Expand All @@ -73,15 +73,15 @@ func (a *authHandlerNS) get(ctx context.Context, host string, sm *session.Manage
if h.authority != nil {
sessionID, ok, err := sessionauth.VerifyTokenAuthority(ctx, host, h.authority, sm, g)
if err == nil && ok {
a.handlers[host+"/"+sessionID] = h
a.fetchers[host+"/"+sessionID] = h
h.lastUsed = time.Now()
return h
}
} else {
sessionID, username, password, err := sessionauth.CredentialsFunc(sm, g)(host)
if err == nil {
if username == h.common.Username && password == h.common.Secret {
a.handlers[host+"/"+sessionID] = h
a.fetchers[host+"/"+sessionID] = h
h.lastUsed = time.Now()
return h
}
Expand All @@ -93,40 +93,40 @@ func (a *authHandlerNS) get(ctx context.Context, host string, sm *session.Manage
return nil
}

func (a *authHandlerNS) set(host, session string, h *authHandler) {
a.handlers[host+"/"+session] = h
func (a *authHandlerNS) set(host, session string, f *authFetcher) {
a.fetchers[host+"/"+session] = f
}

func (a *authHandlerNS) delete(h *authHandler) {
maps.DeleteFunc(a.handlers, func(_ string, v *authHandler) bool {
return v == h
func (a *authHandlerNS) delete(f *authFetcher) {
maps.DeleteFunc(a.fetchers, func(_ string, v *authFetcher) bool {
return v == f
})
}

type dockerAuthorizer struct {
client *http.Client

sm *session.Manager
session session.Group
handlers *authHandlerNS
sm *session.Manager
session session.Group
handlerNS *authHandlerNS
}

func newDockerAuthorizer(client *http.Client, handlers *authHandlerNS, sm *session.Manager, group session.Group) *dockerAuthorizer {
func newDockerAuthorizer(client *http.Client, handlerNS *authHandlerNS, sm *session.Manager, group session.Group) *dockerAuthorizer {
return &dockerAuthorizer{
client: client,
handlers: handlers,
sm: sm,
session: group,
client: client,
handlerNS: handlerNS,
sm: sm,
session: group,
}
}

// Authorize handles auth request.
func (a *dockerAuthorizer) Authorize(ctx context.Context, req *http.Request) error {
a.handlers.muHandlers.Lock()
defer a.handlers.muHandlers.Unlock()
a.handlerNS.muHandlers.Lock()
defer a.handlerNS.muHandlers.Unlock()

// skip if there is no auth handler
ah := a.handlers.get(ctx, req.URL.Host, a.sm, a.session)
ah := a.handlerNS.get(ctx, req.URL.Host, a.sm, a.session)
if ah == nil {
return nil
}
Expand All @@ -145,20 +145,22 @@ func (a *dockerAuthorizer) getCredentials(host string) (sessionID, username, sec
}

func (a *dockerAuthorizer) AddResponses(ctx context.Context, responses []*http.Response) error {
a.handlers.muHandlers.Lock()
defer a.handlers.muHandlers.Unlock()
handlerNS := a.handlerNS

handlerNS.muHandlers.Lock()
defer handlerNS.muHandlers.Unlock()

last := responses[len(responses)-1]
host := last.Request.URL.Host

handler := a.handlers.get(ctx, host, a.sm, a.session)
handler := handlerNS.get(ctx, host, a.sm, a.session)

for _, c := range auth.ParseAuthHeader(last.Header) {
switch c.Scheme {
case auth.BearerAuth:
var oldScopes []string
if err := invalidAuthorization(c, responses); err != nil {
a.handlers.delete(handler)
handlerNS.delete(handler)

if handler != nil {
oldScopes = handler.common.Scopes
Expand Down Expand Up @@ -199,7 +201,7 @@ func (a *dockerAuthorizer) AddResponses(ctx context.Context, responses []*http.R
}
common.Scopes = parseScopes(append(common.Scopes, oldScopes...)).normalize()

a.handlers.set(host, sessionID, newAuthHandler(host, a.client, c.Scheme, pubKey, common))
handlerNS.set(host, sessionID, newAuthFetcher(host, a.client, c.Scheme, pubKey, common))

return nil
case auth.BasicAuth:
Expand All @@ -209,7 +211,7 @@ func (a *dockerAuthorizer) AddResponses(ctx context.Context, responses []*http.R
}

if username != "" && secret != "" {
a.handlers.set(host, sessionID, newAuthHandler(host, a.client, c.Scheme, nil, auth.TokenOptions{
handlerNS.set(host, sessionID, newAuthFetcher(host, a.client, c.Scheme, nil, auth.TokenOptions{
Username: username,
Secret: secret,
}))
Expand All @@ -226,8 +228,8 @@ type authResult struct {
expires time.Time
}

// authHandler is used to handle auth request per registry server.
type authHandler struct {
// authFetcher is used to process auth request return the token.
type authFetcher struct {
g flightcontrol.Group[*authResult]

client *http.Client
Expand All @@ -250,8 +252,8 @@ type authHandler struct {
authority *[32]byte
}

func newAuthHandler(host string, client *http.Client, scheme auth.AuthenticationScheme, authority *[32]byte, opts auth.TokenOptions) *authHandler {
return &authHandler{
func newAuthFetcher(host string, client *http.Client, scheme auth.AuthenticationScheme, authority *[32]byte, opts auth.TokenOptions) *authFetcher {
return &authFetcher{
host: host,
client: client,
scheme: scheme,
Expand All @@ -262,7 +264,7 @@ func newAuthHandler(host string, client *http.Client, scheme auth.Authentication
}
}

func (ah *authHandler) authorize(ctx context.Context, sm *session.Manager, g session.Group) (string, error) {
func (ah *authFetcher) authorize(ctx context.Context, sm *session.Manager, g session.Group) (string, error) {
switch ah.scheme {
case auth.BasicAuth:
return ah.doBasicAuth()
Expand All @@ -273,7 +275,7 @@ func (ah *authHandler) authorize(ctx context.Context, sm *session.Manager, g ses
}
}

func (ah *authHandler) doBasicAuth() (string, error) {
func (ah *authFetcher) doBasicAuth() (string, error) {
username, secret := ah.common.Username, ah.common.Secret

if username == "" || secret == "" {
Expand All @@ -284,7 +286,7 @@ func (ah *authHandler) doBasicAuth() (string, error) {
return fmt.Sprintf("Basic %s", authHeader), nil
}

func (ah *authHandler) doBearerAuth(ctx context.Context, sm *session.Manager, g session.Group) (token string, err error) {
func (ah *authFetcher) doBearerAuth(ctx context.Context, sm *session.Manager, g session.Group) (token string, err error) {
// copy common tokenOptions
to := ah.common

Expand Down Expand Up @@ -317,7 +319,7 @@ func (ah *authHandler) doBearerAuth(ctx context.Context, sm *session.Manager, g
return res.token, nil
}

func (ah *authHandler) fetchToken(ctx context.Context, sm *session.Manager, g session.Group, to auth.TokenOptions) (r *authResult, err error) {
func (ah *authFetcher) fetchToken(ctx context.Context, sm *session.Manager, g session.Group, to auth.TokenOptions) (r *authResult, err error) {
var issuedAt time.Time
var expires int
var token string
Expand Down
32 changes: 24 additions & 8 deletions util/resolver/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,21 @@ func (p *Pool) gc() {

for k, ns := range p.m {
ns.muHandlers.Lock()
for key, h := range ns.handlers {
for key, h := range ns.fetchers {
if time.Since(h.lastUsed) < 10*time.Minute {
continue
}
parts := strings.SplitN(key, "/", 2)
if len(parts) != 2 {
delete(ns.handlers, key)
delete(ns.fetchers, key)
continue
}
c, err := ns.sm.Get(context.TODO(), parts[1], true)
if c == nil || err != nil {
delete(ns.handlers, key)
delete(ns.fetchers, key)
}
}
if len(ns.handlers) == 0 {
if len(ns.fetchers) == 0 {
delete(p.m, k)
}
ns.muHandlers.Unlock()
Expand All @@ -78,26 +78,26 @@ func (p *Pool) Clear() {
}

// GetResolver gets a resolver for a specified scope from the pool
func (p *Pool) GetResolver(hosts docker.RegistryHosts, ref, scope string, sm *session.Manager, g session.Group) *Resolver {
func (p *Pool) GetResolver(hosts docker.RegistryHosts, ref string, scope ScopeType, sm *session.Manager, g session.Group) *Resolver {
name := ref
named, err := distreference.ParseNormalizedNamed(ref)
if err == nil {
name = named.Name()
}

var key string
if strings.Contains(scope, "push") {
if scope.Push {
// When scope includes "push", index the authHandlerNS cache by session
// id(s) as well to prevent tokens with potential write access to third
// party registries from leaking between client sessions. The key will end
// up looking something like:
// 'wujskoey891qc5cv1edd3yj3p::repository:foo/bar::pull,push'
key = fmt.Sprintf("%s::%s::%s", strings.Join(session.AllSessionIDs(g), ":"), name, scope)
key = fmt.Sprintf("%s::%s::%s", strings.Join(session.AllSessionIDs(g), ":"), name, scope.String())
} else {
// The authHandlerNS is not isolated for pull-only scopes since LLB
// verticies from pulls all end up in the cache anyway and all
// requests/clients have access to the same cache
key = fmt.Sprintf("%s::%s", name, scope)
key = fmt.Sprintf("%s::%s", name, scope.String())
}

p.mu.Lock()
Expand Down Expand Up @@ -143,6 +143,22 @@ func newResolver(hosts docker.RegistryHosts, handler *authHandlerNS, sm *session
return r
}

type ScopeType struct {
Push bool
Insecure bool
}

func (s ScopeType) String() string {
out := "pull"
if s.Push {
out = "push"
}
if s.Insecure {
out += ":insecure"
}
return out
}

// Resolver is a wrapper around remotes.Resolver
type Resolver struct {
remotes.Resolver
Expand Down