Skip to content

Commit

Permalink
rpc: add client and server unary interceptors to transmit qos levels
Browse files Browse the repository at this point in the history
This commit adds a simple middleware to transmit qos levels from a client
context to a server using grpc metadata headers.

Release note: None
  • Loading branch information
ajwerner committed Jul 1, 2019
1 parent f18a447 commit 2f1b961
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 4 deletions.
12 changes: 8 additions & 4 deletions pkg/rpc/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ const (
defaultWindowSize = 65535
initialWindowSize = defaultWindowSize * 32 // for an RPC
initialConnWindowSize = initialWindowSize * 16 // for a connection
clientQosLevelKey = "c_qos"
)

// sourceAddr is the environment-provided local address for outgoing
Expand Down Expand Up @@ -265,7 +266,7 @@ func NewServerWithInterceptor(
return handler(srv, stream)
}
}

unaryInterceptor = qosServerInterceptor(unaryInterceptor)
if unaryInterceptor != nil {
opts = append(opts, grpc.UnaryInterceptor(unaryInterceptor))
}
Expand Down Expand Up @@ -640,17 +641,20 @@ func (ctx *Context) GRPCDialOptions() ([]grpc.DialOption, error) {
dialOpts = append(dialOpts, grpc.WithDefaultCallOptions(grpc.UseCompressor((snappyCompressor{}).Name())))
}

var unaryInterceptor grpc.UnaryClientInterceptor
if tracer := ctx.AmbientCtx.Tracer; tracer != nil {
// We use a SpanInclusionFunc to circumvent the interceptor's work when
// tracing is disabled. Otherwise, the interceptor causes an increase in
// the number of packets (even with an empty context!). See #17177.
interceptor := otgrpc.OpenTracingClientInterceptor(
unaryInterceptor = otgrpc.OpenTracingClientInterceptor(
tracer,
otgrpc.IncludingSpans(otgrpc.SpanInclusionFunc(spanInclusionFuncForClient)),
)
dialOpts = append(dialOpts, grpc.WithUnaryInterceptor(interceptor))
}

unaryInterceptor = qosClientInterceptor(unaryInterceptor)
if unaryInterceptor != nil {
dialOpts = append(dialOpts, grpc.WithUnaryInterceptor(unaryInterceptor))
}
return dialOpts, nil
}

Expand Down
76 changes: 76 additions & 0 deletions pkg/rpc/qos.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright 2019 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package rpc

import (
"context"
"time"

"github.com/cockroachdb/cockroach/pkg/qos"
"github.com/cockroachdb/cockroach/pkg/util/log"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)

func qosClientInterceptor(
prevUnaryInterceptor grpc.UnaryClientInterceptor,
) grpc.UnaryClientInterceptor {
return func(
goCtx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
invoker grpc.UnaryInvoker, opts ...grpc.CallOption,
) error {
// Add a qos level header if the goCtx contains a qos level.
if l, haveLevel := qos.LevelFromContext(goCtx); haveLevel {
goCtx = metadata.AppendToOutgoingContext(goCtx, clientQosLevelKey, l.EncodeString())
}
// Chain the previous interceptor if there is one.
if prevUnaryInterceptor != nil {
return prevUnaryInterceptor(goCtx, method, req, reply, cc, invoker, opts...)
}
return invoker(goCtx, method, req, reply, cc, opts...)
}
}

