Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use default STS endpoint #2044

Merged
merged 2 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,7 @@ func (c *Client) CredContext() *credentials.CredContext {
httpClient = http.DefaultClient
}
return &credentials.CredContext{
Client: httpClient,
Client: httpClient,
Endpoint: c.endpointURL.String(),
}
}
33 changes: 21 additions & 12 deletions pkg/credentials/assume_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,6 @@ type STSAssumeRoleOptions struct {
// NewSTSAssumeRole returns a pointer to a new
// Credentials object wrapping the STSAssumeRole.
func NewSTSAssumeRole(stsEndpoint string, opts STSAssumeRoleOptions) (*Credentials, error) {
if stsEndpoint == "" {
return nil, errors.New("STS endpoint cannot be empty")
}
if opts.AccessKey == "" || opts.SecretKey == "" {
return nil, errors.New("AssumeRole credentials access/secretkey is mandatory")
}
Expand Down Expand Up @@ -220,12 +217,30 @@ func getAssumeRoleCredentials(clnt *http.Client, endpoint string, opts STSAssume
return a, nil
}

func (m *STSAssumeRole) retrieve(cc *CredContext) (Value, error) {
// RetrieveWithCredContext retrieves credentials from the MinIO service.
// Error will be returned if the request fails, optional cred context.
func (m *STSAssumeRole) RetrieveWithCredContext(cc *CredContext) (Value, error) {
if cc == nil {
cc = defaultCredContext
}

client := m.Client
if client == nil {
client = cc.Client
}
a, err := getAssumeRoleCredentials(client, m.STSEndpoint, m.Options)
if client == nil {
client = defaultCredContext.Client
}

stsEndpoint := m.STSEndpoint
if stsEndpoint == "" {
stsEndpoint = cc.Endpoint
}
if stsEndpoint == "" {
return Value{}, errors.New("STS endpoint unknown")
}

a, err := getAssumeRoleCredentials(client, stsEndpoint, m.Options)
if err != nil {
return Value{}, err
}
Expand All @@ -242,14 +257,8 @@ func (m *STSAssumeRole) retrieve(cc *CredContext) (Value, error) {
}, nil
}

// RetrieveWithCredContext retrieves credentials from the MinIO service.
// Error will be returned if the request fails, optional cred context.
func (m *STSAssumeRole) RetrieveWithCredContext(cc *CredContext) (Value, error) {
return m.retrieve(cc)
}

// Retrieve retrieves credentials from the MinIO service.
// Error will be returned if the request fails.
func (m *STSAssumeRole) Retrieve() (Value, error) {
return m.retrieve(defaultCredContext)
return m.RetrieveWithCredContext(nil)
}
2 changes: 1 addition & 1 deletion pkg/credentials/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (c *Chain) RetrieveWithCredContext(cc *CredContext) (Value, error) {
// to IsExpired() will return the expired state of the cached provider.
func (c *Chain) Retrieve() (Value, error) {
for _, p := range c.Providers {
creds, _ := p.RetrieveWithCredContext(defaultCredContext)
creds, _ := p.Retrieve()
// Always prioritize non-anonymous providers, if any.
if creds.AccessKeyID == "" && creds.SecretAccessKey == "" {
continue
Expand Down
13 changes: 12 additions & 1 deletion pkg/credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ type Provider interface {

// Retrieve returns nil if it successfully retrieved the value.
// Error is returned if the value were not obtainable, or empty.
//
// Deprecated: Retrieve() exists for historical compatibility and should not
// be used. To get new credentials use the RetrieveWithCredContext function
// to ensure the proper context (i.e. HTTP client) will be used.
Retrieve() (Value, error)

// IsExpired returns if the credentials are no longer valid, and need
Expand All @@ -77,6 +81,10 @@ type CredContext struct {
// Client specifies the HTTP client that should be used if an HTTP
// request is to be made to fetch the credentials.
Client *http.Client

// Endpoint specifies the MinIO endpoint that will be used if no
// explicit endpoint is provided.
Endpoint string
}

// A Expiry provides shared expiration logic to be used by credentials
Expand Down Expand Up @@ -169,7 +177,7 @@ func New(provider Provider) *Credentials {
// used. To get new credentials use the Credentials.GetWithContext function
// to ensure the proper context (i.e. HTTP client) will be used.
func (c *Credentials) Get() (Value, error) {
return c.GetWithContext(defaultCredContext)
return c.GetWithContext(nil)
}

// GetWithContext returns the credentials value, or error if the
Expand All @@ -185,6 +193,9 @@ func (c *Credentials) GetWithContext(cc *CredContext) (Value, error) {
if c == nil {
return Value{}, nil
}
if cc == nil {
cc = defaultCredContext
}

c.Lock()
defer c.Unlock()
Expand Down
2 changes: 1 addition & 1 deletion pkg/credentials/env_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (e *EnvAWS) Retrieve() (Value, error) {
return e.retrieve()
}

// RetrieveWithContext is like Retrieve (no-op input of Cred Context)
// RetrieveWithCredContext is like Retrieve (no-op input of Cred Context)
func (e *EnvAWS) RetrieveWithCredContext(_ *CredContext) (Value, error) {
return e.retrieve()
}
Expand Down
21 changes: 14 additions & 7 deletions pkg/credentials/iam_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@ func NewIAM(endpoint string) *Credentials {
})
}

func (m *IAM) retrieve(cc *CredContext) (Value, error) {
// RetrieveWithCredContext is like Retrieve with Cred Context
func (m *IAM) RetrieveWithCredContext(cc *CredContext) (Value, error) {
if cc == nil {
cc = defaultCredContext
}

token := os.Getenv("AWS_CONTAINER_AUTHORIZATION_TOKEN")
if token == "" {
token = m.Container.AuthorizationToken
Expand Down Expand Up @@ -143,8 +148,15 @@ func (m *IAM) retrieve(cc *CredContext) (Value, error) {
if client == nil {
client = cc.Client
}
if client == nil {
client = defaultCredContext.Client
}

endpoint := m.Endpoint
if endpoint == "" {
endpoint = cc.Endpoint
}

switch {
case identityFile != "":
if len(endpoint) == 0 {
Expand Down Expand Up @@ -228,12 +240,7 @@ func (m *IAM) retrieve(cc *CredContext) (Value, error) {
// Error will be returned if the request fails, or unable to extract
// the desired
func (m *IAM) Retrieve() (Value, error) {
return m.retrieve(defaultCredContext)
}

// RetrieveWithCredContext is like Retrieve with Cred Context
func (m *IAM) RetrieveWithCredContext(cc *CredContext) (Value, error) {
return m.retrieve(cc)
return m.RetrieveWithCredContext(nil)
}

// A ec2RoleCredRespBody provides the shape for unmarshaling credential
Expand Down
31 changes: 20 additions & 11 deletions pkg/credentials/sts_client_grants.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,6 @@ type STSClientGrants struct {
// NewSTSClientGrants returns a pointer to a new
// Credentials object wrapping the STSClientGrants.
func NewSTSClientGrants(stsEndpoint string, getClientGrantsTokenExpiry func() (*ClientGrantsToken, error)) (*Credentials, error) {
if stsEndpoint == "" {
return nil, errors.New("STS endpoint cannot be empty")
}
if getClientGrantsTokenExpiry == nil {
return nil, errors.New("Client grants access token and expiry retrieval function should be defined")
}
Expand Down Expand Up @@ -160,12 +157,29 @@ func getClientGrantsCredentials(clnt *http.Client, endpoint string,
return a, nil
}

func (m *STSClientGrants) retrieve(cc *CredContext) (Value, error) {
// RetrieveWithCredContext is like Retrieve() with cred context
func (m *STSClientGrants) RetrieveWithCredContext(cc *CredContext) (Value, error) {
if cc == nil {
cc = defaultCredContext
}

client := m.Client
if client == nil {
client = cc.Client
}
a, err := getClientGrantsCredentials(client, m.STSEndpoint, m.GetClientGrantsTokenExpiry)
if client == nil {
client = defaultCredContext.Client
}

stsEndpoint := m.STSEndpoint
if stsEndpoint == "" {
stsEndpoint = cc.Endpoint
}
if stsEndpoint == "" {
return Value{}, errors.New("STS endpoint unknown")
}

a, err := getClientGrantsCredentials(client, stsEndpoint, m.GetClientGrantsTokenExpiry)
if err != nil {
return Value{}, err
}
Expand All @@ -182,13 +196,8 @@ func (m *STSClientGrants) retrieve(cc *CredContext) (Value, error) {
}, nil
}

// RetrieveWithCredContext is like Retrieve() with cred context
func (m *STSClientGrants) RetrieveWithCredContext(cc *CredContext) (Value, error) {
return m.retrieve(cc)
}

// Retrieve retrieves credentials from the MinIO service.
// Error will be returned if the request fails.
func (m *STSClientGrants) Retrieve() (Value, error) {
return m.retrieve(defaultCredContext)
return m.RetrieveWithCredContext(nil)
}
27 changes: 19 additions & 8 deletions pkg/credentials/sts_custom_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,21 @@ type CustomTokenIdentity struct {
RequestedExpiry time.Duration
}

func (c *CustomTokenIdentity) retrieve(cc *CredContext) (value Value, err error) {
u, err := url.Parse(c.STSEndpoint)
// RetrieveWithCredContext with Retrieve optionally cred context
func (c *CustomTokenIdentity) RetrieveWithCredContext(cc *CredContext) (value Value, err error) {
if cc == nil {
cc = defaultCredContext
}

stsEndpoint := c.STSEndpoint
if stsEndpoint == "" {
stsEndpoint = cc.Endpoint
}
if stsEndpoint == "" {
return Value{}, errors.New("STS endpoint unknown")
}

u, err := url.Parse(stsEndpoint)
if err != nil {
return value, err
}
Expand All @@ -97,6 +110,9 @@ func (c *CustomTokenIdentity) retrieve(cc *CredContext) (value Value, err error)
if client == nil {
client = cc.Client
}
if client == nil {
client = defaultCredContext.Client
}

resp, err := client.Do(req)
if err != nil {
Expand Down Expand Up @@ -126,12 +142,7 @@ func (c *CustomTokenIdentity) retrieve(cc *CredContext) (value Value, err error)

// Retrieve - to satisfy Provider interface; fetches credentials from MinIO.
func (c *CustomTokenIdentity) Retrieve() (value Value, err error) {
return c.retrieve(defaultCredContext)
}

// RetrieveWithCredContext with Retrieve optionally cred context
func (c *CustomTokenIdentity) RetrieveWithCredContext(cc *CredContext) (value Value, err error) {
return c.retrieve(cc)
return c.RetrieveWithCredContext(nil)
}

// NewCustomTokenCredentials - returns credentials using the
Expand Down
30 changes: 21 additions & 9 deletions pkg/credentials/sts_ldap_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package credentials
import (
"bytes"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -120,8 +121,22 @@ func NewLDAPIdentityWithSessionPolicy(stsEndpoint, ldapUsername, ldapPassword, p
}), nil
}

func (k *LDAPIdentity) retrieve(cc *CredContext) (value Value, err error) {
u, err := url.Parse(k.STSEndpoint)
// RetrieveWithCredContext gets the credential by calling the MinIO STS API for
// LDAP on the configured stsEndpoint.
func (k *LDAPIdentity) RetrieveWithCredContext(cc *CredContext) (value Value, err error) {
if cc == nil {
cc = defaultCredContext
}

stsEndpoint := k.STSEndpoint
if stsEndpoint == "" {
stsEndpoint = cc.Endpoint
}
if stsEndpoint == "" {
return Value{}, errors.New("STS endpoint unknown")
}

u, err := url.Parse(stsEndpoint)
if err != nil {
return value, err
}
Expand Down Expand Up @@ -149,6 +164,9 @@ func (k *LDAPIdentity) retrieve(cc *CredContext) (value Value, err error) {
if client == nil {
client = cc.Client
}
if client == nil {
client = defaultCredContext.Client
}

resp, err := client.Do(req)
if err != nil {
Expand Down Expand Up @@ -194,11 +212,5 @@ func (k *LDAPIdentity) retrieve(cc *CredContext) (value Value, err error) {
// Retrieve gets the credential by calling the MinIO STS API for
// LDAP on the configured stsEndpoint.
func (k *LDAPIdentity) Retrieve() (value Value, err error) {
return k.retrieve(defaultCredContext)
}

// RetrieveWithCredContext gets the credential by calling the MinIO STS API for
// LDAP on the configured stsEndpoint.
func (k *LDAPIdentity) RetrieveWithCredContext(cc *CredContext) (value Value, err error) {
return k.retrieve(cc)
return k.RetrieveWithCredContext(defaultCredContext)
}
33 changes: 19 additions & 14 deletions pkg/credentials/sts_tls_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,6 @@ type STSCertificateIdentity struct {
// to the given STS endpoint with the given TLS certificate and retrieves and
// rotates S3 credentials.
func NewSTSCertificateIdentity(endpoint string, certificate tls.Certificate, options ...CertificateIdentityOption) (*Credentials, error) {
if endpoint == "" {
return nil, errors.New("STS endpoint cannot be empty")
}
if _, err := url.Parse(endpoint); err != nil {
return nil, err
}
identity := &STSCertificateIdentity{
STSEndpoint: endpoint,
Certificate: certificate,
Expand All @@ -102,8 +96,21 @@ func NewSTSCertificateIdentity(endpoint string, certificate tls.Certificate, opt
return New(identity), nil
}

func (i *STSCertificateIdentity) retrieve(cc *CredContext) (Value, error) {
endpointURL, err := url.Parse(i.STSEndpoint)
// RetrieveWithCredContext is Retrieve with cred context
func (i *STSCertificateIdentity) RetrieveWithCredContext(cc *CredContext) (Value, error) {
if cc == nil {
cc = defaultCredContext
}

stsEndpoint := i.STSEndpoint
if stsEndpoint == "" {
stsEndpoint = cc.Endpoint
}
if stsEndpoint == "" {
return Value{}, errors.New("STS endpoint unknown")
}

endpointURL, err := url.Parse(stsEndpoint)
if err != nil {
return Value{}, err
}
Expand All @@ -130,6 +137,9 @@ func (i *STSCertificateIdentity) retrieve(cc *CredContext) (Value, error) {
if client == nil {
client = cc.Client
}
if client == nil {
client = defaultCredContext.Client
}

tr, ok := client.Transport.(*http.Transport)
if !ok {
Expand Down Expand Up @@ -192,14 +202,9 @@ func (i *STSCertificateIdentity) retrieve(cc *CredContext) (Value, error) {
}, nil
}

// RetrieveWithCredContext is Retrieve with cred context
func (i *STSCertificateIdentity) RetrieveWithCredContext(cc *CredContext) (Value, error) {
return i.retrieve(cc)
}

// Retrieve fetches a new set of S3 credentials from the configured STS API endpoint.
func (i *STSCertificateIdentity) Retrieve() (Value, error) {
return i.retrieve(defaultCredContext)
return i.RetrieveWithCredContext(defaultCredContext)
}

// Expiration returns the expiration time of the current S3 credentials.
Expand Down
Loading
Loading