diff --git a/security/advancedtls/advancedtls.go b/security/advancedtls/advancedtls.go index f794c090d5e4..aa6cd8b07d72 100644 --- a/security/advancedtls/advancedtls.go +++ b/security/advancedtls/advancedtls.go @@ -35,10 +35,10 @@ import ( credinternal "google.golang.org/grpc/internal/credentials" ) -// VerificationFuncParams contains parameters available to users when -// implementing CustomVerificationFunc. +// HandshakeVerificationInfo contains information about a handshake needed for +// verification for use when implementing the `PostHandshakeVerificationFunc` // The fields in this struct are read-only. -type VerificationFuncParams struct { +type HandshakeVerificationInfo struct { // The target server name that the client connects to when establishing the // connection. This field is only meaningful for client side. On server side, // this field would be an empty string. @@ -54,17 +54,36 @@ type VerificationFuncParams struct { Leaf *x509.Certificate } -// VerificationResults contains the information about results of -// CustomVerificationFunc. -// VerificationResults is an empty struct for now. It may be extended in the +// VerificationFuncParams contains parameters available to users when +// implementing CustomVerificationFunc. +// The fields in this struct are read-only. +// +// Deprecated: use HandshakeVerificationInfo instead. +type VerificationFuncParams = HandshakeVerificationInfo + +// PostHandshakeVerificationResults contains the information about results of +// PostHandshakeVerificationFunc. +// PostHandshakeVerificationResults is an empty struct for now. It may be extended in the // future to include more information. -type VerificationResults struct{} +type PostHandshakeVerificationResults struct{} + +// Deprecated: use PostHandshakeVerificationResults instead. +type VerificationResults = PostHandshakeVerificationResults + +// PostHandshakeVerificationFunc is the function defined by users to perform +// custom verification checks after chain building and regular handshake +// verification has been completed. +// PostHandshakeVerificationFunc should return (nil, error) if the authorization +// should fail, with the error containing information on why it failed. +type PostHandshakeVerificationFunc func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) // CustomVerificationFunc is the function defined by users to perform custom // verification check. // CustomVerificationFunc returns nil if the authorization fails; otherwise // returns an empty struct. -type CustomVerificationFunc func(params *VerificationFuncParams) (*VerificationResults, error) +// +// Deprecated: use PostHandshakeVerificationFunc instead. +type CustomVerificationFunc = PostHandshakeVerificationFunc // GetRootCAsParams contains the parameters available to users when // implementing GetRootCAs. @@ -167,11 +186,18 @@ type ClientOptions struct { // IdentityOptions is OPTIONAL on client side. This field only needs to be // set if mutual authentication is required on server side. IdentityOptions IdentityCertificateOptions + // AdditionalPeerVerification is a custom verification check after certificate signature + // check. + // If this is set, we will perform this customized check after doing the + // normal check(s) indicated by setting VerificationType. + AdditionalPeerVerification PostHandshakeVerificationFunc // VerifyPeer is a custom verification check after certificate signature // check. // If this is set, we will perform this customized check after doing the - // normal check(s) indicated by setting VType. - VerifyPeer CustomVerificationFunc + // normal check(s) indicated by setting VerificationType. + // + // Deprecated: use AdditionalPeerVerification instead. + VerifyPeer PostHandshakeVerificationFunc // RootOptions is OPTIONAL on client side. If not set, we will try to use the // default trust certificates in users' OS system. RootOptions RootCertificateOptions @@ -206,11 +232,18 @@ type ClientOptions struct { type ServerOptions struct { // IdentityOptions is REQUIRED on server side. IdentityOptions IdentityCertificateOptions + // AdditionalPeerVerification is a custom verification check after certificate signature + // check. + // If this is set, we will perform this customized check after doing the + // normal check(s) indicated by setting VerificationType. + AdditionalPeerVerification PostHandshakeVerificationFunc // VerifyPeer is a custom verification check after certificate signature // check. // If this is set, we will perform this customized check after doing the - // normal check(s) indicated by setting VType. - VerifyPeer CustomVerificationFunc + // normal check(s) indicated by setting VerificationType. + // + // Deprecated: use AdditionalPeerVerification instead. + VerifyPeer PostHandshakeVerificationFunc // RootOptions is OPTIONAL on server side. This field only needs to be set if // mutual authentication is required(RequireClientCert is true). RootOptions RootCertificateOptions @@ -239,13 +272,18 @@ type ServerOptions struct { } func (o *ClientOptions) config() (*tls.Config, error) { + // TODO(gtcooke94) Remove this block when o.VerifyPeer is remoed. + // VerifyPeer is deprecated, but do this to aid the transitory migration time. + if o.AdditionalPeerVerification == nil { + o.AdditionalPeerVerification = o.VerifyPeer + } // TODO(gtcooke94). VType is deprecated, eventually remove this block. This // will ensure that users still explicitly setting `VType` will get the // setting to the right place. if o.VType != CertAndHostVerification { o.VerificationType = o.VType } - if o.VerificationType == SkipVerification && o.VerifyPeer == nil { + if o.VerificationType == SkipVerification && o.AdditionalPeerVerification == nil { return nil, fmt.Errorf("client needs to provide custom verification mechanism if choose to skip default verification") } // Make sure users didn't specify more than one fields in @@ -321,13 +359,18 @@ func (o *ClientOptions) config() (*tls.Config, error) { } func (o *ServerOptions) config() (*tls.Config, error) { + // TODO(gtcooke94) Remove this block when o.VerifyPeer is remoed. + // VerifyPeer is deprecated, but do this to aid the transitory migration time. + if o.AdditionalPeerVerification == nil { + o.AdditionalPeerVerification = o.VerifyPeer + } // TODO(gtcooke94). VType is deprecated, eventually remove this block. This // will ensure that users still explicitly setting `VType` will get the // setting to the right place. if o.VType != CertAndHostVerification { o.VerificationType = o.VType } - if o.RequireClientCert && o.VerificationType == SkipVerification && o.VerifyPeer == nil { + if o.RequireClientCert && o.VerificationType == SkipVerification && o.AdditionalPeerVerification == nil { return nil, fmt.Errorf("server needs to provide custom verification mechanism if choose to skip default verification, but require client certificate(s)") } // Make sure users didn't specify more than one fields in @@ -416,7 +459,7 @@ func (o *ServerOptions) config() (*tls.Config, error) { // using TLS. type advancedTLSCreds struct { config *tls.Config - verifyFunc CustomVerificationFunc + verifyFunc PostHandshakeVerificationFunc getRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error) isClient bool verificationType VerificationType @@ -579,7 +622,7 @@ func buildVerifyFunc(c *advancedTLSCreds, } // Perform custom verification check if specified. if c.verifyFunc != nil { - _, err := c.verifyFunc(&VerificationFuncParams{ + _, err := c.verifyFunc(&HandshakeVerificationInfo{ ServerName: serverName, RawCerts: rawCerts, VerifiedChains: chains, @@ -602,7 +645,7 @@ func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error) config: conf, isClient: true, getRootCAs: o.RootOptions.GetRootCertificates, - verifyFunc: o.VerifyPeer, + verifyFunc: o.AdditionalPeerVerification, verificationType: o.VerificationType, revocationConfig: o.RevocationConfig, } @@ -621,7 +664,7 @@ func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error) config: conf, isClient: false, getRootCAs: o.RootOptions.GetRootCertificates, - verifyFunc: o.VerifyPeer, + verifyFunc: o.AdditionalPeerVerification, verificationType: o.VerificationType, revocationConfig: o.RevocationConfig, } diff --git a/security/advancedtls/advancedtls_integration_test.go b/security/advancedtls/advancedtls_integration_test.go index 4a6457f101dc..ca080c956b97 100644 --- a/security/advancedtls/advancedtls_integration_test.go +++ b/security/advancedtls/advancedtls_integration_test.go @@ -143,13 +143,13 @@ func (s) TestEnd2End(t *testing.T) { clientGetCert func(*tls.CertificateRequestInfo) (*tls.Certificate, error) clientRoot *x509.CertPool clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) - clientVerifyFunc CustomVerificationFunc + clientVerifyFunc PostHandshakeVerificationFunc clientVerificationType VerificationType serverCert []tls.Certificate serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) serverRoot *x509.CertPool serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) - serverVerifyFunc CustomVerificationFunc + serverVerifyFunc PostHandshakeVerificationFunc serverVerificationType VerificationType }{ // Test Scenarios: @@ -175,8 +175,8 @@ func (s) TestEnd2End(t *testing.T) { } }, clientRoot: cs.ClientTrust1, - clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { - return &VerificationResults{}, nil + clientVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) { + return &PostHandshakeVerificationResults{}, nil }, clientVerificationType: CertVerification, serverCert: []tls.Certificate{cs.ServerCert1}, @@ -188,8 +188,8 @@ func (s) TestEnd2End(t *testing.T) { return &GetRootCAsResults{TrustCerts: cs.ServerTrust2}, nil } }, - serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { - return &VerificationResults{}, nil + serverVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) { + return &PostHandshakeVerificationResults{}, nil }, serverVerificationType: CertVerification, }, @@ -216,8 +216,8 @@ func (s) TestEnd2End(t *testing.T) { return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil } }, - clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { - return &VerificationResults{}, nil + clientVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) { + return &PostHandshakeVerificationResults{}, nil }, clientVerificationType: CertVerification, serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) { @@ -229,8 +229,8 @@ func (s) TestEnd2End(t *testing.T) { } }, serverRoot: cs.ServerTrust1, - serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { - return &VerificationResults{}, nil + serverVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) { + return &PostHandshakeVerificationResults{}, nil }, serverVerificationType: CertVerification, }, @@ -258,7 +258,7 @@ func (s) TestEnd2End(t *testing.T) { return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil } }, - clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { + clientVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) { if len(params.RawCerts) == 0 { return nil, fmt.Errorf("no peer certs") } @@ -280,7 +280,7 @@ func (s) TestEnd2End(t *testing.T) { } } if authzCheck { - return &VerificationResults{}, nil + return &PostHandshakeVerificationResults{}, nil } return nil, fmt.Errorf("custom authz check fails") }, @@ -294,8 +294,8 @@ func (s) TestEnd2End(t *testing.T) { } }, serverRoot: cs.ServerTrust1, - serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { - return &VerificationResults{}, nil + serverVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) { + return &PostHandshakeVerificationResults{}, nil }, serverVerificationType: CertVerification, }, @@ -314,16 +314,16 @@ func (s) TestEnd2End(t *testing.T) { desc: "TestServerCustomVerification", clientCert: []tls.Certificate{cs.ClientCert1}, clientRoot: cs.ClientTrust1, - clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { - return &VerificationResults{}, nil + clientVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) { + return &PostHandshakeVerificationResults{}, nil }, clientVerificationType: CertVerification, serverCert: []tls.Certificate{cs.ServerCert1}, serverRoot: cs.ServerTrust1, - serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) { + serverVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) { switch stage.read() { case 0, 2: - return &VerificationResults{}, nil + return &PostHandshakeVerificationResults{}, nil case 1: return nil, fmt.Errorf("custom authz check fails") default: @@ -345,9 +345,9 @@ func (s) TestEnd2End(t *testing.T) { RootCACerts: test.serverRoot, GetRootCertificates: test.serverGetRoot, }, - RequireClientCert: true, - VerifyPeer: test.serverVerifyFunc, - VerificationType: test.serverVerificationType, + RequireClientCert: true, + AdditionalPeerVerification: test.serverVerifyFunc, + VerificationType: test.serverVerificationType, } serverTLSCreds, err := NewServerCreds(serverOptions) if err != nil { @@ -368,7 +368,7 @@ func (s) TestEnd2End(t *testing.T) { Certificates: test.clientCert, GetIdentityCertificatesForClient: test.clientGetCert, }, - VerifyPeer: test.clientVerifyFunc, + AdditionalPeerVerification: test.clientVerifyFunc, RootOptions: RootCertificateOptions{ RootCACerts: test.clientRoot, GetRootCertificates: test.clientGetRoot, @@ -635,8 +635,8 @@ func (s) TestPEMFileProviderEnd2End(t *testing.T) { RootProvider: serverRootProvider, }, RequireClientCert: true, - VerifyPeer: func(params *VerificationFuncParams) (*VerificationResults, error) { - return &VerificationResults{}, nil + AdditionalPeerVerification: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) { + return &PostHandshakeVerificationResults{}, nil }, VerificationType: CertVerification, } @@ -658,8 +658,8 @@ func (s) TestPEMFileProviderEnd2End(t *testing.T) { IdentityOptions: IdentityCertificateOptions{ IdentityProvider: clientIdentityProvider, }, - VerifyPeer: func(params *VerificationFuncParams) (*VerificationResults, error) { - return &VerificationResults{}, nil + AdditionalPeerVerification: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) { + return &PostHandshakeVerificationResults{}, nil }, RootOptions: RootCertificateOptions{ RootProvider: clientRootProvider, diff --git a/security/advancedtls/advancedtls_test.go b/security/advancedtls/advancedtls_test.go index 0439363d2406..ab881b458978 100644 --- a/security/advancedtls/advancedtls_test.go +++ b/security/advancedtls/advancedtls_test.go @@ -369,7 +369,7 @@ func (s) TestClientServerHandshake(t *testing.T) { getRootCAsForClient := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil } - clientVerifyFuncGood := func(params *VerificationFuncParams) (*VerificationResults, error) { + clientVerifyFuncGood := func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) { if params.ServerName == "" { return nil, errors.New("client side server name should have a value") } @@ -378,15 +378,15 @@ func (s) TestClientServerHandshake(t *testing.T) { return nil, errors.New("client side params parsing error") } - return &VerificationResults{}, nil + return &PostHandshakeVerificationResults{}, nil } - verifyFuncBad := func(params *VerificationFuncParams) (*VerificationResults, error) { + verifyFuncBad := func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) { return nil, fmt.Errorf("custom verification function failed") } getRootCAsForServer := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { return &GetRootCAsResults{TrustCerts: cs.ServerTrust1}, nil } - serverVerifyFunc := func(params *VerificationFuncParams) (*VerificationResults, error) { + serverVerifyFunc := func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) { if params.ServerName != "" { return nil, errors.New("server side server name should not have a value") } @@ -395,7 +395,7 @@ func (s) TestClientServerHandshake(t *testing.T) { return nil, errors.New("server side params parsing error") } - return &VerificationResults{}, nil + return &PostHandshakeVerificationResults{}, nil } getRootCAsForServerBad := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { return nil, fmt.Errorf("bad root certificate reloading") @@ -431,7 +431,7 @@ func (s) TestClientServerHandshake(t *testing.T) { clientGetCert func(*tls.CertificateRequestInfo) (*tls.Certificate, error) clientRoot *x509.CertPool clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) - clientVerifyFunc CustomVerificationFunc + clientVerifyFunc PostHandshakeVerificationFunc clientVerificationType VerificationType clientRootProvider certprovider.Provider clientIdentityProvider certprovider.Provider @@ -442,7 +442,7 @@ func (s) TestClientServerHandshake(t *testing.T) { serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) serverRoot *x509.CertPool serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) - serverVerifyFunc CustomVerificationFunc + serverVerifyFunc PostHandshakeVerificationFunc serverVerificationType VerificationType serverRootProvider certprovider.Provider serverIdentityProvider certprovider.Provider @@ -822,10 +822,10 @@ func (s) TestClientServerHandshake(t *testing.T) { GetRootCertificates: test.serverGetRoot, RootProvider: test.serverRootProvider, }, - RequireClientCert: test.serverMutualTLS, - VerifyPeer: test.serverVerifyFunc, - VerificationType: test.serverVerificationType, - RevocationConfig: test.serverRevocationConfig, + RequireClientCert: test.serverMutualTLS, + AdditionalPeerVerification: test.serverVerifyFunc, + VerificationType: test.serverVerificationType, + RevocationConfig: test.serverRevocationConfig, } go func(done chan credentials.AuthInfo, lis net.Listener, serverOptions *ServerOptions) { serverRawConn, err := lis.Accept() @@ -861,7 +861,7 @@ func (s) TestClientServerHandshake(t *testing.T) { GetIdentityCertificatesForClient: test.clientGetCert, IdentityProvider: test.clientIdentityProvider, }, - VerifyPeer: test.clientVerifyFunc, + AdditionalPeerVerification: test.clientVerifyFunc, RootOptions: RootCertificateOptions{ RootCACerts: test.clientRoot, GetRootCertificates: test.clientGetRoot, diff --git a/security/advancedtls/examples/credential_reloading_from_files/client/main.go b/security/advancedtls/examples/credential_reloading_from_files/client/main.go index 212bcdf48359..810a8bf5780b 100644 --- a/security/advancedtls/examples/credential_reloading_from_files/client/main.go +++ b/security/advancedtls/examples/credential_reloading_from_files/client/main.go @@ -76,8 +76,8 @@ func main() { IdentityOptions: advancedtls.IdentityCertificateOptions{ IdentityProvider: identityProvider, }, - VerifyPeer: func(params *advancedtls.VerificationFuncParams) (*advancedtls.VerificationResults, error) { - return &advancedtls.VerificationResults{}, nil + AdditionalPeerVerification: func(params *advancedtls.HandshakeVerificationInfo) (*advancedtls.PostHandshakeVerificationResults, error) { + return &advancedtls.PostHandshakeVerificationResults{}, nil }, RootOptions: advancedtls.RootCertificateOptions{ RootProvider: rootProvider, diff --git a/security/advancedtls/examples/credential_reloading_from_files/server/main.go b/security/advancedtls/examples/credential_reloading_from_files/server/main.go index 030f2b7772b3..16d389860ccb 100644 --- a/security/advancedtls/examples/credential_reloading_from_files/server/main.go +++ b/security/advancedtls/examples/credential_reloading_from_files/server/main.go @@ -84,10 +84,10 @@ func main() { RootProvider: rootProvider, }, RequireClientCert: true, - VerifyPeer: func(params *advancedtls.VerificationFuncParams) (*advancedtls.VerificationResults, error) { + AdditionalPeerVerification: func(params *advancedtls.HandshakeVerificationInfo) (*advancedtls.PostHandshakeVerificationResults, error) { // This message is to show the certificate under the hood is actually reloaded. fmt.Printf("Client common name: %s.\n", params.Leaf.Subject.CommonName) - return &advancedtls.VerificationResults{}, nil + return &advancedtls.PostHandshakeVerificationResults{}, nil }, VerificationType: advancedtls.CertVerification, }