func qosServerInterceptor(
prevUnaryInterceptor grpc.UnaryServerInterceptor,
) grpc.UnaryServerInterceptor {
warnTooManyEvery := log.Every(time.Second)
errMalformedQosLevelEvery := log.Every(time.Second)
return func(
goCtx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler,
) (interface{}, error) {
if md, ok := metadata.FromIncomingContext(goCtx); ok {
if v := md.Get(clientQosLevelKey); len(v) > 0 {
// We don't expect more than one item; gRPC does not copy metadata
// from one incoming RPC to an outgoing RPC, so there should be a
// single qos level in the context put there by the interceptor on the
// client before calling this RPC. Nevertheless, having two is only
// logged and is not treated as an error.
if len(v) > 1 && warnTooManyEvery.ShouldLog() {
log.Warningf(goCtx, "unexpected multiple qos levels in client metadata: %s", v)
}
// If a qos level header exists but is malformed it is ignored but a
// message is logged with the corresponding error.
// TODO(ajwerner): consider if this behavior should be less lenient for
// malformed headers.
if l, err := qos.DecodeString(v[0]); err == nil {
goCtx = qos.ContextWithLevel(goCtx, l)
} else if errMalformedQosLevelEvery.ShouldLog() {
log.Errorf(goCtx, "malformed qos level %s header: %v", clientQosLevelKey, err)
}
}
}
if prevUnaryInterceptor != nil {
return prevUnaryInterceptor(goCtx, req, info, handler)
}
return handler(goCtx, req)
}
}
196 changes: 196 additions & 0 deletions pkg/rpc/qos_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
// Copyright 2019 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package rpc

