Skip to content

Commit

Permalink
Use common interface to fetch secrets in HTTP client config
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Hrabovcak <thespiritxiii@gmail.com>
  • Loading branch information
TheSpiritXIII committed Nov 30, 2023
1 parent 1d8c672 commit 0d08b50
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 151 deletions.
217 changes: 105 additions & 112 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
Expand All @@ -29,6 +30,7 @@ import (
"strings"
"sync"
"time"
"unsafe"

"github.com/mwitkow/go-conntrack"
"golang.org/x/net/http/httpproxy"
Expand Down Expand Up @@ -546,21 +548,17 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT

// If a authorization_credentials is provided, create a round tripper that will set the
// Authorization header correctly on each request.
if cfg.Authorization != nil && len(cfg.Authorization.Credentials) > 0 {
rt = NewAuthorizationCredentialsRoundTripper(cfg.Authorization.Type, cfg.Authorization.Credentials, rt)
} else if cfg.Authorization != nil && len(cfg.Authorization.CredentialsFile) > 0 {
rt = NewAuthorizationCredentialsFileRoundTripper(cfg.Authorization.Type, cfg.Authorization.CredentialsFile, rt)
if cfg.Authorization != nil && (len(cfg.Authorization.Credentials) > 0 || len(cfg.Authorization.CredentialsFile) > 0) {
rt = NewAuthorizationCredentialsRoundTripper(cfg.Authorization.Type, secretFrom(cfg.Authorization.Credentials, cfg.Authorization.CredentialsFile), rt)
}
// Backwards compatibility, be nice with importers who would not have
// called Validate().
if len(cfg.BearerToken) > 0 {
rt = NewAuthorizationCredentialsRoundTripper("Bearer", cfg.BearerToken, rt)
} else if len(cfg.BearerTokenFile) > 0 {
rt = NewAuthorizationCredentialsFileRoundTripper("Bearer", cfg.BearerTokenFile, rt)
if len(cfg.BearerToken) > 0 || len(cfg.BearerTokenFile) > 0 {
rt = NewAuthorizationCredentialsRoundTripper("Bearer", secretFrom(cfg.BearerToken, cfg.BearerTokenFile), rt)
}

if cfg.BasicAuth != nil {
rt = NewBasicAuthRoundTripper(cfg.BasicAuth.Username, cfg.BasicAuth.Password, cfg.BasicAuth.UsernameFile, cfg.BasicAuth.PasswordFile, rt)
rt = NewBasicAuthRoundTripper(secretFrom(Secret(cfg.BasicAuth.Username), cfg.BasicAuth.UsernameFile), secretFrom(cfg.BasicAuth.Password, cfg.BasicAuth.PasswordFile), rt)
}

if cfg.OAuth2 != nil {
Expand All @@ -587,52 +585,67 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.roundTripperSettings(), newRT)
}

type authorizationCredentialsRoundTripper struct {
authType string
authCredentials Secret
rt http.RoundTripper
type secret interface {
fetch() (string, error)
}

// NewAuthorizationCredentialsRoundTripper adds the provided credentials to a
// request unless the authorization header has already been set.
func NewAuthorizationCredentialsRoundTripper(authType string, authCredentials Secret, rt http.RoundTripper) http.RoundTripper {
return &authorizationCredentialsRoundTripper{authType, authCredentials, rt}
type inlineSecret struct {
text string
}

func (rt *authorizationCredentialsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if len(req.Header.Get("Authorization")) == 0 {
req = cloneRequest(req)
req.Header.Set("Authorization", fmt.Sprintf("%s %s", rt.authType, string(rt.authCredentials)))
func (s *inlineSecret) fetch() (string, error) {
return s.text, nil
}

type fileSecret struct {
file string
}

func (s *fileSecret) fetch() (string, error) {
fileBytes, err := os.ReadFile(s.file)
if err != nil {
return "", fmt.Errorf("unable to read file %s: %w", s.file, err)
}
return rt.rt.RoundTrip(req)
return strings.TrimSpace(string(fileBytes)), nil
}

func (rt *authorizationCredentialsRoundTripper) CloseIdleConnections() {
if ci, ok := rt.rt.(closeIdler); ok {
ci.CloseIdleConnections()
func secretFrom(text Secret, file string) secret {
if text != "" {
return &inlineSecret{
text: string(text),
}
}
if file != "" {
return &fileSecret{
file: file,
}
}
return nil
}

type authorizationCredentialsFileRoundTripper struct {
authType string
authCredentialsFile string
rt http.RoundTripper
type authorizationCredentialsRoundTripper struct {
authType string
authCredentials secret
rt http.RoundTripper
}

// NewAuthorizationCredentialsFileRoundTripper adds the authorization
// credentials read from the provided file to a request unless the authorization
// header has already been set. This file is read for every request.
func NewAuthorizationCredentialsFileRoundTripper(authType, authCredentialsFile string, rt http.RoundTripper) http.RoundTripper {
return &authorizationCredentialsFileRoundTripper{authType, authCredentialsFile, rt}
// NewAuthorizationCredentialsRoundTripper adds the authorization credentials
// read from the provided secret to a request unless the authorization header
// has already been set.
func NewAuthorizationCredentialsRoundTripper(authType string, authCredentials secret, rt http.RoundTripper) http.RoundTripper {
return &authorizationCredentialsRoundTripper{authType, authCredentials, rt}
}

func (rt *authorizationCredentialsFileRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
func (rt *authorizationCredentialsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if len(req.Header.Get("Authorization")) == 0 {
b, err := os.ReadFile(rt.authCredentialsFile)
if err != nil {
return nil, fmt.Errorf("unable to read authorization credentials file %s: %s", rt.authCredentialsFile, err)
var authCredentials string
if rt.authCredentials != nil {
var err error
authCredentials, err = rt.authCredentials.fetch()
if err != nil {
return nil, fmt.Errorf("unable to get authorization credentials: %w", err)
}
}
authCredentials := strings.TrimSpace(string(b))

req = cloneRequest(req)
req.Header.Set("Authorization", fmt.Sprintf("%s %s", rt.authType, authCredentials))
Expand All @@ -641,49 +654,43 @@ func (rt *authorizationCredentialsFileRoundTripper) RoundTrip(req *http.Request)
return rt.rt.RoundTrip(req)
}

func (rt *authorizationCredentialsFileRoundTripper) CloseIdleConnections() {
func (rt *authorizationCredentialsRoundTripper) CloseIdleConnections() {
if ci, ok := rt.rt.(closeIdler); ok {
ci.CloseIdleConnections()
}
}

type basicAuthRoundTripper struct {
username string
password Secret
usernameFile string
passwordFile string
rt http.RoundTripper
username secret
password secret
rt http.RoundTripper
}

// NewBasicAuthRoundTripper will apply a BASIC auth authorization header to a request unless it has
// already been set.
func NewBasicAuthRoundTripper(username string, password Secret, usernameFile, passwordFile string, rt http.RoundTripper) http.RoundTripper {
return &basicAuthRoundTripper{username, password, usernameFile, passwordFile, rt}
func NewBasicAuthRoundTripper(username secret, password secret, rt http.RoundTripper) http.RoundTripper {
return &basicAuthRoundTripper{username, password, rt}
}

func (rt *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
var username string
var password string
if len(req.Header.Get("Authorization")) != 0 {
return rt.rt.RoundTrip(req)
}
if rt.usernameFile != "" {
usernameBytes, err := os.ReadFile(rt.usernameFile)
var username string
var password string
if rt.username != nil {
var err error
username, err = rt.username.fetch()
if err != nil {
return nil, fmt.Errorf("unable to read basic auth username file %s: %s", rt.usernameFile, err)
return nil, fmt.Errorf("unable to get basic auth username: %w", err)
}
username = strings.TrimSpace(string(usernameBytes))
} else {
username = rt.username
}
if rt.passwordFile != "" {
passwordBytes, err := os.ReadFile(rt.passwordFile)
if rt.password != nil {
var err error
password, err = rt.password.fetch()
if err != nil {
return nil, fmt.Errorf("unable to read basic auth password file %s: %s", rt.passwordFile, err)
return nil, fmt.Errorf("unable to get basic auth password: %w", err)
}
password = strings.TrimSpace(string(passwordBytes))
} else {
password = string(rt.password)
}
req = cloneRequest(req)
req.SetBasicAuth(username, password)
Expand All @@ -697,20 +704,22 @@ func (rt *basicAuthRoundTripper) CloseIdleConnections() {
}

type oauth2RoundTripper struct {
config *OAuth2
rt http.RoundTripper
next http.RoundTripper
secret string
mtx sync.RWMutex
opts *httpClientOptions
client *http.Client
config *OAuth2
clientSecret secret
rt http.RoundTripper
next http.RoundTripper
secret string
mtx sync.RWMutex
opts *httpClientOptions
client *http.Client
}

func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper {
return &oauth2RoundTripper{
config: config,
next: next,
opts: opts,
config: config,
clientSecret: secretFrom(config.ClientSecret, config.ClientSecretFile),
next: next,
opts: opts,
}
}

Expand All @@ -720,22 +729,18 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
changed bool
)

if rt.config.ClientSecretFile != "" {
data, err := os.ReadFile(rt.config.ClientSecretFile)
if rt.clientSecret != nil {
var err error
secret, err = rt.clientSecret.fetch()
if err != nil {
return nil, fmt.Errorf("unable to read oauth2 client secret file %s: %s", rt.config.ClientSecretFile, err)
return nil, fmt.Errorf("unable to get oauth2 client secret: %w", err)
}
secret = strings.TrimSpace(string(data))
rt.mtx.RLock()
changed = secret != rt.secret
rt.mtx.RUnlock()
}
rt.mtx.RLock()
changed = secret != rt.secret
rt.mtx.RUnlock()

if changed || rt.rt == nil {
if rt.config.ClientSecret != "" {
secret = string(rt.config.ClientSecret)
}

config := &clientcredentials.Config{
ClientID: rt.config.ClientID,
ClientSecret: secret,
Expand Down Expand Up @@ -852,17 +857,14 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {

// If a CA cert is provided then let's read it in so we can validate the
// scrape target's certificate properly.
if len(cfg.CA) > 0 {
if !updateRootCA(tlsConfig, []byte(cfg.CA)) {
return nil, fmt.Errorf("unable to use inline CA cert")
}
} else if len(cfg.CAFile) > 0 {
b, err := readCAFile(cfg.CAFile)
caSecret := secretFrom(Secret(cfg.CA), cfg.CAFile)
if caSecret != nil {
ca, err := caSecret.fetch()
if err != nil {
return nil, err
return nil, fmt.Errorf("unable to get CA cert: %w", err)
}
if !updateRootCA(tlsConfig, b) {
return nil, fmt.Errorf("unable to use specified CA cert %s", cfg.CAFile)
if !updateRootCA(tlsConfig, []byte(ca)) {
return nil, errors.New("unable to use CA cert")
}
}

Expand Down Expand Up @@ -970,45 +972,36 @@ func (c *TLSConfig) roundTripperSettings() TLSRoundTripperSettings {
// getClientCertificate reads the pair of client cert and key from disk and returns a tls.Certificate.
func (c *TLSConfig) getClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
var (
certData, keyData []byte
certData, keyData string
err error
)

if c.CertFile != "" {
certData, err = os.ReadFile(c.CertFile)
certSecret := secretFrom(Secret(c.Cert), c.CertFile)
if certSecret != nil {
certData, err = certSecret.fetch()
if err != nil {
return nil, fmt.Errorf("unable to read specified client cert (%s): %s", c.CertFile, err)
return nil, fmt.Errorf("unable to get client cert: %w", err)
}
} else {
certData = []byte(c.Cert)
}

if c.KeyFile != "" {
keyData, err = os.ReadFile(c.KeyFile)
keySecret := secretFrom(Secret(c.Key), c.KeyFile)
if keySecret != nil {
keyData, err = keySecret.fetch()
if err != nil {
return nil, fmt.Errorf("unable to read specified client key (%s): %s", c.KeyFile, err)
return nil, fmt.Errorf("unable to get client key: %w", err)
}
} else {
keyData = []byte(c.Key)
}

cert, err := tls.X509KeyPair(certData, keyData)
certStr := unsafe.Slice(unsafe.StringData(certData), len(certData))
keyStr := unsafe.Slice(unsafe.StringData(keyData), len(keyData))
cert, err := tls.X509KeyPair(certStr, keyStr)
if err != nil {
return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %s", c.CertFile, c.KeyFile, err)
}

return &cert, nil
}

// readCAFile reads the CA cert file from disk.
func readCAFile(f string) ([]byte, error) {
data, err := os.ReadFile(f)
if err != nil {
return nil, fmt.Errorf("unable to load specified CA cert %s: %s", f, err)
}
return data, nil
}

// updateRootCA parses the given byte slice as a series of PEM encoded certificates and updates tls.Config.RootCAs.
func updateRootCA(cfg *tls.Config, b []byte) bool {
caCertPool := x509.NewCertPool()
Expand Down
Loading

0 comments on commit 0d08b50

Please sign in to comment.