Skip to content

Commit

Permalink
[rpc_md_in_address] balancer: set RPC metadata in address attributes,…
Browse files Browse the repository at this point in the history
… instead of Metadata field

This metadata will be sent with all RPCs on the created SubConn
  • Loading branch information
menghanl committed Nov 17, 2020
1 parent 707e298 commit 83fe271
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 17 deletions.
8 changes: 3 additions & 5 deletions balancer/grpclb/grpclb_remote_balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal/backoff"
"google.golang.org/grpc/internal/channelz"
imetadata "google.golang.org/grpc/internal/metadata"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
Expand Down Expand Up @@ -76,10 +77,7 @@ func (lb *lbBalancer) processServerList(l *lbpb.ServerList) {
// net.SplitHostPort() will return too many colons error.
ipStr = fmt.Sprintf("[%s]", ipStr)
}
addr := resolver.Address{
Addr: fmt.Sprintf("%s:%d", ipStr, s.Port),
Metadata: &md,
}
addr := imetadata.Set(resolver.Address{Addr: fmt.Sprintf("%s:%d", ipStr, s.Port)}, md)
if logger.V(2) {
logger.Infof("lbBalancer: server list entry[%d]: ipStr:|%s|, port:|%d|, load balancer token:|%v|",
i, ipStr, s.Port, s.LoadBalanceToken)
Expand Down Expand Up @@ -164,7 +162,7 @@ func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address, fallback
// Create new SubConns.
for _, addr := range backendAddrs {
addrWithoutMD := addr
addrWithoutMD.Metadata = nil
addrWithoutMD.Attributes = nil
addrsSet[addrWithoutMD] = struct{}{}
lb.backendAddrsWithoutMetadata = append(lb.backendAddrsWithoutMetadata, addrWithoutMD)

Expand Down
2 changes: 1 addition & 1 deletion balancer/grpclb/grpclb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.E
if !ok {
return nil, status.Error(codes.Internal, "failed to receive metadata")
}
if !s.fallback && (md == nil || md["lb-token"][0] != lbToken) {
if !s.fallback && (md == nil || len(md["lb-token"]) == 0 || md["lb-token"][0] != lbToken) {
return nil, status.Errorf(codes.Internal, "received unexpected metadata: %v", md)
}
grpc.SetTrailer(ctx, metadata.Pairs(testmdkey, s.addr))
Expand Down
2 changes: 1 addition & 1 deletion balancer/grpclb/grpclb_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func (ccc *lbCacheClientConn) NewSubConn(addrs []resolver.Address, opts balancer
return nil, fmt.Errorf("grpclb calling NewSubConn with addrs of length %v", len(addrs))
}
addrWithoutMD := addrs[0]
addrWithoutMD.Metadata = nil
addrWithoutMD.Attributes = nil

ccc.mu.Lock()
defer ccc.mu.Unlock()
Expand Down
55 changes: 55 additions & 0 deletions internal/metadata/metadata.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package metadata

import (
"google.golang.org/grpc/attributes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
)

type mdKeyType string

const mdKey = mdKeyType("grpc.internal.address.metadata")

// Get returns the metadata of addr.
func Get(addr resolver.Address) metadata.MD {
attrs := addr.Attributes
if attrs == nil {
return nil
}
md, ok := attrs.Value(mdKey).(metadata.MD)
if !ok {
return nil
}
return md
}

// Set sets (overrides) the metadata in addr.
//
// When a SubConn is created with this address, the RPCs sent on it will all
// have this metadata.
func Set(addr resolver.Address, md metadata.MD) resolver.Address {
if addr.Attributes == nil {
addr.Attributes = attributes.New(mdKey, md)
return addr
}
addr.Attributes = addr.Attributes.WithValues(mdKey, md)
return addr
}
86 changes: 86 additions & 0 deletions internal/metadata/metadata_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package metadata

import (
"testing"

"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/attributes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
)

func TestGet(t *testing.T) {
tests := []struct {
name string
addr resolver.Address
want metadata.MD
}{
{
name: "not set",
addr: resolver.Address{},
want: nil,
},
{
name: "not set",
addr: resolver.Address{
Attributes: attributes.New(mdKey, metadata.Pairs("k", "v")),
},
want: metadata.Pairs("k", "v"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := Get(tt.addr); !cmp.Equal(got, tt.want) {
t.Errorf("Get() = %v, want %v", got, tt.want)
}
})
}
}

func TestSet(t *testing.T) {
tests := []struct {
name string
addr resolver.Address
md metadata.MD
}{
{
name: "unset before",
addr: resolver.Address{},
md: metadata.Pairs("k", "v"),
},
{
name: "set before",
addr: resolver.Address{
Attributes: attributes.New(mdKey, metadata.Pairs("bef", "ore")),
},
md: metadata.Pairs("k", "v"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
newAddr := Set(tt.addr, tt.md)
newMD := Get(newAddr)
if !cmp.Equal(newMD, tt.md) {
t.Errorf("md after Set() = %v, want %v", newMD, tt.md)
}
})
}
}
24 changes: 14 additions & 10 deletions internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/internal/grpcutil"
imetadata "google.golang.org/grpc/internal/metadata"
"google.golang.org/grpc/internal/transport/networktype"

"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -60,7 +61,7 @@ type http2Client struct {
cancel context.CancelFunc
ctxDone <-chan struct{} // Cache the ctx.Done() chan.
userAgent string
md interface{}
md metadata.MD
conn net.Conn // underlying communication channel
loopy *loopyWriter
remoteAddr net.Addr
Expand Down Expand Up @@ -268,7 +269,6 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
ctxDone: ctx.Done(), // Cache Done chan.
cancel: cancel,
userAgent: opts.UserAgent,
md: addr.Metadata,
conn: conn,
remoteAddr: conn.RemoteAddr(),
localAddr: conn.LocalAddr(),
Expand Down Expand Up @@ -296,6 +296,12 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
keepaliveEnabled: keepaliveEnabled,
bufferPool: newBufferPool(),
}

if md, ok := addr.Metadata.(*metadata.MD); ok {
t.md = *md
} else if md := imetadata.Get(addr); md != nil {
t.md = md
}
t.controlBuf = newControlBuffer(t.ctxDone)
if opts.InitialWindowSize >= defaultWindowSize {
t.initialWindowSize = opts.InitialWindowSize
Expand Down Expand Up @@ -512,14 +518,12 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
}
}
}
if md, ok := t.md.(*metadata.MD); ok {
for k, vv := range *md {
if isReservedHeader(k) {
continue
}
for _, v := range vv {
headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
for k, vv := range t.md {
if isReservedHeader(k) {
continue
}
for _, v := range vv {
headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
}
return headerFields, nil
Expand Down
90 changes: 90 additions & 0 deletions test/balancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"google.golang.org/grpc/internal/balancerload"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpcutil"
imetadata "google.golang.org/grpc/internal/metadata"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
Expand Down Expand Up @@ -543,6 +544,95 @@ func (s) TestAddressAttributesInNewSubConn(t *testing.T) {
}
}

// TestMetadataInAddressAttributes verifies that the metadata added to
// address.Attributes will be sent with the RPCs.
func (s) TestMetadataInAddressAttributes(t *testing.T) {
const (
testMDKey = "test-md"
testMDValue = "test-md-value"
mdBalancerName = "metadata-balancer"
)

// Register a stub balancer which adds metadata to the first address that it
// receives and then calls NewSubConn on it.
bf := stub.BalancerFuncs{
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
addrs := ccs.ResolverState.Addresses
if len(addrs) == 0 {
return nil
}

// Only use the first address.
sc, err := bd.ClientConn.NewSubConn([]resolver.Address{
imetadata.Set(addrs[0], metadata.Pairs(testMDKey, testMDValue)),
}, balancer.NewSubConnOptions{})
if err != nil {
return err
}
sc.Connect()
return nil
},
UpdateSubConnState: func(bd *stub.BalancerData, sc balancer.SubConn, state balancer.SubConnState) {
bd.ClientConn.UpdateState(balancer.State{ConnectivityState: state.ConnectivityState, Picker: &aiPicker{result: balancer.PickResult{SubConn: sc}, err: state.ConnectionError}})
},
}
stub.Register(mdBalancerName, bf)
t.Logf("Registered balancer %s...", mdBalancerName)

r := manual.NewBuilderWithScheme("whatever")
t.Logf("Registered manual resolver with scheme %s...", r.Scheme())

testMDChan := make(chan []string, 1)
ss := &stubServer{
emptyCall: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
md, ok := metadata.FromIncomingContext(ctx)
if ok {
select {
case <-testMDChan:
default:
}
testMDChan <- md[testMDKey]
}
return &testpb.Empty{}, nil
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

dopts := []grpc.DialOption{
grpc.WithInsecure(),
grpc.WithResolvers(r),
grpc.WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, mdBalancerName)),
}
cc, err := grpc.Dial(r.Scheme()+":///test.server", dopts...)
if err != nil {
t.Fatal(err)
}
defer cc.Close()
tc := testpb.NewTestServiceClient(cc)
t.Log("Created a ClientConn...")

state := resolver.State{Addresses: []resolver.Address{{Addr: ss.address}}}
r.UpdateState(state)
t.Logf("Pushing resolver state update: %v through the manual resolver", state)

// The RPC should succeed with the expected md.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
}
t.Log("Made an RPC which succeeded...")

// The server should receive the test metadata.
md1 := <-testMDChan
if len(md1) == 0 || md1[0] != testMDValue {
t.Fatalf("got md: %v, want %v", md1, []string{testMDValue})
}
}

// TestServersSwap creates two servers and verifies the client switches between
// them when the name resolver reports the first and then the second.
func (s) TestServersSwap(t *testing.T) {
Expand Down

0 comments on commit 83fe271

Please sign in to comment.