import (
"context"
"strings"
"testing"
"time"

"github.com/cockroachdb/cockroach/pkg/qos"
"github.com/cockroachdb/cockroach/pkg/roachpb"
"github.com/cockroachdb/cockroach/pkg/testutils"
"github.com/cockroachdb/cockroach/pkg/util"
"github.com/cockroachdb/cockroach/pkg/util/hlc"
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/netutil"
"github.com/cockroachdb/cockroach/pkg/util/stop"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/cockroachdb/cockroach/pkg/util/uuid"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)

type internalServerFunc func(
context.Context, *roachpb.BatchRequest,
) (*roachpb.BatchResponse, error)

func (f internalServerFunc) Batch(
ctx context.Context, ba *roachpb.BatchRequest,
) (*roachpb.BatchResponse, error) {
return f(ctx, ba)
}

func (f internalServerFunc) RangeFeed(
_ *roachpb.RangeFeedRequest, _ roachpb.Internal_RangeFeedServer,
) error {
panic("unimplemented")
}

// TestQosMiddleware tests that qos levels get properly transmitted.
func TestQosMiddleware(t *testing.T) {
defer leaktest.AfterTest(t)()

// Can't be zero because that'd be an empty offset.
clock := hlc.NewClock(timeutil.Unix(0, 1).UnixNano, time.Nanosecond)
stopper := stop.NewStopper()
defer stopper.Stop(context.TODO())

// Shared cluster ID by all RPC peers (this ensures that the peers
// don't talk to servers from unrelated tests by accident).
clusterID := uuid.MakeV4()

serverCtx := newTestContext(clusterID, clock, stopper)
const serverNodeID = 1
serverCtx.NodeID.Set(context.TODO(), serverNodeID)
s := newTestServer(t, serverCtx,
grpc.UnaryInterceptor(qosServerInterceptor(nil /* prevUnaryInterceptor */)))

heartbeat := &ManualHeartbeatService{
ready: make(chan error),
stopper: stopper,
clock: clock,
remoteClockMonitor: serverCtx.RemoteClocks,
version: serverCtx.version,
nodeID: &serverCtx.NodeID,
}
RegisterHeartbeatServer(s, heartbeat)
type qosLevel struct {
qos.Level
ok bool
}
qosLevelChan := make(chan qosLevel, 1)
batchFunc := internalServerFunc(func(
ctx context.Context, ba *roachpb.BatchRequest,
) (*roachpb.BatchResponse, error) {
l, ok := qos.LevelFromContext(ctx)
qosLevelChan <- qosLevel{l, ok}
return &roachpb.BatchResponse{}, nil
})
roachpb.RegisterInternalServer(s, batchFunc)

ln, err := netutil.ListenAndServeGRPC(serverCtx.Stopper, s, util.TestAddr)
if err != nil {
t.Fatal(err)
}
remoteAddr := ln.Addr().String()

clientCtx := newTestContext(clusterID, clock, stopper)
// Make the interval shorter to speed up the test.
clientCtx.heartbeatInterval = 1 * time.Millisecond
go func() { heartbeat.ready <- nil }()
conn, err := clientCtx.GRPCDialNode(remoteAddr, serverNodeID).Connect(context.Background())
if err != nil {
t.Fatal(err)
}

// Wait for the connection & successful heartbeat.
testutils.SucceedsSoon(t, func() error {
err := clientCtx.TestingConnHealth(remoteAddr, serverNodeID)
if err != nil && err != ErrNotHeartbeated {
t.Fatal(err)
}
return err
})
clientConn := roachpb.NewInternalClient(conn)

t.Run("without", func(t *testing.T) {
ctx := context.Background()
if _, err = clientConn.Batch(ctx, &roachpb.BatchRequest{}); err != nil {
t.Fatal(err)
}
got := <-qosLevelChan
if got.ok {
t.Fatalf("received context should not have contained a qos level")
}
})
t.Run("with", func(t *testing.T) {
l := qos.Level{Class: qos.ClassLow, Shard: 12}
ctx := qos.ContextWithLevel(context.Background(), l)
if _, err = clientConn.Batch(ctx, &roachpb.BatchRequest{}); err != nil {
t.Fatal(err)
}
got := <-qosLevelChan
if !got.ok {
t.Fatalf("expected context to have contained a qos level")
} else if got.Level != l {
t.Fatalf("expected context to have qos level %v, got %v", l, got.Level)
}
})
t.Run("malformed header", func(t *testing.T) {
ctxWithMalformedHeader := metadata.AppendToOutgoingContext(context.Background(),
clientQosLevelKey, "foo")
var entry *log.Entry
log.Intercept(ctxWithMalformedHeader, func(le log.Entry) {
if entry == nil && le.Severity == log.Severity_ERROR {
entry = &le
}
})
if _, err = clientConn.Batch(ctxWithMalformedHeader, &roachpb.BatchRequest{}); err != nil {
t.Fatal(err)
}
log.Intercept(ctxWithMalformedHeader, nil)
got := <-qosLevelChan
if got.ok {
t.Fatalf("expected context to not have contained a qos level")
}
const msg = "malformed qos level c_qos header: "
if entry == nil {
t.Fatalf("expected a log entry to be captured")
} else if !strings.Contains(entry.Message, msg) {
t.Fatalf("Found log entry %q, expected a log entrying containing %q", entry.Message, msg)
}
})
t.Run("extra headers", func(t *testing.T) {
l1 := qos.Level{Class: qos.ClassLow, Shard: 1}
l2 := qos.Level{Class: qos.ClassHigh, Shard: 2}
ctxWithDuplicateInHeader := metadata.AppendToOutgoingContext(context.Background(),
clientQosLevelKey, l1.EncodeString())
ctxWithDuplicateInHeader = metadata.AppendToOutgoingContext(ctxWithDuplicateInHeader,
clientQosLevelKey, l2.EncodeString())
var entry *log.Entry
log.Intercept(ctxWithDuplicateInHeader, func(le log.Entry) {
if entry == nil && le.Severity == log.Severity_WARNING {
entry = &le
}
})
if _, err = clientConn.Batch(ctxWithDuplicateInHeader, &roachpb.BatchRequest{}); err != nil {
t.Fatal(err)
}
log.Intercept(ctxWithDuplicateInHeader, nil)
const msg = "unexpected multiple qos levels"
if entry == nil {
t.Fatalf("expected a log entry to be captured")
} else if !strings.Contains(entry.Message, msg) {
t.Fatalf("Found log entry %q, expected a log entrying containing %q", entry.Message, msg)
}
// Check that the used level corresponds to the first value, which in this
// case is l1.
got := <-qosLevelChan
if !got.ok {
t.Fatalf("expected context to have contained a qos level")
} else if got.Level != l1 {
t.Fatalf("expected context to have qos level %v, got %v", l1, got.Level)
}
})
}

0 comments on commit 2f1b961

Please sign in to comment.