diff --git a/credentials/credentials.go b/credentials/credentials.go index a851560456b4..88aff94596a1 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -36,9 +36,6 @@ import ( "google.golang.org/grpc/credentials/internal" ) -// alpnProtoStr are the specified application level protocols for gRPC. -var alpnProtoStr = []string{"h2"} - // PerRPCCredentials defines the common interface for the credentials which need to // attach security information to every RPC (e.g., oauth2). type PerRPCCredentials interface { @@ -208,10 +205,23 @@ func (c *tlsCreds) OverrideServerName(serverNameOverride string) error { return nil } +const alpnProtoStrH2 = "h2" + +func appendH2ToNextProtos(ps []string) []string { + for _, p := range ps { + if p == alpnProtoStrH2 { + return ps + } + } + ret := make([]string, 0, len(ps)+1) + ret = append(ret, ps...) + return append(ret, alpnProtoStrH2) +} + // NewTLS uses c to construct a TransportCredentials based on TLS. func NewTLS(c *tls.Config) TransportCredentials { tc := &tlsCreds{cloneTLSConfig(c)} - tc.config.NextProtos = alpnProtoStr + tc.config.NextProtos = appendH2ToNextProtos(tc.config.NextProtos) return tc } diff --git a/credentials/credentials_test.go b/credentials/credentials_test.go index cb091de08092..b15458636f1a 100644 --- a/credentials/credentials_test.go +++ b/credentials/credentials_test.go @@ -22,6 +22,7 @@ import ( "context" "crypto/tls" "net" + "reflect" "testing" "google.golang.org/grpc/testdata" @@ -204,3 +205,39 @@ func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) { } return TLSInfo{State: clientConn.ConnectionState()}, nil } + +func TestAppendH2ToNextProtos(t *testing.T) { + tests := []struct { + name string + ps []string + want []string + }{ + { + name: "empty", + ps: nil, + want: []string{"h2"}, + }, + { + name: "only h2", + ps: []string{"h2"}, + want: []string{"h2"}, + }, + { + name: "with h2", + ps: []string{"alpn", "h2"}, + want: []string{"alpn", "h2"}, + }, + { + name: "no h2", + ps: []string{"alpn"}, + want: []string{"alpn", "h2"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := appendH2ToNextProtos(tt.ps); !reflect.DeepEqual(got, tt.want) { + t.Errorf("appendH2ToNextProtos() = %v, want %v", got, tt.want) + } + }) + } +}