Skip to content

Commit

Permalink
NewClientCertificateCredential requires certificate bytes instead of …
Browse files Browse the repository at this point in the history
…path (#15604)
  • Loading branch information
chlowell authored Sep 24, 2021
1 parent 6ff9912 commit b72b333
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 165 deletions.
10 changes: 10 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@
```
* Removed `ExcludeAzureCLICredential`, `ExcludeEnvironmentCredential`, and `ExcludeMSICredential`
from `DefaultAzureCredentialOptions`
* `NewClientCertificateCredential` requires the bytes of a certificate instead of
a path to a certificate file:
```go
// before
cred, err := NewClientCertificateCredential("tenant", "client-id", "/cert.pem", nil)

// after
certData, err := os.ReadFile("/cert.pem")
cred, err := NewClientCertificateCredential("tenant", "client-id", certData, nil)
```

### Features Added
* Added connection configuration options to `DefaultAzureCredentialOptions`
Expand Down
41 changes: 11 additions & 30 deletions sdk/azidentity/client_certificate_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ import (
"encoding/base64"
"encoding/pem"
"errors"
"io/ioutil"
"os"
"strings"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
Expand Down Expand Up @@ -56,35 +53,18 @@ type ClientCertificateCredential struct {
// NewClientCertificateCredential creates an instance of ClientCertificateCredential with the details needed to authenticate against Azure Active Directory with the specified certificate.
// tenantID: The Azure Active Directory tenant (directory) ID of the service principal.
// clientID: The client (application) ID of the service principal.
// certificatePath: The path to the client certificate used to authenticate the client. Supported formats are PEM and PFX.
// certData: The bytes of a certificate in PEM or PKCS12 format, including the private key.
// options: ClientCertificateCredentialOptions that can be used to provide additional configurations for the credential, such as the certificate password.
func NewClientCertificateCredential(tenantID string, clientID string, certificatePath string, options *ClientCertificateCredentialOptions) (*ClientCertificateCredential, error) {
func NewClientCertificateCredential(tenantID string, clientID string, certData []byte, options *ClientCertificateCredentialOptions) (*ClientCertificateCredential, error) {
if !validTenantID(tenantID) {
return nil, &CredentialUnavailableError{credentialType: "Client Certificate Credential", message: tenantIDValidationErr}
}
_, err := os.Stat(certificatePath)
if err != nil {
credErr := &CredentialUnavailableError{credentialType: "Client Certificate Credential", message: "Certificate file not found in path: " + certificatePath}
logCredentialError(credErr.credentialType, credErr)
return nil, credErr
}
certData, err := ioutil.ReadFile(certificatePath)
if err != nil {
credErr := &CredentialUnavailableError{credentialType: "Client Certificate Credential", message: err.Error()}
logCredentialError(credErr.credentialType, credErr)
return nil, credErr
}
if options == nil {
options = &ClientCertificateCredentialOptions{}
}
var cert *certContents
certificatePath = strings.ToUpper(certificatePath)
if strings.HasSuffix(certificatePath, ".PEM") {
cert, err = extractFromPEMFile(certData, options.Password, options.SendCertificateChain)
} else if strings.HasSuffix(certificatePath, ".PFX") {
cert, err = extractFromPFXFile(certData, options.Password, options.SendCertificateChain)
} else {
err = errors.New("only PEM and PFX files are supported")
cert, err := loadPEMCert(certData, options.Password, options.SendCertificateChain)
if err != nil {
cert, err = loadPKCS12Cert(certData, options.Password, options.SendCertificateChain)
}
if err != nil {
credErr := &CredentialUnavailableError{credentialType: "Client Certificate Credential", message: err.Error()}
Expand Down Expand Up @@ -174,7 +154,7 @@ func newCertContents(blocks []*pem.Block, fromPEM bool, sendCertificateChain boo
return &cc, nil
}

func extractFromPEMFile(certData []byte, password string, sendCertificateChain bool) (*certContents, error) {
func loadPEMCert(certData []byte, password string, sendCertificateChain bool) (*certContents, error) {
// TODO: wire up support for password
blocks := []*pem.Block{}
// read all of the PEM blocks
Expand All @@ -187,19 +167,20 @@ func extractFromPEMFile(certData []byte, password string, sendCertificateChain b
blocks = append(blocks, block)
}
if len(blocks) == 0 {
return nil, errors.New("didn't find any blocks in PEM file")
return nil, errors.New("didn't find any PEM blocks")
}
return newCertContents(blocks, true, sendCertificateChain)
}

func extractFromPFXFile(certData []byte, password string, sendCertificateChain bool) (*certContents, error) {
// convert PFX binary data to PEM blocks
func loadPKCS12Cert(certData []byte, password string, sendCertificateChain bool) (*certContents, error) {
// convert data to PEM blocks
blocks, err := pkcs12.ToPEM(certData, password)
if err != nil {
return nil, err
}
if len(blocks) == 0 {
return nil, errors.New("didn't find any blocks in PFX file")
// not mentioning PKCS12 in this message because we end up here when certData is garbage
return nil, errors.New("didn't find any certificate content")
}
return newCertContents(blocks, false, sendCertificateChain)
}
Expand Down
182 changes: 99 additions & 83 deletions sdk/azidentity/client_certificate_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io/ioutil"
"net/http"
"net/url"
"os"
"strings"
"testing"

Expand All @@ -17,13 +18,22 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
)

const (
certificatePath = "testdata/certificate.pem"
wrongCertificatePath = "wrong_certificate_path.pem"
)
var pemCert, _ = os.ReadFile("testdata/certificate.pem")
var pkcs12Cert, _ = os.ReadFile("testdata/certificate.pfx")
var pkcs12CertEncrypted, _ = os.ReadFile("testdata/certificate_encrypted_key.pfx")

var allCertTests = []struct {
name string
certData []byte
password string
}{
{"pem", pemCert, ""},
{"pkcs12", pkcs12Cert, ""},
{"pkcs12Encrypted", pkcs12CertEncrypted, "password"},
}

func TestClientCertificateCredential_InvalidTenantID(t *testing.T) {
cred, err := NewClientCertificateCredential(badTenantID, clientID, certificatePath, nil)
cred, err := NewClientCertificateCredential(badTenantID, clientID, pemCert, nil)
if err == nil {
t.Fatal("Expected an error but received none")
}
Expand All @@ -37,7 +47,7 @@ func TestClientCertificateCredential_InvalidTenantID(t *testing.T) {
}

func TestClientCertificateCredential_CreateAuthRequestSuccess(t *testing.T) {
cred, err := NewClientCertificateCredential(tenantID, clientID, certificatePath, nil)
cred, err := NewClientCertificateCredential(tenantID, clientID, pemCert, nil)
if err != nil {
t.Fatalf("Failed to instantiate credential")
}
Expand Down Expand Up @@ -83,7 +93,7 @@ func TestClientCertificateCredential_CreateAuthRequestSuccess(t *testing.T) {
func TestClientCertificateCredential_CreateAuthRequestSuccess_withCertificateChain(t *testing.T) {
opts := ClientCertificateCredentialOptions{}
opts.SendCertificateChain = true
cred, err := NewClientCertificateCredential(tenantID, clientID, certificatePath, &opts)
cred, err := NewClientCertificateCredential(tenantID, clientID, pemCert, &opts)
if err != nil {
t.Fatalf("Failed to instantiate credential")
}
Expand Down Expand Up @@ -116,11 +126,7 @@ func TestClientCertificateCredential_CreateAuthRequestSuccess_withCertificateCha
t.Fatalf("Wrong client assertion type assigned to request")
}
// create a client assertion for comparison with the one in the request
certData, err := ioutil.ReadFile(certificatePath)
if err != nil {
t.Fatalf("Failed to read certificate: %v", err)
}
cert, err := extractFromPEMFile(certData, "", true)
cert, err := loadPEMCert(pemCert, "", true)
if err != nil {
t.Fatalf("Failed extract data from PEM file: %v", err)
}
Expand Down Expand Up @@ -153,71 +159,73 @@ func TestClientCertificateCredential_CreateAuthRequestSuccess_withCertificateCha
}

func TestClientCertificateCredential_GetTokenSuccess(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
options := ClientCertificateCredentialOptions{}
options.AuthorityHost = AuthorityHost(srv.URL())
options.HTTPClient = srv
cred, err := NewClientCertificateCredential(tenantID, clientID, certificatePath, &options)
if err != nil {
t.Fatalf("Expected an empty error but received: %s", err.Error())
}
_, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}})
if err != nil {
t.Fatalf("Expected an empty error but received: %s", err.Error())
for _, test := range allCertTests {
t.Run(test.name, func(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
options := ClientCertificateCredentialOptions{}
options.AuthorityHost = AuthorityHost(srv.URL())
options.HTTPClient = srv
options.Password = test.password
cred, err := NewClientCertificateCredential(tenantID, clientID, test.certData, &options)
if err != nil {
t.Fatalf("Expected an empty error but received: %s", err.Error())
}
_, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}})
if err != nil {
t.Fatalf("Expected an empty error but received: %s", err.Error())
}
})
}
}

func TestClientCertificateCredential_GetTokenSuccess_withCertificateChain(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
options := ClientCertificateCredentialOptions{}
options.AuthorityHost = AuthorityHost(srv.URL())
options.SendCertificateChain = true
options.HTTPClient = srv
cred, err := NewClientCertificateCredential(tenantID, clientID, certificatePath, &options)
if err != nil {
t.Fatalf("Expected an empty error but received: %s", err.Error())
}
_, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}})
if err != nil {
t.Fatalf("Expected an empty error but received: %s", err.Error())
for _, test := range allCertTests {
t.Run(test.name, func(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
options := ClientCertificateCredentialOptions{}
options.AuthorityHost = AuthorityHost(srv.URL())
options.SendCertificateChain = true
options.HTTPClient = srv
options.Password = test.password
cred, err := NewClientCertificateCredential(tenantID, clientID, test.certData, &options)
if err != nil {
t.Fatalf("Expected an empty error but received: %s", err.Error())
}
_, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}})
if err != nil {
t.Fatalf("Expected an empty error but received: %s", err.Error())
}
})
}
}

func TestClientCertificateCredential_GetTokenInvalidCredentials(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.SetResponse(mock.WithStatusCode(http.StatusUnauthorized))
options := ClientCertificateCredentialOptions{}
options.AuthorityHost = AuthorityHost(srv.URL())
options.HTTPClient = srv
cred, err := NewClientCertificateCredential(tenantID, clientID, certificatePath, &options)
if err != nil {
t.Fatalf("Did not expect an error but received one: %v", err)
}
_, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}})
if err == nil {
t.Fatalf("Expected to receive a nil error, but received: %v", err)
}
var authFailed *AuthenticationFailedError
if !errors.As(err, &authFailed) {
t.Fatalf("Expected: AuthenticationFailedError, Received: %T", err)
}
}

func TestClientCertificateCredential_WrongCertificatePath(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.SetResponse(mock.WithStatusCode(http.StatusUnauthorized))
options := ClientCertificateCredentialOptions{}
options.AuthorityHost = AuthorityHost(srv.URL())
options.HTTPClient = srv
_, err := NewClientCertificateCredential(tenantID, clientID, wrongCertificatePath, &options)
if err == nil {
t.Fatalf("Expected an error but did not receive one")
for _, test := range allCertTests {
t.Run(test.name, func(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.SetResponse(mock.WithStatusCode(http.StatusUnauthorized))
options := ClientCertificateCredentialOptions{}
options.AuthorityHost = AuthorityHost(srv.URL())
options.HTTPClient = srv
options.Password = test.password
cred, err := NewClientCertificateCredential(tenantID, clientID, test.certData, &options)
if err != nil {
t.Fatalf("Did not expect an error but received one: %v", err)
}
_, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}})
if err == nil {
t.Fatalf("Expected to receive a nil error, but received: %v", err)
}
var authFailed *AuthenticationFailedError
if !errors.As(err, &authFailed) {
t.Fatalf("Expected: AuthenticationFailedError, Received: %T", err)
}
})
}
}

Expand All @@ -228,7 +236,11 @@ func TestClientCertificateCredential_GetTokenCheckPrivateKeyBlocks(t *testing.T)
options := ClientCertificateCredentialOptions{}
options.AuthorityHost = AuthorityHost(srv.URL())
options.HTTPClient = srv
cred, err := NewClientCertificateCredential(tenantID, clientID, "testdata/certificate_formatB.pem", &options)
certData, err := os.ReadFile("testdata/certificate_formatB.pem")
if err != nil {
t.Fatalf("Failed to read certificate file: %s", err.Error())
}
cred, err := NewClientCertificateCredential(tenantID, clientID, certData, &options)
if err != nil {
t.Fatalf("Expected an empty error but received: %s", err.Error())
}
Expand All @@ -238,44 +250,48 @@ func TestClientCertificateCredential_GetTokenCheckPrivateKeyBlocks(t *testing.T)
}
}

func TestClientCertificateCredential_GetTokenCheckCertificateBlocks(t *testing.T) {
func TestClientCertificateCredential_NoData(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
options := ClientCertificateCredentialOptions{}
options.AuthorityHost = AuthorityHost(srv.URL())
options.HTTPClient = srv
cred, err := NewClientCertificateCredential(tenantID, clientID, "testdata/certificate_formatA.pem", &options)
if err != nil {
t.Fatalf("Expected an empty error but received: %s", err.Error())
}
_, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}})
if err != nil {
t.Fatalf("Expected an empty error but received: %s", err.Error())
_, err := NewClientCertificateCredential(tenantID, clientID, []byte{}, &options)
if err == nil {
t.Fatalf("Expected an error but received nil")
}
}

func TestClientCertificateCredential_GetTokenEmptyCertificate(t *testing.T) {
func TestClientCertificateCredential_NoCertificate(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
options := ClientCertificateCredentialOptions{}
options.AuthorityHost = AuthorityHost(srv.URL())
options.HTTPClient = srv
_, err := NewClientCertificateCredential(tenantID, clientID, "testdata/certificate_empty.pem", &options)
certData, err := os.ReadFile("testdata/certificate_empty.pem")
if err != nil {
t.Fatalf("Failed to read certificate file: %s", err.Error())
}
_, err = NewClientCertificateCredential(tenantID, clientID, certData, &options)
if err == nil {
t.Fatalf("Expected an error but received nil")
}
}

func TestClientCertificateCredential_GetTokenNoPrivateKey(t *testing.T) {
func TestClientCertificateCredential_NoPrivateKey(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
options := ClientCertificateCredentialOptions{}
options.AuthorityHost = AuthorityHost(srv.URL())
options.HTTPClient = srv
_, err := NewClientCertificateCredential(tenantID, clientID, "testdata/certificate_nokey.pem", &options)
certData, err := os.ReadFile("testdata/certificate_nokey.pem")
if err != nil {
t.Fatalf("Failed to read certificate file: %s", err.Error())
}
_, err = NewClientCertificateCredential(tenantID, clientID, certData, &options)
if err == nil {
t.Fatalf("Expected an error but received nil")
}
Expand All @@ -289,7 +305,7 @@ func TestBearerPolicy_ClientCertificateCredential(t *testing.T) {
options := ClientCertificateCredentialOptions{}
options.AuthorityHost = AuthorityHost(srv.URL())
options.HTTPClient = srv
cred, err := NewClientCertificateCredential(tenantID, clientID, certificatePath, &options)
cred, err := NewClientCertificateCredential(tenantID, clientID, pemCert, &options)
if err != nil {
t.Fatalf("Did not expect an error but received: %v", err)
}
Expand Down
8 changes: 6 additions & 2 deletions sdk/azidentity/environment_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,13 @@ func NewEnvironmentCredential(options *EnvironmentCredentialOptions) (*Environme
}
return &EnvironmentCredential{cred: cred}, nil
}
if clientCertificate := os.Getenv("AZURE_CLIENT_CERTIFICATE_PATH"); clientCertificate != "" {
if certPath := os.Getenv("AZURE_CLIENT_CERTIFICATE_PATH"); certPath != "" {
log.Write(LogCredential, "Azure Identity => NewEnvironmentCredential() invoking ClientCertificateCredential")
cred, err := NewClientCertificateCredential(tenantID, clientID, clientCertificate, &ClientCertificateCredentialOptions{AuthorityHost: options.AuthorityHost, HTTPClient: options.HTTPClient, Retry: options.Retry, Telemetry: options.Telemetry, Logging: options.Logging})
certData, err := os.ReadFile(certPath)
if err != nil {
return nil, &CredentialUnavailableError{credentialType: "Environment Credential", message: "Failed to read certificate file: " + err.Error()}
}
cred, err := NewClientCertificateCredential(tenantID, clientID, certData, &ClientCertificateCredentialOptions{AuthorityHost: options.AuthorityHost, HTTPClient: options.HTTPClient, Retry: options.Retry, Telemetry: options.Telemetry, Logging: options.Logging})
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit b72b333

Please sign in to comment.