Skip to content

Commit d6ccff6

Browse files
arjan-baljanardhankrishna-sai
authored andcommitted
credentials: Add experimental credentials that don't enforce ALPN (grpc#7980)
1 parent 969bdd7 commit d6ccff6

File tree

8 files changed

+1531
-2
lines changed

8 files changed

+1531
-2
lines changed

credentials/tls.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ import (
3232
"google.golang.org/grpc/internal/envconfig"
3333
)
3434

35+
const alpnFailureHelpMessage = "If you upgraded from a grpc-go version earlier than 1.67, your TLS connections may have stopped working due to ALPN enforcement. For more details, see: https://github.com/grpc/grpc-go/issues/434"
36+
3537
var logger = grpclog.Component("credentials")
3638

3739
// TLSInfo contains the auth information for a TLS authenticated connection.
@@ -128,7 +130,7 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon
128130
if np == "" {
129131
if envconfig.EnforceALPNEnabled {
130132
conn.Close()
131-
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property")
133+
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage)
132134
}
133135
logger.Warningf("Allowing TLS connection to server %q with ALPN disabled. TLS connections to servers with ALPN disabled will be disallowed in future grpc-go releases", cfg.ServerName)
134136
}
@@ -158,7 +160,7 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
158160
if cs.NegotiatedProtocol == "" {
159161
if envconfig.EnforceALPNEnabled {
160162
conn.Close()
161-
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property")
163+
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage)
162164
} else if logger.V(2) {
163165
logger.Info("Allowing TLS connection from client with ALPN disabled. TLS connections with ALPN disabled will be disallowed in future grpc-go releases")
164166
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
/*
2+
*
3+
* Copyright 2025 gRPC authors.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*
17+
*/
18+
19+
package credentials
20+
21+
import (
22+
"context"
23+
"crypto/tls"
24+
"net"
25+
"strings"
26+
"testing"
27+
"time"
28+
29+
"google.golang.org/grpc/credentials"
30+
"google.golang.org/grpc/internal/grpctest"
31+
"google.golang.org/grpc/testdata"
32+
)
33+
34+
const defaultTestTimeout = 10 * time.Second
35+
36+
type s struct {
37+
grpctest.Tester
38+
}
39+
40+
func Test(t *testing.T) {
41+
grpctest.RunSubTests(t, s{})
42+
}
43+
44+
func (s) TestTLSOverrideServerName(t *testing.T) {
45+
expectedServerName := "server.name"
46+
c := NewTLSWithALPNDisabled(nil)
47+
c.OverrideServerName(expectedServerName)
48+
if c.Info().ServerName != expectedServerName {
49+
t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
50+
}
51+
}
52+
53+
func (s) TestTLSClone(t *testing.T) {
54+
expectedServerName := "server.name"
55+
c := NewTLSWithALPNDisabled(nil)
56+
c.OverrideServerName(expectedServerName)
57+
cc := c.Clone()
58+
if cc.Info().ServerName != expectedServerName {
59+
t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName)
60+
}
61+
cc.OverrideServerName("")
62+
if c.Info().ServerName != expectedServerName {
63+
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
64+
}
65+
66+
}
67+
68+
type serverHandshake func(net.Conn) (credentials.AuthInfo, error)
69+
70+
func (s) TestClientHandshakeReturnsAuthInfo(t *testing.T) {
71+
tcs := []struct {
72+
name string
73+
address string
74+
}{
75+
{
76+
name: "localhost",
77+
address: "localhost:0",
78+
},
79+
{
80+
name: "ipv4",
81+
address: "127.0.0.1:0",
82+
},
83+
{
84+
name: "ipv6",
85+
address: "[::1]:0",
86+
},
87+
}
88+
89+
for _, tc := range tcs {
90+
t.Run(tc.name, func(t *testing.T) {
91+
done := make(chan credentials.AuthInfo, 1)
92+
lis := launchServerOnListenAddress(t, tlsServerHandshake, done, tc.address)
93+
defer lis.Close()
94+
lisAddr := lis.Addr().String()
95+
clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr)
96+
// wait until server sends serverAuthInfo or fails.
97+
serverAuthInfo, ok := <-done
98+
if !ok {
99+
t.Fatalf("Error at server-side")
100+
}
101+
if !compare(clientAuthInfo, serverAuthInfo) {
102+
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo)
103+
}
104+
})
105+
}
106+
}
107+
108+
func (s) TestServerHandshakeReturnsAuthInfo(t *testing.T) {
109+
done := make(chan credentials.AuthInfo, 1)
110+
lis := launchServer(t, gRPCServerHandshake, done)
111+
defer lis.Close()
112+
clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String())
113+
// wait until server sends serverAuthInfo or fails.
114+
serverAuthInfo, ok := <-done
115+
if !ok {
116+
t.Fatalf("Error at server-side")
117+
}
118+
if !compare(clientAuthInfo, serverAuthInfo) {
119+
t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo)
120+
}
121+
}
122+
123+
func (s) TestServerAndClientHandshake(t *testing.T) {
124+
done := make(chan credentials.AuthInfo, 1)
125+
lis := launchServer(t, gRPCServerHandshake, done)
126+
defer lis.Close()
127+
clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String())
128+
// wait until server sends serverAuthInfo or fails.
129+
serverAuthInfo, ok := <-done
130+
if !ok {
131+
t.Fatalf("Error at server-side")
132+
}
133+
if !compare(clientAuthInfo, serverAuthInfo) {
134+
t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo)
135+
}
136+
}
137+
138+
func compare(a1, a2 credentials.AuthInfo) bool {
139+
if a1.AuthType() != a2.AuthType() {
140+
return false
141+
}
142+
switch a1.AuthType() {
143+
case "tls":
144+
state1 := a1.(credentials.TLSInfo).State
145+
state2 := a2.(credentials.TLSInfo).State
146+
if state1.Version == state2.Version &&
147+
state1.HandshakeComplete == state2.HandshakeComplete &&
148+
state1.CipherSuite == state2.CipherSuite &&
149+
state1.NegotiatedProtocol == state2.NegotiatedProtocol {
150+
return true
151+
}
152+
return false
153+
default:
154+
return false
155+
}
156+
}
157+
158+
func launchServer(t *testing.T, hs serverHandshake, done chan credentials.AuthInfo) net.Listener {
159+
return launchServerOnListenAddress(t, hs, done, "localhost:0")
160+
}
161+
162+
func launchServerOnListenAddress(t *testing.T, hs serverHandshake, done chan credentials.AuthInfo, address string) net.Listener {
163+
lis, err := net.Listen("tcp", address)
164+
if err != nil {
165+
if strings.Contains(err.Error(), "bind: cannot assign requested address") ||
166+
strings.Contains(err.Error(), "socket: address family not supported by protocol") {
167+
t.Skipf("no support for address %v", address)
168+
}
169+
t.Fatalf("Failed to listen: %v", err)
170+
}
171+
go serverHandle(t, hs, done, lis)
172+
return lis
173+
}
174+
175+
// Is run in a separate goroutine.
176+
func serverHandle(t *testing.T, hs serverHandshake, done chan credentials.AuthInfo, lis net.Listener) {
177+
serverRawConn, err := lis.Accept()
178+
if err != nil {
179+
t.Errorf("Server failed to accept connection: %v", err)
180+
close(done)
181+
return
182+
}
183+
serverAuthInfo, err := hs(serverRawConn)
184+
if err != nil {
185+
t.Errorf("Server failed while handshake. Error: %v", err)
186+
serverRawConn.Close()
187+
close(done)
188+
return
189+
}
190+
done <- serverAuthInfo
191+
}
192+
193+
func clientHandle(t *testing.T, hs func(net.Conn, string) (credentials.AuthInfo, error), lisAddr string) credentials.AuthInfo {
194+
conn, err := net.Dial("tcp", lisAddr)
195+
if err != nil {
196+
t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err)
197+
}
198+
defer conn.Close()
199+
clientAuthInfo, err := hs(conn, lisAddr)
200+
if err != nil {
201+
t.Fatalf("Error on client while handshake. Error: %v", err)
202+
}
203+
return clientAuthInfo
204+
}
205+
206+
// Server handshake implementation in gRPC.
207+
func gRPCServerHandshake(conn net.Conn) (credentials.AuthInfo, error) {
208+
serverTLS, err := NewServerTLSFromFileWithALPNDisabled(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
209+
if err != nil {
210+
return nil, err
211+
}
212+
_, serverAuthInfo, err := serverTLS.ServerHandshake(conn)
213+
if err != nil {
214+
return nil, err
215+
}
216+
return serverAuthInfo, nil
217+
}
218+
219+
// Client handshake implementation in gRPC.
220+
func gRPCClientHandshake(conn net.Conn, lisAddr string) (credentials.AuthInfo, error) {
221+
clientTLS := NewTLSWithALPNDisabled(&tls.Config{InsecureSkipVerify: true})
222+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
223+
defer cancel()
224+
_, authInfo, err := clientTLS.ClientHandshake(ctx, lisAddr, conn)
225+
if err != nil {
226+
return nil, err
227+
}
228+
return authInfo, nil
229+
}
230+
231+
func tlsServerHandshake(conn net.Conn) (credentials.AuthInfo, error) {
232+
cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
233+
if err != nil {
234+
return nil, err
235+
}
236+
serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
237+
serverConn := tls.Server(conn, serverTLSConfig)
238+
err = serverConn.Handshake()
239+
if err != nil {
240+
return nil, err
241+
}
242+
return credentials.TLSInfo{State: serverConn.ConnectionState(), CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}, nil
243+
}
244+
245+
func tlsClientHandshake(conn net.Conn, _ string) (credentials.AuthInfo, error) {
246+
clientTLSConfig := &tls.Config{InsecureSkipVerify: true}
247+
clientConn := tls.Client(conn, clientTLSConfig)
248+
if err := clientConn.Handshake(); err != nil {
249+
return nil, err
250+
}
251+
return credentials.TLSInfo{State: clientConn.ConnectionState(), CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}, nil
252+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
*
3+
* Copyright 2025 gRPC authors.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*
17+
*/
18+
19+
// Package internal defines APIs for parsing SPIFFE ID.
20+
//
21+
// All APIs in this package are experimental.
22+
package internal
23+
24+
import (
25+
"crypto/tls"
26+
"crypto/x509"
27+
"net/url"
28+
29+
"google.golang.org/grpc/grpclog"
30+
)
31+
32+
var logger = grpclog.Component("credentials")
33+
34+
// SPIFFEIDFromState parses the SPIFFE ID from State. If the SPIFFE ID format
35+
// is invalid, return nil with warning.
36+
func SPIFFEIDFromState(state tls.ConnectionState) *url.URL {
37+
if len(state.PeerCertificates) == 0 || len(state.PeerCertificates[0].URIs) == 0 {
38+
return nil
39+
}
40+
return SPIFFEIDFromCert(state.PeerCertificates[0])
41+
}
42+
43+
// SPIFFEIDFromCert parses the SPIFFE ID from x509.Certificate. If the SPIFFE
44+
// ID format is invalid, return nil with warning.
45+
func SPIFFEIDFromCert(cert *x509.Certificate) *url.URL {
46+
if cert == nil || cert.URIs == nil {
47+
return nil
48+
}
49+
var spiffeID *url.URL
50+
for _, uri := range cert.URIs {
51+
if uri == nil || uri.Scheme != "spiffe" || uri.Opaque != "" || (uri.User != nil && uri.User.Username() != "") {
52+
continue
53+
}
54+
// From this point, we assume the uri is intended for a SPIFFE ID.
55+
if len(uri.String()) > 2048 {
56+
logger.Warning("invalid SPIFFE ID: total ID length larger than 2048 bytes")
57+
return nil
58+
}
59+
if len(uri.Host) == 0 || len(uri.Path) == 0 {
60+
logger.Warning("invalid SPIFFE ID: domain or workload ID is empty")
61+
return nil
62+
}
63+
if len(uri.Host) > 255 {
64+
logger.Warning("invalid SPIFFE ID: domain length larger than 255 characters")
65+
return nil
66+
}
67+
// A valid SPIFFE certificate can only have exactly one URI SAN field.
68+
if len(cert.URIs) > 1 {
69+
logger.Warning("invalid SPIFFE ID: multiple URI SANs")
70+
return nil
71+
}
72+
spiffeID = uri
73+
}
74+
return spiffeID
75+
}

0 commit comments

Comments
 (0)