Skip to content

Commit

Permalink
fix(gRPC): append "h2" to next proto in gRPC tlsConfig to enable prot…
Browse files Browse the repository at this point in the history
…ocol negotiation in TLS
  • Loading branch information
ppzqh committed Jan 23, 2024
1 parent bfe6a07 commit 3d8285d
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 3 deletions.
6 changes: 5 additions & 1 deletion client/option_advanced.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/cloudwego/kitex/pkg/proxy"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/trans/netpoll"
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc"
"github.com/cloudwego/kitex/pkg/retry"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/utils"
Expand Down Expand Up @@ -237,7 +238,10 @@ func WithBoundHandler(h remote.BoundHandler) Option {
// WithGRPCTLSConfig sets the TLS config for gRPC client.
func WithGRPCTLSConfig(tlsConfig *tls.Config) Option {
return Option{F: func(o *client.Options, di *utils.Slice) {
if tlsConfig == nil {
panic("invalid TLS config: nil")
}
di.Push("WithGRPCTLSConfig")
o.GRPCConnectOpts.TLSConfig = tlsConfig
o.GRPCConnectOpts.TLSConfig = grpc.TLSConfig(tlsConfig)
}}
}
7 changes: 7 additions & 0 deletions client/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package client

import (
"context"
"crypto/tls"
"fmt"
"reflect"
"testing"
Expand Down Expand Up @@ -701,3 +702,9 @@ func TestWithXDSSuite(t *testing.T) {
test.Assert(t, opt.XDSRouterMiddleware != nil)
test.Assert(t, opt.Resolver != nil)
}

func TestWithGRPCTLSConfig(t *testing.T) {
cfg := &tls.Config{}
opts := client.NewOptions([]client.Option{WithGRPCTLSConfig(cfg)})
test.Assert(t, opts.GRPCConnectOpts != nil)
}
15 changes: 14 additions & 1 deletion pkg/remote/trans/nphttp2/conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ func (p *connPool) newTransport(ctx context.Context, dialer remote.Dialer, netwo
return nil, err
}
if opts.TLSConfig != nil {
conn = tls.Client(conn, opts.TLSConfig)
tlsConn, err := newTLSConn(conn, opts.TLSConfig)
if err != nil {
return nil, err
}
conn = tlsConn
}
return grpc.NewClientTransport(
ctx,
Expand Down Expand Up @@ -222,3 +226,12 @@ func (p *connPool) Close() error {
})
return nil
}

// newTLSConn constructs a client-side TLS connection and performs handshake.
func newTLSConn(conn net.Conn, tlsCfg *tls.Config) (net.Conn, error) {
tlsConn := tls.Client(conn, tlsCfg)
if err := tlsConn.Handshake(); err != nil {
return nil, err
}
return tlsConn, nil
}
8 changes: 7 additions & 1 deletion pkg/remote/trans/nphttp2/grpc/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ type http2Client struct {
loopy *loopyWriter
remoteAddr net.Addr
localAddr net.Addr
scheme string

readerDone chan struct{} // sync point to enable testing.
writerDone chan struct{} // sync point to enable testing.
Expand Down Expand Up @@ -115,6 +116,10 @@ type http2Client struct {
func newHTTP2Client(ctx context.Context, conn net.Conn, opts ConnectOptions,
remoteService string, onGoAway func(GoAwayReason), onClose func(),
) (_ *http2Client, err error) {
scheme := "http"
if opts.TLSConfig != nil {
scheme = "https"
}
ctx, cancel := context.WithCancel(ctx)
defer func() {
if err != nil {
Expand Down Expand Up @@ -163,6 +168,7 @@ func newHTTP2Client(ctx context.Context, conn net.Conn, opts ConnectOptions,
cancel: cancel,
remoteAddr: conn.RemoteAddr(),
localAddr: conn.LocalAddr(),
scheme: scheme,
readerDone: make(chan struct{}),
writerDone: make(chan struct{}),
goAway: make(chan struct{}),
Expand Down Expand Up @@ -339,7 +345,7 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te
headerFields := make([]hpack.HeaderField, 0, hfLen)
headerFields = append(headerFields, hpack.HeaderField{Name: ":method", Value: "POST"})
headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: "http"})
headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: t.scheme})
headerFields = append(headerFields, hpack.HeaderField{Name: ":path", Value: callHdr.Method})
headerFields = append(headerFields, hpack.HeaderField{Name: ":authority", Value: callHdr.Host})
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(callHdr.ContentSubtype)})
Expand Down
29 changes: 29 additions & 0 deletions pkg/remote/trans/nphttp2/grpc/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -805,3 +805,32 @@ func ContextErr(err error) error {
func IsStreamDoneErr(err error) bool {
return errors.Is(err, errStreamDone)
}

// TLSConfig checks and supplement the tls config provided by user.
func TLSConfig(tlsConfig *tls.Config) *tls.Config {
cfg := tlsConfig.Clone()
// When multiple application protocols are supported on a single server-side port number,
// the client and the server need to negotiate an application protocol for use with each connection.
// For gRPC, "h2" should be appended to "application_layer_protocol_negotiation" field.
cfg.NextProtos = tlsAppendH2ToALPNProtocols(cfg.NextProtos)

// Implementations of HTTP/2 MUST use TLS version 1.2 [TLS12] or higher for HTTP/2 over TLS.
// https://datatracker.ietf.org/doc/html/rfc7540#section-9.2
if cfg.MinVersion == 0 && (cfg.MaxVersion == 0 || cfg.MaxVersion >= tls.VersionTLS12) {
cfg.MinVersion = tls.VersionTLS12
}
return cfg
}

const alpnProtoStrH2 = "h2"

func tlsAppendH2ToALPNProtocols(ps []string) []string {
for _, p := range ps {
if p == alpnProtoStrH2 {
return ps
}
}
ret := make([]string, 0, len(ps)+1)
ret = append(ret, ps...)
return append(ret, alpnProtoStrH2)
}
20 changes: 20 additions & 0 deletions pkg/remote/trans/nphttp2/grpc/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package grpc
import (
"bytes"
"context"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
Expand All @@ -40,6 +41,7 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"

"github.com/cloudwego/kitex/internal/test"
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes"
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc/grpcframe"
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc/testutils"
Expand Down Expand Up @@ -2039,3 +2041,21 @@ func TestHeaderTblSize(t *testing.T) {
t.Fatalf("expected len(limits) = 2 within 10s, got != 2")
}
}

func TestTLSConfig(t *testing.T) {
cfg := &tls.Config{}
newCfg := TLSConfig(cfg)
test.Assert(t, len(cfg.NextProtos) == 0)
test.Assert(t, len(newCfg.NextProtos) == 1)
test.Assert(t, newCfg.NextProtos[0] == alpnProtoStrH2)
test.Assert(t, newCfg.MinVersion == tls.VersionTLS12)
}

func TestTlsAppendH2ToALPNProtocols(t *testing.T) {
var ps []string
appended := tlsAppendH2ToALPNProtocols(ps)
test.Assert(t, len(appended) == 1)
test.Assert(t, appended[0] == alpnProtoStrH2)
appended = tlsAppendH2ToALPNProtocols(appended)
test.Assert(t, len(appended) == 1)
}

0 comments on commit 3d8285d

Please sign in to comment.