Skip to content

Commit

Permalink
advancedtls: Add SNI logic to ServerOptions.GetCertificate (#3697)
Browse files Browse the repository at this point in the history
* Add SNI support in the user-provided GetCertificate callback
  • Loading branch information
cindyxue authored Jul 28, 2020
1 parent 8b7764b commit dfc0c05
Show file tree
Hide file tree
Showing 13 changed files with 638 additions and 30 deletions.
20 changes: 13 additions & 7 deletions security/advancedtls/advancedtls.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ type ClientOptions struct {
// Certificates or GetClientCertificate indicates the certificates sent from
// the server to the client to prove server's identities. The rules for setting
// these two fields are:
// Either Certificates or GetCertificate must be set; the other will be ignored.
// Either Certificates or GetCertificates must be set; the other will be ignored.
type ServerOptions struct {
// If field Certificates is set, field GetClientCertificate will be ignored.
// The server will use Certificates every time when asked for a certificate,
Expand All @@ -166,7 +166,7 @@ type ServerOptions struct {
// invoke this function every time asked to present certificates to the
// client when a new connection is established. This is known as peer
// certificate reloading.
GetCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error)
GetCertificates func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
// VerifyPeer is a custom verification check after certificate signature
// check.
// If this is set, we will perform this customized check after doing the
Expand Down Expand Up @@ -210,8 +210,8 @@ func (o *ClientOptions) config() (*tls.Config, error) {
}

func (o *ServerOptions) config() (*tls.Config, error) {
if o.Certificates == nil && o.GetCertificate == nil {
return nil, fmt.Errorf("either Certificates or GetCertificate must be specified")
if o.Certificates == nil && o.GetCertificates == nil {
return nil, fmt.Errorf("either Certificates or GetCertificates must be specified")
}
if o.RequireClientCert && o.VType == SkipVerification && o.VerifyPeer == nil {
return nil, fmt.Errorf(
Expand All @@ -234,9 +234,15 @@ func (o *ServerOptions) config() (*tls.Config, error) {
clientAuth = tls.RequireAnyClientCert
}
config := &tls.Config{
ClientAuth: clientAuth,
Certificates: o.Certificates,
GetCertificate: o.GetCertificate,
ClientAuth: clientAuth,
Certificates: o.Certificates,
}
if o.GetCertificates != nil {
// GetCertificate is only able to perform SNI logic for go1.10 and above.
// It will return the first certificate in o.GetCertificates for go1.9.
config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return buildGetCertificates(clientHello, o)
}
}
if clientCAs != nil {
config.ClientCAs = clientCAs
Expand Down
18 changes: 9 additions & 9 deletions security/advancedtls/advancedtls_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func TestEnd2End(t *testing.T) {
clientVerifyFunc CustomVerificationFunc
clientVType VerificationType
serverCert []tls.Certificate
serverGetCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)
serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
serverRoot *x509.CertPool
serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
serverVerifyFunc CustomVerificationFunc
Expand Down Expand Up @@ -271,12 +271,12 @@ func TestEnd2End(t *testing.T) {
return &VerificationResults{}, nil
},
clientVType: CertVerification,
serverGetCert: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
switch stage.read() {
case 0:
return &cs.serverPeer1, nil
return []*tls.Certificate{&cs.serverPeer1}, nil
default:
return &cs.serverPeer2, nil
return []*tls.Certificate{&cs.serverPeer2}, nil
}
},
serverRoot: cs.serverTrust1,
Expand Down Expand Up @@ -336,12 +336,12 @@ func TestEnd2End(t *testing.T) {
return nil, fmt.Errorf("custom authz check fails")
},
clientVType: CertVerification,
serverGetCert: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
switch stage.read() {
case 0:
return &cs.serverPeer1, nil
return []*tls.Certificate{&cs.serverPeer1}, nil
default:
return &cs.serverPeer2, nil
return []*tls.Certificate{&cs.serverPeer2}, nil
}
},
serverRoot: cs.serverTrust1,
Expand Down Expand Up @@ -388,8 +388,8 @@ func TestEnd2End(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
// Start a server using ServerOptions in another goroutine.
serverOptions := &ServerOptions{
Certificates: test.serverCert,
GetCertificate: test.serverGetCert,
Certificates: test.serverCert,
GetCertificates: test.serverGetCert,
RootCertificateOptions: RootCertificateOptions{
RootCACerts: test.serverRoot,
GetRootCAs: test.serverGetRoot,
Expand Down
26 changes: 13 additions & 13 deletions security/advancedtls/advancedtls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func TestClientServerHandshake(t *testing.T) {
clientExpectHandshakeError bool
serverMutualTLS bool
serverCert []tls.Certificate
serverGetCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)
serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
serverRoot *x509.CertPool
serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
serverVerifyFunc CustomVerificationFunc
Expand Down Expand Up @@ -279,8 +279,8 @@ func TestClientServerHandshake(t *testing.T) {
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &serverPeerCert, nil
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&serverPeerCert}, nil
},
serverGetRoot: getRootCAsForServer,
serverVerifyFunc: serverVerifyFunc,
Expand All @@ -300,8 +300,8 @@ func TestClientServerHandshake(t *testing.T) {
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &serverPeerCert, nil
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&serverPeerCert}, nil
},
serverGetRoot: getRootCAsForServer,
serverVerifyFunc: serverVerifyFunc,
Expand All @@ -322,8 +322,8 @@ func TestClientServerHandshake(t *testing.T) {
clientVType: CertVerification,
clientExpectHandshakeError: true,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &serverPeerCert, nil
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&serverPeerCert}, nil
},
serverGetRoot: getRootCAsForServer,
serverVerifyFunc: serverVerifyFunc,
Expand All @@ -344,8 +344,8 @@ func TestClientServerHandshake(t *testing.T) {
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &clientPeerCert, nil
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&clientPeerCert}, nil
},
serverGetRoot: getRootCAsForServer,
serverVerifyFunc: serverVerifyFunc,
Expand All @@ -366,8 +366,8 @@ func TestClientServerHandshake(t *testing.T) {
clientVType: CertVerification,
clientExpectHandshakeError: true,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &serverPeerCert, nil
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&serverPeerCert}, nil
},
serverGetRoot: getRootCAsForClient,
serverVerifyFunc: serverVerifyFunc,
Expand Down Expand Up @@ -402,8 +402,8 @@ func TestClientServerHandshake(t *testing.T) {
}
// Start a server using ServerOptions in another goroutine.
serverOptions := &ServerOptions{
Certificates: test.serverCert,
GetCertificate: test.serverGetCert,
Certificates: test.serverCert,
GetCertificates: test.serverGetCert,
RootCertificateOptions: RootCertificateOptions{
RootCACerts: test.serverRoot,
GetRootCAs: test.serverGetRoot,
Expand Down
2 changes: 1 addition & 1 deletion security/advancedtls/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ go 1.13

require (
github.com/golang/protobuf v1.3.5 // indirect
github.com/google/go-cmp v0.4.0 // indirect
github.com/google/go-cmp v0.4.0
golang.org/x/net v0.0.0-20200602114024-627f9648deb9 // indirect
golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 // indirect
golang.org/x/text v0.3.3 // indirect
Expand Down
53 changes: 53 additions & 0 deletions security/advancedtls/sni_110.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// +build go1.10

/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package advancedtls

import (
"crypto/tls"
"fmt"
)

// buildGetCertificates returns the certificate that matches the SNI field
// for the given ClientHelloInfo, defaulting to the first element of o.GetCertificates.
func buildGetCertificates(clientHello *tls.ClientHelloInfo, o *ServerOptions) (*tls.Certificate, error) {
if o.GetCertificates == nil {
return nil, fmt.Errorf("function GetCertificates must be specified")
}
certificates, err := o.GetCertificates(clientHello)
if err != nil {
return nil, err
}
if len(certificates) == 0 {
return nil, fmt.Errorf("no certificates configured")
}
// If users pass in only one certificate, return that certificate.
if len(certificates) == 1 {
return certificates[0], nil
}
// Choose the SNI certificate using SupportsCertificate.
for _, cert := range certificates {
if err := clientHello.SupportsCertificate(cert); err == nil {
return cert, nil
}
}
// If nothing matches, return the first certificate.
return certificates[0], nil
}
41 changes: 41 additions & 0 deletions security/advancedtls/sni_before_110.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// +build !go1.10

/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package advancedtls

import (
"crypto/tls"
"fmt"
)

// buildGetCertificates returns the first element of o.GetCertificates.
func buildGetCertificates(clientHello *tls.ClientHelloInfo, o *ServerOptions) (*tls.Certificate, error) {
if o.GetCertificates == nil {
return nil, fmt.Errorf("function GetCertificates must be specified")
}
certificates, err := o.GetCertificates(clientHello)
if err != nil {
return nil, err
}
if len(certificates) == 0 {
return nil, fmt.Errorf("no certificates configured")
}
return certificates[0], nil
}
108 changes: 108 additions & 0 deletions security/advancedtls/sni_test_110.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// +build go1.10

/*
*
* Copyright 2019 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package advancedtls

import (
"crypto/tls"
"testing"

"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/security/advancedtls/testdata"
)

// TestGetCertificatesSNI tests SNI logic for go1.10 and above.
func TestGetCertificatesSNI(t *testing.T) {
// Load server certificates for setting the serverGetCert callback function.
serverCert1, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), testdata.Path("server_key_1.pem"))
if err != nil {
t.Fatalf("tls.LoadX509KeyPair(server_cert_1.pem, server_key_1.pem) failed: %v", err)
}
serverCert2, err := tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"), testdata.Path("server_key_2.pem"))
if err != nil {
t.Fatalf("tls.LoadX509KeyPair(server_cert_2.pem, server_key_2.pem) failed: %v", err)
}
serverCert3, err := tls.LoadX509KeyPair(testdata.Path("server_cert_3.pem"), testdata.Path("server_key_3.pem"))
if err != nil {
t.Fatalf("tls.LoadX509KeyPair(server_cert_3.pem, server_key_3.pem) failed: %v", err)
}

tests := []struct {
desc string
serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
serverName string
wantCert tls.Certificate
}{
{
desc: "Select serverCert1",
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil
},
// "foo.bar.com" is the common name on server certificate server_cert_1.pem.
serverName: "foo.bar.com",
wantCert: serverCert1,
},
{
desc: "Select serverCert2",
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil
},
// "foo.bar.server2.com" is the common name on server certificate server_cert_2.pem.
serverName: "foo.bar.server2.com",
wantCert: serverCert2,
},
{
desc: "Select serverCert3",
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil
},
// "google.com" is one of the DNS names on server certificate server_cert_3.pem.
serverName: "google.com",
wantCert: serverCert3,
},
}
for _, test := range tests {
test := test
t.Run(test.desc, func(t *testing.T) {
serverOptions := &ServerOptions{
GetCertificates: test.serverGetCert,
}
serverConfig, err := serverOptions.config()
if err != nil {
t.Fatalf("serverOptions.config() failed: %v", err)
}
pointFormatUncompressed := uint8(0)
clientHello := &tls.ClientHelloInfo{
CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA},
ServerName: test.serverName,
SupportedCurves: []tls.CurveID{tls.CurveP256},
SupportedPoints: []uint8{pointFormatUncompressed},
SupportedVersions: []uint16{tls.VersionTLS10},
}
gotCertificate, err := serverConfig.GetCertificate(clientHello)
if err != nil {
t.Fatalf("serverConfig.GetCertificate(clientHello) failed: %v", err)
}
if !cmp.Equal(gotCertificate, test.wantCert, cmp.AllowUnexported(tls.Certificate{})) {
t.Errorf("GetCertificates() = %v, want %v", gotCertificate, test.wantCert)
}
})
}
}
Loading

0 comments on commit dfc0c05

Please sign in to comment.