diff --git a/config/http_config.go b/config/http_config.go index e2178860..5e9d6507 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -20,6 +20,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/json" + "errors" "fmt" "net" "net/http" @@ -29,7 +30,6 @@ import ( "strings" "sync" "time" - "unsafe" conntrack "github.com/mwitkow/go-conntrack" "golang.org/x/net/http/httpproxy" @@ -131,8 +131,12 @@ func (tv *TLSVersion) String() string { type BasicAuth struct { Username string `yaml:"username" json:"username"` UsernameFile string `yaml:"username_file,omitempty" json:"username_file,omitempty"` + // UsernameRef is the name of the secret within the secret manager to use as the username. + UsernameRef string `yaml:"username_ref,omitempty" json:"username_ref,omitempty"` Password Secret `yaml:"password,omitempty" json:"password,omitempty"` PasswordFile string `yaml:"password_file,omitempty" json:"password_file,omitempty"` + // PasswordRef is the name of the secret within the secret manager to use as the password. + PasswordRef string `yaml:"password_ref,omitempty" json:"password_ref,omitempty"` } // SetDirectory joins any relative file paths with dir. @@ -149,6 +153,8 @@ type Authorization struct { Type string `yaml:"type,omitempty" json:"type,omitempty"` Credentials Secret `yaml:"credentials,omitempty" json:"credentials,omitempty"` CredentialsFile string `yaml:"credentials_file,omitempty" json:"credentials_file,omitempty"` + // CredentialsRef is the name of the secret within the secret manager to use as credentials. + CredentialsRef string `yaml:"credentials_ref,omitempty" json:"credentials_ref,omitempty"` } // SetDirectory joins any relative file paths with dir. @@ -225,14 +231,17 @@ func (u URL) MarshalJSON() ([]byte, error) { // OAuth2 is the oauth2 client configuration. type OAuth2 struct { - ClientID string `yaml:"client_id" json:"client_id"` - ClientSecret Secret `yaml:"client_secret" json:"client_secret"` - ClientSecretFile string `yaml:"client_secret_file" json:"client_secret_file"` - Scopes []string `yaml:"scopes,omitempty" json:"scopes,omitempty"` - TokenURL string `yaml:"token_url" json:"token_url"` - EndpointParams map[string]string `yaml:"endpoint_params,omitempty" json:"endpoint_params,omitempty"` - TLSConfig TLSConfig `yaml:"tls_config,omitempty"` - ProxyConfig `yaml:",inline"` + ClientID string `yaml:"client_id" json:"client_id"` + ClientSecret Secret `yaml:"client_secret" json:"client_secret"` + ClientSecretFile string `yaml:"client_secret_file" json:"client_secret_file"` + // ClientSecretRef is the name of the secret within the secret manager to use as the client + // secret. + ClientSecretRef string `yaml:"client_secret_ref" json:"client_secret_ref"` + Scopes []string `yaml:"scopes,omitempty" json:"scopes,omitempty"` + TokenURL string `yaml:"token_url" json:"token_url"` + EndpointParams map[string]string `yaml:"endpoint_params,omitempty" json:"endpoint_params,omitempty"` + TLSConfig TLSConfig `yaml:"tls_config,omitempty"` + ProxyConfig `yaml:",inline"` } // UnmarshalYAML implements the yaml.Unmarshaler interface @@ -330,6 +339,18 @@ func (c *HTTPClientConfig) SetDirectory(dir string) { c.BearerTokenFile = JoinDir(dir, c.BearerTokenFile) } +// nonZeroCount returns the amount of values that are non-zero. +func nonZeroCount[T comparable](values ...T) int { + count := 0 + var zero T + for _, value := range values { + if value != zero { + count += 1 + } + } + return count +} + // Validate validates the HTTPClientConfig to check only one of BearerToken, // BasicAuth and BearerTokenFile is configured. It also validates that ProxyURL // is set if ProxyConnectHeader is set. @@ -341,17 +362,17 @@ func (c *HTTPClientConfig) Validate() error { if (c.BasicAuth != nil || c.OAuth2 != nil) && (len(c.BearerToken) > 0 || len(c.BearerTokenFile) > 0) { return fmt.Errorf("at most one of basic_auth, oauth2, bearer_token & bearer_token_file must be configured") } - if c.BasicAuth != nil && (string(c.BasicAuth.Username) != "" && c.BasicAuth.UsernameFile != "") { - return fmt.Errorf("at most one of basic_auth username & username_file must be configured") + if c.BasicAuth != nil && nonZeroCount(string(c.BasicAuth.Username) != "", c.BasicAuth.UsernameFile != "", c.BasicAuth.UsernameRef != "") > 1 { + return fmt.Errorf("at most one of basic_auth username, username_file & username_ref must be configured") } - if c.BasicAuth != nil && (string(c.BasicAuth.Password) != "" && c.BasicAuth.PasswordFile != "") { - return fmt.Errorf("at most one of basic_auth password & password_file must be configured") + if c.BasicAuth != nil && nonZeroCount(string(c.BasicAuth.Password) != "", c.BasicAuth.PasswordFile != "", c.BasicAuth.PasswordRef != "") > 1 { + return fmt.Errorf("at most one of basic_auth password, password_file & password_ref must be configured") } if c.Authorization != nil { if len(c.BearerToken) > 0 || len(c.BearerTokenFile) > 0 { return fmt.Errorf("authorization is not compatible with bearer_token & bearer_token_file") } - if string(c.Authorization.Credentials) != "" && c.Authorization.CredentialsFile != "" { + if nonZeroCount(string(c.Authorization.Credentials) != "", c.Authorization.CredentialsFile != "", c.Authorization.CredentialsRef != "") > 1 { return fmt.Errorf("at most one of authorization credentials & credentials_file must be configured") } c.Authorization.Type = strings.TrimSpace(c.Authorization.Type) @@ -386,8 +407,8 @@ func (c *HTTPClientConfig) Validate() error { if len(c.OAuth2.TokenURL) == 0 { return fmt.Errorf("oauth2 token_url must be configured") } - if len(c.OAuth2.ClientSecret) > 0 && len(c.OAuth2.ClientSecretFile) > 0 { - return fmt.Errorf("at most one of oauth2 client_secret & client_secret_file must be configured") + if nonZeroCount(len(c.OAuth2.ClientSecret) > 0, len(c.OAuth2.ClientSecretFile) > 0, len(c.OAuth2.ClientSecretRef) > 0) > 1 { + return fmt.Errorf("at most one of oauth2 client_secret, client_secret_file & client_secret_ref must be configured") } } if err := c.ProxyConfig.Validate(); err != nil { @@ -438,50 +459,78 @@ type httpClientOptions struct { idleConnTimeout time.Duration userAgent string host string + secretManager SecretManager } // HTTPClientOption defines an option that can be applied to the HTTP client. -type HTTPClientOption func(options *httpClientOptions) +type HTTPClientOption interface { + applyToHTTPClientOptions(options *httpClientOptions) +} + +type httpClientOptionFunc func(options *httpClientOptions) + +func (f httpClientOptionFunc) applyToHTTPClientOptions(options *httpClientOptions) { + f(options) +} // WithDialContextFunc allows you to override func gets used for the actual dialing. The default is `net.Dialer.DialContext`. func WithDialContextFunc(fn DialContextFunc) HTTPClientOption { - return func(opts *httpClientOptions) { + return httpClientOptionFunc(func(opts *httpClientOptions) { opts.dialContextFunc = fn - } + }) } // WithKeepAlivesDisabled allows to disable HTTP keepalive. func WithKeepAlivesDisabled() HTTPClientOption { - return func(opts *httpClientOptions) { + return httpClientOptionFunc(func(opts *httpClientOptions) { opts.keepAlivesEnabled = false - } + }) } // WithHTTP2Disabled allows to disable HTTP2. func WithHTTP2Disabled() HTTPClientOption { - return func(opts *httpClientOptions) { + return httpClientOptionFunc(func(opts *httpClientOptions) { opts.http2Enabled = false - } + }) } // WithIdleConnTimeout allows setting the idle connection timeout. func WithIdleConnTimeout(timeout time.Duration) HTTPClientOption { - return func(opts *httpClientOptions) { + return httpClientOptionFunc(func(opts *httpClientOptions) { opts.idleConnTimeout = timeout - } + }) } // WithUserAgent allows setting the user agent. func WithUserAgent(ua string) HTTPClientOption { - return func(opts *httpClientOptions) { + return httpClientOptionFunc(func(opts *httpClientOptions) { opts.userAgent = ua - } + }) } // WithHost allows setting the host header. func WithHost(host string) HTTPClientOption { - return func(opts *httpClientOptions) { + return httpClientOptionFunc(func(opts *httpClientOptions) { opts.host = host + }) +} + +type secretManagerOption struct { + secretManager SecretManager +} + +func (s *secretManagerOption) applyToHTTPClientOptions(opts *httpClientOptions) { + opts.secretManager = s.secretManager +} + +func (s *secretManagerOption) applyToTLSConfigOptions(opts *tlsConfigOptions) { + opts.secretManager = s.secretManager +} + +// WithSecretManager allows setting the secret manager. +func WithSecretManager(manager SecretManager) *secretManagerOption { + return &secretManagerOption{ + secretManager: manager, } } @@ -511,9 +560,16 @@ func NewClientFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HTTPClie // given config.HTTPClientConfig and config.HTTPClientOption. // The name is used as go-conntrack metric label. func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HTTPClientOption) (http.RoundTripper, error) { + return NewRoundTripperFromConfigWithContext(context.Background(), cfg, name, optFuncs...) +} + +// NewRoundTripperFromConfig returns a new HTTP RoundTripper configured for the +// given config.HTTPClientConfig and config.HTTPClientOption. +// The name is used as go-conntrack metric label. +func NewRoundTripperFromConfigWithContext(ctx context.Context, cfg HTTPClientConfig, name string, optFuncs ...HTTPClientOption) (http.RoundTripper, error) { opts := defaultHTTPClientOptions - for _, f := range optFuncs { - f(&opts) + for _, opt := range optFuncs { + opt.applyToHTTPClientOptions(&opts) } var dialContext func(ctx context.Context, network, addr string) (net.Conn, error) @@ -562,20 +618,40 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT // If a authorization_credentials is provided, create a round tripper that will set the // Authorization header correctly on each request. if cfg.Authorization != nil { - rt = NewAuthorizationCredentialsRoundTripper(cfg.Authorization.Type, secretFrom(cfg.Authorization.Credentials, cfg.Authorization.CredentialsFile), rt) + credentialsSecret, err := toSecret(opts.secretManager, cfg.Authorization.Credentials, cfg.Authorization.CredentialsFile, cfg.Authorization.CredentialsRef) + if err != nil { + return nil, fmt.Errorf("unable to use credentials: %w", err) + } + rt = NewAuthorizationCredentialsRoundTripper(cfg.Authorization.Type, credentialsSecret, rt) } // Backwards compatibility, be nice with importers who would not have // called Validate(). if len(cfg.BearerToken) > 0 || len(cfg.BearerTokenFile) > 0 { - rt = NewAuthorizationCredentialsRoundTripper("Bearer", secretFrom(cfg.BearerToken, cfg.BearerTokenFile), rt) + bearerSecret, err := toSecret(opts.secretManager, cfg.BearerToken, cfg.BearerTokenFile, "") + if err != nil { + return nil, fmt.Errorf("unable to use bearer token: %w", err) + } + rt = NewAuthorizationCredentialsRoundTripper("Bearer", bearerSecret, rt) } if cfg.BasicAuth != nil { - rt = NewBasicAuthRoundTripper(secretFrom(Secret(cfg.BasicAuth.Username), cfg.BasicAuth.UsernameFile), secretFrom(cfg.BasicAuth.Password, cfg.BasicAuth.PasswordFile), rt) + usernameSecret, err := toSecret(opts.secretManager, Secret(cfg.BasicAuth.Username), cfg.BasicAuth.UsernameFile, cfg.BasicAuth.UsernameRef) + if err != nil { + return nil, fmt.Errorf("unable to use username: %w", err) + } + passwordSecret, err := toSecret(opts.secretManager, cfg.BasicAuth.Password, cfg.BasicAuth.PasswordFile, cfg.BasicAuth.PasswordRef) + if err != nil { + return nil, fmt.Errorf("unable to use password: %w", err) + } + rt = NewBasicAuthRoundTripper(usernameSecret, passwordSecret, rt) } if cfg.OAuth2 != nil { - rt = NewOAuth2RoundTripper(cfg.OAuth2, rt, &opts) + clientSecret, err := toSecret(opts.secretManager, Secret(cfg.OAuth2.ClientSecret), cfg.OAuth2.ClientSecretFile, cfg.OAuth2.ClientSecretRef) + if err != nil { + return nil, fmt.Errorf("unable to use client secret: %w", err) + } + rt = NewOAuth2RoundTripper(clientSecret, cfg.OAuth2, rt, &opts) } if cfg.HTTPHeaders != nil { @@ -594,21 +670,30 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT return rt, nil } - tlsConfig, err := NewTLSConfig(&cfg.TLSConfig) + tlsConfig, err := NewTLSConfig(&cfg.TLSConfig, WithSecretManager(opts.secretManager)) if err != nil { return nil, err } - tlsSettings := cfg.TLSConfig.roundTripperSettings() + tlsSettings, err := cfg.TLSConfig.roundTripperSettings(opts.secretManager) + if err != nil { + return nil, err + } if tlsSettings.CA == nil || tlsSettings.CA.immutable() { // No need for a RoundTripper that reloads the CA file automatically. return newRT(tlsConfig) } - return NewTLSRoundTripper(tlsConfig, tlsSettings, newRT) + return NewTLSRoundTripperWithContext(ctx, tlsConfig, tlsSettings, newRT) +} + +// SecretManager manages secret data mapped to names known as "references" or "refs". +type SecretManager interface { + // Fetch returns the secret data given a secret name indicated by `secretRef`. + Fetch(ctx context.Context, secretRef string) (string, error) } type secret interface { - fetch() (string, error) + fetch(ctx context.Context) (string, error) description() string immutable() bool } @@ -617,7 +702,7 @@ type inlineSecret struct { text string } -func (s *inlineSecret) fetch() (string, error) { +func (s *inlineSecret) fetch(ctx context.Context) (string, error) { return s.text, nil } @@ -633,7 +718,7 @@ type fileSecret struct { file string } -func (s *fileSecret) fetch() (string, error) { +func (s *fileSecret) fetch(ctx context.Context) (string, error) { fileBytes, err := os.ReadFile(s.file) if err != nil { return "", fmt.Errorf("unable to read file %s: %w", s.file, err) @@ -649,18 +734,47 @@ func (s *fileSecret) immutable() bool { return false } -func secretFrom(text Secret, file string) secret { +// refSecret fetches a single secret from a secret manager. +type refSecret struct { + ref string + manager SecretManager +} + +func (s *refSecret) fetch(ctx context.Context) (string, error) { + return s.manager.Fetch(ctx, s.ref) +} + +func (s *refSecret) description() string { + return fmt.Sprintf("ref %s", s.ref) +} + +func (s *refSecret) immutable() bool { + return false +} + +// toSecret returns a secret from one of the given sources, assuming exactly +// one or none of the sources are provided. +func toSecret(secretManager SecretManager, text Secret, file, ref string) (secret, error) { if text != "" { return &inlineSecret{ text: string(text), - } + }, nil } if file != "" { return &fileSecret{ file: file, + }, nil + } + if ref != "" { + if secretManager == nil { + return nil, errors.New("cannot use secret ref without manager") } + return &refSecret{ + ref: ref, + manager: secretManager, + }, nil } - return nil + return nil, nil } type authorizationCredentialsRoundTripper struct { @@ -681,7 +795,7 @@ func (rt *authorizationCredentialsRoundTripper) RoundTrip(req *http.Request) (*h var authCredentials string if rt.authCredentials != nil { var err error - authCredentials, err = rt.authCredentials.fetch() + authCredentials, err = rt.authCredentials.fetch(req.Context()) if err != nil { return nil, fmt.Errorf("unable to read authorization credentials: %w", err) } @@ -720,14 +834,14 @@ func (rt *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, e var password string if rt.username != nil { var err error - username, err = rt.username.fetch() + username, err = rt.username.fetch(req.Context()) if err != nil { return nil, fmt.Errorf("unable to read basic auth username: %w", err) } } if rt.password != nil { var err error - password, err = rt.password.fetch() + password, err = rt.password.fetch(req.Context()) if err != nil { return nil, fmt.Errorf("unable to read basic auth password: %w", err) } @@ -745,21 +859,21 @@ func (rt *basicAuthRoundTripper) CloseIdleConnections() { type oauth2RoundTripper struct { config *OAuth2 - clientSecret secret rt http.RoundTripper next http.RoundTripper - secret string + clientSecret secret + lastSecret string mtx sync.RWMutex opts *httpClientOptions client *http.Client } -func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper { +func NewOAuth2RoundTripper(clientSecret secret, config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper { return &oauth2RoundTripper{ config: config, - clientSecret: secretFrom(config.ClientSecret, config.ClientSecretFile), next: next, opts: opts, + clientSecret: clientSecret, } } @@ -769,18 +883,28 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro changed bool ) - if rt.clientSecret != nil { - var err error - secret, err = rt.clientSecret.fetch() - if err != nil { - return nil, fmt.Errorf("unable to read oauth2 client secret: %w", err) + // Fetch the secret if it's our first run or always if the secret can change. + if rt.rt == nil || (rt.clientSecret != nil && !rt.clientSecret.immutable()) { + if rt.clientSecret != nil { + var err error + secret, err = rt.clientSecret.fetch(req.Context()) + if err != nil { + return nil, fmt.Errorf("unable to read oauth2 client secret: %w", err) + } + + if !rt.clientSecret.immutable() { + rt.mtx.RLock() + changed = secret != rt.lastSecret + rt.mtx.RUnlock() + } + } + + if rt.rt == nil { + changed = true } } - rt.mtx.RLock() - changed = secret != rt.secret - rt.mtx.RUnlock() - if changed || rt.rt == nil { + if changed { config := &clientcredentials.Config{ ClientID: rt.config.ClientID, ClientSecret: secret, @@ -789,7 +913,7 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro EndpointParams: mapToValues(rt.config.EndpointParams), } - tlsConfig, err := NewTLSConfig(&rt.config.TLSConfig) + tlsConfig, err := NewTLSConfig(&rt.config.TLSConfig, WithSecretManager(rt.opts.secretManager)) if err != nil { return nil, err } @@ -809,11 +933,14 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro } var t http.RoundTripper - tlsSettings := rt.config.TLSConfig.roundTripperSettings() + tlsSettings, err := rt.config.TLSConfig.roundTripperSettings(rt.opts.secretManager) + if err != nil { + return nil, err + } if tlsSettings.CA == nil || tlsSettings.CA.immutable() { t, _ = tlsTransport(tlsConfig) } else { - t, err = NewTLSRoundTripper(tlsConfig, tlsSettings, tlsTransport) + t, err = NewTLSRoundTripperWithContext(req.Context(), tlsConfig, tlsSettings, tlsTransport) if err != nil { return nil, err } @@ -828,7 +955,7 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro tokenSource := config.TokenSource(ctx) rt.mtx.Lock() - rt.secret = secret + rt.lastSecret = secret rt.rt = &oauth2.Transport{ Base: rt.next, Source: tokenSource, @@ -878,8 +1005,27 @@ func cloneRequest(r *http.Request) *http.Request { return r2 } +type tlsConfigOptions struct { + secretManager SecretManager +} + +// TLSConfigOption defines an option that can be applied to the HTTP client. +type TLSConfigOption interface { + applyToTLSConfigOptions(options *tlsConfigOptions) +} + +// NewTLSConfig creates a new tls.Config from the given TLSConfig. +func NewTLSConfig(cfg *TLSConfig, optFuncs ...TLSConfigOption) (*tls.Config, error) { + return NewTLSConfigWithContext(context.Background(), cfg, optFuncs...) +} + // NewTLSConfig creates a new tls.Config from the given TLSConfig. -func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) { +func NewTLSConfigWithContext(ctx context.Context, cfg *TLSConfig, optFuncs ...TLSConfigOption) (*tls.Config, error) { + opts := tlsConfigOptions{} + for _, opt := range optFuncs { + opt.applyToTLSConfigOptions(&opts) + } + if err := cfg.Validate(); err != nil { return nil, err } @@ -898,9 +1044,12 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) { // If a CA cert is provided then let's read it in so we can validate the // scrape target's certificate properly. - caSecret := secretFrom(Secret(cfg.CA), cfg.CAFile) + caSecret, err := toSecret(opts.secretManager, Secret(cfg.CA), cfg.CAFile, cfg.CARef) + if err != nil { + return nil, fmt.Errorf("unable to use CA cert: %w", err) + } if caSecret != nil { - ca, err := caSecret.fetch() + ca, err := caSecret.fetch(ctx) if err != nil { return nil, fmt.Errorf("unable to read CA cert: %w", err) } @@ -916,10 +1065,16 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) { // If a client cert & key is provided then configure TLS config accordingly. if cfg.usingClientCert() && cfg.usingClientKey() { // Verify that client cert and key are valid. - if _, err := cfg.getClientCertificate(nil); err != nil { + if _, err := cfg.getClientCertificate(ctx, opts.secretManager); err != nil { return nil, err } - tlsConfig.GetClientCertificate = cfg.getClientCertificate + tlsConfig.GetClientCertificate = func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { + var ctx context.Context + if cri != nil { + ctx = cri.Context() + } + return cfg.getClientCertificate(ctx, opts.secretManager) + } } return tlsConfig, nil @@ -939,6 +1094,15 @@ type TLSConfig struct { CertFile string `yaml:"cert_file,omitempty" json:"cert_file,omitempty"` // The client key file for the targets. KeyFile string `yaml:"key_file,omitempty" json:"key_file,omitempty"` + // CARef is the name of the secret within the secret manager to use as the CA cert for the + // targets. + CARef string `yaml:"ca_ref,omitempty" json:"ca_ref,omitempty"` + // CertRef is the name of the secret within the secret manager to use as the client cert for + // the targets. + CertRef string `yaml:"cert_ref,omitempty" json:"cert_ref,omitempty"` + // KeyRef is the name of the secret within the secret manager to use as the client key for + // the targets. + KeyRef string `yaml:"key_ref,omitempty" json:"key_ref,omitempty"` // Used to verify the hostname for the targets. ServerName string `yaml:"server_name,omitempty" json:"server_name,omitempty"` // Disable target certificate validation. @@ -972,13 +1136,13 @@ func (c *TLSConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { // file-based fields for the TLS CA, client certificate, and client key are // used. func (c *TLSConfig) Validate() error { - if len(c.CA) > 0 && len(c.CAFile) > 0 { - return fmt.Errorf("at most one of ca and ca_file must be configured") + if nonZeroCount(len(c.CA) > 0, len(c.CAFile) > 0, len(c.CARef) > 0) > 1 { + return fmt.Errorf("at most one of ca, ca_file & ca_ref must be configured") } - if len(c.Cert) > 0 && len(c.CertFile) > 0 { - return fmt.Errorf("at most one of cert and cert_file must be configured") + if nonZeroCount(len(c.Cert) > 0, len(c.CertFile) > 0, len(c.CertRef) > 0) > 1 { + return fmt.Errorf("at most one of cert, cert_file & cert_ref must be configured") } - if len(c.Key) > 0 && len(c.KeyFile) > 0 { + if nonZeroCount(len(c.Key) > 0, len(c.KeyFile) > 0, len(c.KeyRef) > 0) > 1 { return fmt.Errorf("at most one of key and key_file must be configured") } @@ -992,47 +1156,63 @@ func (c *TLSConfig) Validate() error { } func (c *TLSConfig) usingClientCert() bool { - return len(c.Cert) > 0 || len(c.CertFile) > 0 + return len(c.Cert) > 0 || len(c.CertFile) > 0 || len(c.CertRef) > 0 } func (c *TLSConfig) usingClientKey() bool { - return len(c.Key) > 0 || len(c.KeyFile) > 0 + return len(c.Key) > 0 || len(c.KeyFile) > 0 || len(c.KeyRef) > 0 } -func (c *TLSConfig) roundTripperSettings() TLSRoundTripperSettings { - return TLSRoundTripperSettings{ - CA: secretFrom(Secret(c.CA), c.CAFile), - Cert: secretFrom(Secret(c.Cert), c.CertFile), - Key: secretFrom(c.Key, c.KeyFile), +func (c *TLSConfig) roundTripperSettings(secretManager SecretManager) (TLSRoundTripperSettings, error) { + ca, err := toSecret(secretManager, Secret(c.CA), c.CAFile, c.CARef) + if err != nil { + return TLSRoundTripperSettings{}, err } + cert, err := toSecret(secretManager, Secret(c.Cert), c.CertFile, c.CertRef) + if err != nil { + return TLSRoundTripperSettings{}, err + } + key, err := toSecret(secretManager, c.Key, c.KeyFile, c.KeyRef) + if err != nil { + return TLSRoundTripperSettings{}, err + } + return TLSRoundTripperSettings{ + CA: ca, + Cert: cert, + Key: key, + }, nil } -// getClientCertificate reads the pair of client cert and key from disk and returns a tls.Certificate. -func (c *TLSConfig) getClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { +// getClientCertificate reads the pair of client cert and key and returns a tls.Certificate. +func (c *TLSConfig) getClientCertificate(ctx context.Context, secretManager SecretManager) (*tls.Certificate, error) { var ( certData, keyData string err error ) - certSecret := secretFrom(Secret(c.Cert), c.CertFile) + certSecret, err := toSecret(secretManager, Secret(c.Cert), c.CertFile, c.CertRef) + if err != nil { + return nil, fmt.Errorf("unable to use client cert: %w", err) + } if certSecret != nil { - certData, err = certSecret.fetch() + certData, err = certSecret.fetch(ctx) if err != nil { return nil, fmt.Errorf("unable to read specified client cert: %w", err) } } - keySecret := secretFrom(Secret(c.Key), c.KeyFile) + keySecret, err := toSecret(secretManager, Secret(c.Key), c.KeyFile, c.KeyRef) + if err != nil { + return nil, fmt.Errorf("unable to use client key: %w", err) + } if keySecret != nil { - keyData, err = keySecret.fetch() + keyData, err = keySecret.fetch(ctx) if err != nil { return nil, fmt.Errorf("unable to read specified client key: %w", err) } } - certStr := unsafe.Slice(unsafe.StringData(certData), len(certData)) - keyStr := unsafe.Slice(unsafe.StringData(keyData), len(keyData)) - cert, err := tls.X509KeyPair(certStr, keyStr) + cert, err := tls.X509KeyPair([]byte(certData), []byte(keyData)) if err != nil { return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %w", certSecret.description(), keySecret.description(), err) } @@ -1076,6 +1256,15 @@ func NewTLSRoundTripper( cfg *tls.Config, settings TLSRoundTripperSettings, newRT func(*tls.Config) (http.RoundTripper, error), +) (http.RoundTripper, error) { + return NewTLSRoundTripperWithContext(context.Background(), cfg, settings, newRT) +} + +func NewTLSRoundTripperWithContext( + ctx context.Context, + cfg *tls.Config, + settings TLSRoundTripperSettings, + newRT func(*tls.Config) (http.RoundTripper, error), ) (http.RoundTripper, error) { t := &tlsRoundTripper{ settings: settings, @@ -1088,7 +1277,7 @@ func NewTLSRoundTripper( return nil, err } t.rt = rt - _, t.hashCAData, t.hashCertData, t.hashKeyData, err = t.getTLSDataWithHash() + _, t.hashCAData, t.hashCertData, t.hashKeyData, err = t.getTLSDataWithHash(ctx) if err != nil { return nil, err } @@ -1096,11 +1285,11 @@ func NewTLSRoundTripper( return t, nil } -func (t *tlsRoundTripper) getTLSDataWithHash() ([]byte, []byte, []byte, []byte, error) { +func (t *tlsRoundTripper) getTLSDataWithHash(ctx context.Context) ([]byte, []byte, []byte, []byte, error) { var caBytes, certBytes, keyBytes []byte if t.settings.CA != nil { - ca, err := t.settings.CA.fetch() + ca, err := t.settings.CA.fetch(ctx) if err != nil { return nil, nil, nil, nil, fmt.Errorf("unable to read CA cert: %w", err) } @@ -1108,7 +1297,7 @@ func (t *tlsRoundTripper) getTLSDataWithHash() ([]byte, []byte, []byte, []byte, } if t.settings.Cert != nil { - cert, err := t.settings.Cert.fetch() + cert, err := t.settings.Cert.fetch(ctx) if err != nil { return nil, nil, nil, nil, fmt.Errorf("unable to read client cert: %w", err) } @@ -1116,7 +1305,7 @@ func (t *tlsRoundTripper) getTLSDataWithHash() ([]byte, []byte, []byte, []byte, } if t.settings.Key != nil { - key, err := t.settings.Key.fetch() + key, err := t.settings.Key.fetch(ctx) if err != nil { return nil, nil, nil, nil, fmt.Errorf("unable to read client key: %w", err) } @@ -1140,7 +1329,7 @@ func (t *tlsRoundTripper) getTLSDataWithHash() ([]byte, []byte, []byte, []byte, // RoundTrip implements the http.RoundTrip interface. func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - caData, caHash, certHash, keyHash, err := t.getTLSDataWithHash() + caData, caHash, certHash, keyHash, err := t.getTLSDataWithHash(req.Context()) if err != nil { return nil, err } diff --git a/config/http_config_test.go b/config/http_config_test.go index 085ecb75..5b03a0b5 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -81,11 +81,11 @@ var invalidHTTPClientConfigs = []struct { }, { httpClientConfigFile: "testdata/http.conf.basic-auth.too-much.bad.yaml", - errMsg: "at most one of basic_auth password & password_file must be configured", + errMsg: "at most one of basic_auth password, password_file & password_ref must be configured", }, { httpClientConfigFile: "testdata/http.conf.basic-auth.bad-username.yaml", - errMsg: "at most one of basic_auth username & username_file must be configured", + errMsg: "at most one of basic_auth username, username_file & username_ref must be configured", }, { httpClientConfigFile: "testdata/http.conf.mix-bearer-and-creds.bad.yaml", @@ -109,7 +109,7 @@ var invalidHTTPClientConfigs = []struct { }, { httpClientConfigFile: "testdata/http.conf.oauth2-secret-and-file-set.bad.yml", - errMsg: "at most one of oauth2 client_secret & client_secret_file must be configured", + errMsg: "at most one of oauth2 client_secret, client_secret_file & client_secret_ref must be configured", }, { httpClientConfigFile: "testdata/http.conf.oauth2-no-client-id.bad.yaml", @@ -892,7 +892,7 @@ func TestTLSConfigInvalidCA(t *testing.T) { ServerName: "", InsecureSkipVerify: false, }, - errorMessage: "at most one of cert and cert_file must be configured", + errorMessage: "at most one of cert, cert_file & cert_ref must be configured", }, { configTLSConfig: TLSConfig{ @@ -934,7 +934,7 @@ func TestBasicAuthNoPassword(t *testing.T) { t.Fatalf("Error casting to basic auth transport, %v", client.Transport) } - if username, _ := rt.username.fetch(); username != "user" { + if username, _ := rt.username.fetch(context.Background()); username != "user" { t.Errorf("Bad HTTP client username: %s", username) } if rt.password != nil { @@ -960,7 +960,7 @@ func TestBasicAuthNoUsername(t *testing.T) { if rt.username != nil { t.Errorf("Got unexpected username") } - if password, _ := rt.password.fetch(); password != "secret" { + if password, _ := rt.password.fetch(context.Background()); password != "secret" { t.Errorf("Unexpected HTTP client password: %s", password) } } @@ -980,14 +980,84 @@ func TestBasicAuthPasswordFile(t *testing.T) { t.Fatalf("Error casting to basic auth transport, %v", client.Transport) } - if username, _ := rt.username.fetch(); username != "user" { + if username, _ := rt.username.fetch(context.Background()); username != "user" { t.Errorf("Bad HTTP client username: %s", username) } - if password, _ := rt.password.fetch(); password != "foobar" { + if password, _ := rt.password.fetch(context.Background()); password != "foobar" { t.Errorf("Bad HTTP client password: %s", password) } } +type secretManager struct { + data map[string]string +} + +func (m *secretManager) Fetch(ctx context.Context, secretRef string) (string, error) { + secretData, ok := m.data[secretRef] + if !ok { + return "", fmt.Errorf("unknown secret %s", secretRef) + } + return secretData, nil +} + +func TestBasicAuthSecretManager(t *testing.T) { + cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.ref.yaml") + if err != nil { + t.Fatalf("Error loading HTTP client config: %v", err) + } + manager := secretManager{ + data: map[string]string{ + "admin": "user", + "pass": "foobar", + }, + } + client, err := NewClientFromConfig(*cfg, "test", WithSecretManager(&manager)) + if err != nil { + t.Fatalf("Error creating HTTP Client: %v", err) + } + + rt, ok := client.Transport.(*basicAuthRoundTripper) + if !ok { + t.Fatalf("Error casting to basic auth transport, %v", client.Transport) + } + + if username, _ := rt.username.fetch(context.Background()); username != "user" { + t.Errorf("Bad HTTP client username: %s", username) + } + if password, _ := rt.password.fetch(context.Background()); password != "foobar" { + t.Errorf("Bad HTTP client password: %s", password) + } +} + +func TestBasicAuthSecretManagerNotFound(t *testing.T) { + cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.ref.yaml") + if err != nil { + t.Fatalf("Error loading HTTP client config: %v", err) + } + manager := secretManager{ + data: map[string]string{ + "admin1": "user", + "foobar": "pass", + }, + } + client, err := NewClientFromConfig(*cfg, "test", WithSecretManager(&manager)) + if err != nil { + t.Fatalf("Error creating HTTP Client: %v", err) + } + + rt, ok := client.Transport.(*basicAuthRoundTripper) + if !ok { + t.Fatalf("Error casting to basic auth transport, %v", client.Transport) + } + + if _, err := rt.username.fetch(context.Background()); !strings.Contains(err.Error(), "unknown secret admin") { + t.Errorf("Unexpected error message: %s", err) + } + if _, err := rt.password.fetch(context.Background()); !strings.Contains(err.Error(), "unknown secret pass") { + t.Errorf("Unexpected error message: %s", err) + } +} + func TestBasicUsernameFile(t *testing.T) { cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.username-file.good.yaml") if err != nil { @@ -1003,10 +1073,10 @@ func TestBasicUsernameFile(t *testing.T) { t.Fatalf("Error casting to basic auth transport, %v", client.Transport) } - if username, _ := rt.username.fetch(); username != "testuser" { + if username, _ := rt.username.fetch(context.Background()); username != "testuser" { t.Errorf("Bad HTTP client username: %s", username) } - if password, _ := rt.password.fetch(); password != "foobar" { + if password, _ := rt.password.fetch(context.Background()); password != "foobar" { t.Errorf("Bad HTTP client passwordFile: %s", password) } } @@ -1396,7 +1466,7 @@ func TestTLSRoundTripperRaces(t *testing.T) { func TestHideHTTPClientConfigSecrets(t *testing.T) { c, _, err := LoadHTTPConfigFile("testdata/http.conf.good.yml") if err != nil { - t.Errorf("Error parsing %s: %s", "testdata/http.conf.good.yml", err) + t.Fatalf("Error parsing %s: %s", "testdata/http.conf.good.yml", err) } // String method must not reveal authentication credentials. @@ -1557,7 +1627,7 @@ endpoint_params: t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig) } - rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions) + rt := NewOAuth2RoundTripper(&inlineSecret{text: string(expectedConfig.ClientSecret)}, &expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions) client := http.Client{ Transport: rt, @@ -1727,7 +1797,7 @@ endpoint_params: t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig) } - rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions) + rt := NewOAuth2RoundTripper(&inlineSecret{text: string(expectedConfig.ClientSecret)}, &expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions) client := http.Client{ Transport: rt, diff --git a/config/testdata/http.conf.basic-auth.ref.yaml b/config/testdata/http.conf.basic-auth.ref.yaml new file mode 100644 index 00000000..68a7b3f8 --- /dev/null +++ b/config/testdata/http.conf.basic-auth.ref.yaml @@ -0,0 +1,3 @@ +basic_auth: + username_ref: admin + password_ref: pass \ No newline at end of file