Skip to content

Commit

Permalink
client: Add grpc authority header integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
serathius committed Sep 30, 2021
1 parent 6e04e8a commit 58d2b12
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 10 deletions.
69 changes: 69 additions & 0 deletions pkg/grpc_testing/recorder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright 2021 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package grpc_testing

import (
"context"
"sync"

"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)

type GrpcRecorder struct {
mux sync.RWMutex
requests []RequestInfo
}

type RequestInfo struct {
FullMethod string
Authority string
}

func (ri *GrpcRecorder) UnaryInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
ri.record(toRequestInfo(ctx, info))
resp, err := handler(ctx, req)
return resp, err
}
}

func (ri *GrpcRecorder) RecordedRequests() []RequestInfo {
ri.mux.RLock()
defer ri.mux.RUnlock()
reqs := make([]RequestInfo, len(ri.requests))
copy(reqs, ri.requests)
return reqs
}

func toRequestInfo(ctx context.Context, info *grpc.UnaryServerInfo) RequestInfo {
req := RequestInfo{
FullMethod: info.FullMethod,
}
md, ok := metadata.FromIncomingContext(ctx)
if ok {
as := md.Get(":authority")
if len(as) != 0 {
req.Authority = as[0]
}
}
return req
}

func (ri *GrpcRecorder) record(r RequestInfo) {
ri.mux.Lock()
defer ri.mux.Unlock()
ri.requests = append(ri.requests, r)
}
2 changes: 1 addition & 1 deletion server/embed/etcd.go
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ func (e *Etcd) servePeers() (err error) {

for _, p := range e.Peers {
u := p.Listener.Addr().String()
gs := v3rpc.Server(e.Server, peerTLScfg)
gs := v3rpc.Server(e.Server, peerTLScfg, nil)
m := cmux.New(p.Listener)
go gs.Serve(m.Match(cmux.HTTP2()))
srv := &http.Server{
Expand Down
4 changes: 2 additions & 2 deletions server/embed/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (sctx *serveCtx) serve(
}()

if sctx.insecure {
gs = v3rpc.Server(s, nil, gopts...)
gs = v3rpc.Server(s, nil, nil, gopts...)
v3electionpb.RegisterElectionServer(gs, servElection)
v3lockpb.RegisterLockServer(gs, servLock)
if sctx.serviceRegister != nil {
Expand Down Expand Up @@ -148,7 +148,7 @@ func (sctx *serveCtx) serve(
if tlsErr != nil {
return tlsErr
}
gs = v3rpc.Server(s, tlscfg, gopts...)
gs = v3rpc.Server(s, tlscfg, nil, gopts...)
v3electionpb.RegisterElectionServer(gs, servElection)
v3lockpb.RegisterLockServer(gs, servLock)
if sctx.serviceRegister != nil {
Expand Down
6 changes: 4 additions & 2 deletions server/etcdserver/api/v3rpc/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,21 @@ const (
maxSendBytes = math.MaxInt32
)

func Server(s *etcdserver.EtcdServer, tls *tls.Config, gopts ...grpc.ServerOption) *grpc.Server {
func Server(s *etcdserver.EtcdServer, tls *tls.Config, interceptor grpc.UnaryServerInterceptor, gopts ...grpc.ServerOption) *grpc.Server {
var opts []grpc.ServerOption
opts = append(opts, grpc.CustomCodec(&codec{}))
if tls != nil {
bundle := credentials.NewBundle(credentials.Config{TLSConfig: tls})
opts = append(opts, grpc.Creds(bundle.TransportCredentials()))
}

chainUnaryInterceptors := []grpc.UnaryServerInterceptor{
newLogUnaryInterceptor(s),
newUnaryInterceptor(s),
grpc_prometheus.UnaryServerInterceptor,
}
if interceptor != nil {
chainUnaryInterceptors = append(chainUnaryInterceptors, interceptor)
}

chainStreamInterceptors := []grpc.StreamServerInterceptor{
newStreamInterceptor(s),
Expand Down
42 changes: 37 additions & 5 deletions tests/integration/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"go.etcd.io/etcd/client/pkg/v3/types"
"go.etcd.io/etcd/client/v2"
"go.etcd.io/etcd/client/v3"
"go.etcd.io/etcd/pkg/v3/grpc_testing"
"go.etcd.io/etcd/raft/v3"
"go.etcd.io/etcd/server/v3/config"
"go.etcd.io/etcd/server/v3/embed"
Expand Down Expand Up @@ -602,6 +603,8 @@ type member struct {

isLearner bool
closed bool

grpcServerRecorder *grpc_testing.GrpcRecorder
}

func (m *member) GRPCURL() string { return m.grpcURL }
Expand Down Expand Up @@ -733,7 +736,7 @@ func mustNewMember(t testutil.TB, mcfg memberConfig) *member {
m.WarningApplyDuration = embed.DefaultWarningApplyDuration

m.V2Deprecation = config.V2_DEPR_DEFAULT

m.grpcServerRecorder = &grpc_testing.GrpcRecorder{}
m.Logger = memberLogger(t, mcfg.name)
t.Cleanup(func() {
// if we didn't cleanup the logger, the consecutive test
Expand Down Expand Up @@ -945,8 +948,8 @@ func (m *member) Launch() error {
return err
}
}
m.grpcServer = v3rpc.Server(m.s, tlscfg, m.grpcServerOpts...)
m.grpcServerPeer = v3rpc.Server(m.s, peerTLScfg)
m.grpcServer = v3rpc.Server(m.s, tlscfg, m.grpcServerRecorder.UnaryInterceptor(), m.grpcServerOpts...)
m.grpcServerPeer = v3rpc.Server(m.s, peerTLScfg, m.grpcServerRecorder.UnaryInterceptor())
m.serverClient = v3client.New(m.s)
lockpb.RegisterLockServer(m.grpcServer, v3lock.NewLockServer(m.serverClient))
epb.RegisterElectionServer(m.grpcServer, v3election.NewElectionServer(m.serverClient))
Expand Down Expand Up @@ -1081,6 +1084,10 @@ func (m *member) Launch() error {
return nil
}

func (m *member) RecordedRequests() []grpc_testing.RequestInfo {
return m.grpcServerRecorder.RecordedRequests()
}

func (m *member) WaitOK(t testutil.TB) {
m.WaitStarted(t)
for m.s.Leader() == 0 {
Expand Down Expand Up @@ -1370,8 +1377,9 @@ func (p SortableMemberSliceByPeerURLs) Swap(i, j int) { p[i], p[j] = p[j], p[i]
type ClusterV3 struct {
*cluster

mu sync.Mutex
clients []*clientv3.Client
mu sync.Mutex
clients []*clientv3.Client
clusterClient *clientv3.Client
}

// NewClusterV3 returns a launched cluster with a grpc client connection
Expand Down Expand Up @@ -1417,6 +1425,11 @@ func (c *ClusterV3) Terminate(t testutil.TB) {
t.Error(err)
}
}
if c.clusterClient != nil {
if err := c.clusterClient.Close(); err != nil {
t.Error(err)
}
}
c.mu.Unlock()
c.cluster.Terminate(t)
}
Expand All @@ -1429,6 +1442,25 @@ func (c *ClusterV3) Client(i int) *clientv3.Client {
return c.clients[i]
}

func (c *ClusterV3) ClusterClient() (client *clientv3.Client, err error) {
if c.clusterClient == nil {
endpoints := []string{}
for _, m := range c.Members {
endpoints = append(endpoints, m.grpcURL)
}
cfg := clientv3.Config{
Endpoints: endpoints,
DialTimeout: 5 * time.Second,
DialOptions: []grpc.DialOption{grpc.WithBlock()},
}
c.clusterClient, err = newClientV3(cfg, cfg.Logger)
if err != nil {
return nil, err
}
}
return c.clusterClient, nil
}

// NewClientV3 creates a new grpc client connection to the member
func (c *ClusterV3) NewClientV3(memberIndex int) (*clientv3.Client, error) {
return NewClientV3(c.Members[memberIndex])
Expand Down
Loading

0 comments on commit 58d2b12

Please sign in to comment.