Skip to content

Commit

Permalink
Review comments #1.
Browse files Browse the repository at this point in the history
  • Loading branch information
easwars committed Apr 18, 2020
1 parent 63be37a commit 3e8e161
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 48 deletions.
6 changes: 0 additions & 6 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1272,12 +1272,6 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne
copts.ChannelzParentID = ac.channelzID
}

// gRPC, resolver, balancer etc. can specify arbitrary data in the
// Attributes field of resolver.Address, which is shoved into connectCtx
// that is passed to the transport layer. The transport layer passes the
// same context to the credential handshaker. This makes is possible for
// address specific arbitrary data to reach the credential handshaker.
connectCtx = credentials.WithAddressInfo(connectCtx, credentials.AddressInfo{Attr: addr.Attributes})
newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, onPrefaceReceipt, onGoAway, onClose)
if err != nil {
// newTr is either nil, or closed.
Expand Down
31 changes: 16 additions & 15 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,30 +194,31 @@ func RequestInfoFromContext(ctx context.Context) (ri RequestInfo, ok bool) {
return
}

// AddressInfo contains address related data attached to the context passed to
// ClientHandshake. This makes it possible to pass arbitrary data to the
// handshaker from gRPC, resolver, balancer etc. Individual credential
// implementations control the actual format of the data that they are willing
// to receive.
// ClientHandshakeInfo holds data to be passed to ClientHandshake. This makes
// it possible to pass arbitrary data to the handshaker from gRPC, resolver,
// balancer etc. Individual credential implementations control the actual
// format of the data that they are willing to receive.
//
// This API is experimental.
type AddressInfo struct {
type ClientHandshakeInfo struct {
// Attr is a generic key/value store.
Attr *attributes.Attributes
}

// addressInfoKey is a struct used as the key to store AddressInfo in a context.
type addressInfoKey struct{}
// clientHandshakeInfoKey is a struct used as the key to store
// ClientHandshakeInfo in a context.
type clientHandshakeInfoKey struct{}

// WithAddressInfo returns a copy of parent with ai stored as a value.
func WithAddressInfo(parent context.Context, ai AddressInfo) context.Context {
return context.WithValue(parent, addressInfoKey{}, ai)
// WithClientHandshakeInfo returns a copy of parent with chi stored as a value.
func WithClientHandshakeInfo(parent context.Context, chi ClientHandshakeInfo) context.Context {
return context.WithValue(parent, clientHandshakeInfoKey{}, chi)
}

// AddressInfoFromContext returns the AddressInfo stored in ctx.
func AddressInfoFromContext(ctx context.Context) AddressInfo {
ai, _ := ctx.Value(addressInfoKey{}).(AddressInfo)
return ai
// ClientHandshakeInfoFromContext returns the ClientHandshakeInfo struct stored
// in ctx.
func ClientHandshakeInfoFromContext(ctx context.Context) ClientHandshakeInfo {
chi, _ := ctx.Value(clientHandshakeInfoKey{}).(ClientHandshakeInfo)
return chi
}

// CheckSecurityLevel checks if a connection's security level is greater than or equal to the specified one.
Expand Down
5 changes: 5 additions & 0 deletions internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
}
}
if transportCreds != nil {
// gRPC, resolver, balancer etc. can specify arbitrary data in the
// Attributes field of resolver.Address, which is shoved into connectCtx
// and passed to the credential handshaker. This makes it possible for
// address specific arbitrary data to reach the credential handshaker.
connectCtx = credentials.WithClientHandshakeInfo(connectCtx, credentials.ClientHandshakeInfo{Attr: addr.Attributes})
scheme = "https"
conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, conn)
if err != nil {
Expand Down
54 changes: 54 additions & 0 deletions internal/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/attributes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/internal/testutils"
Expand Down Expand Up @@ -1816,3 +1819,54 @@ func (s) TestHeaderTblSize(t *testing.T) {
t.Fatalf("expected len(limits) = 2 within 10s, got != 2")
}
}

// attrTransportCreds is a transport credential implementation which stores
// Attributes from the ClientHandshakeInfo struct passed in the context locally
// for the test to inspect.
type attrTransportCreds struct {
credentials.TransportCredentials
attr *attributes.Attributes
}

func (ac *attrTransportCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
ai := credentials.ClientHandshakeInfoFromContext(ctx)
ac.attr = ai.Attr
return rawConn, nil, nil
}
func (ac *attrTransportCreds) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}
func (ac *attrTransportCreds) Clone() credentials.TransportCredentials {
return nil
}

// TestClientHandshakeInfo adds attributes to the resolver.Address passes to
// NewClientTransport and verifies that these attributes are received by the
// transport credential handshaker.
func (s) TestClientHandshakeInfo(t *testing.T) {
server := setUpServerOnly(t, 0, &ServerConfig{}, pingpong)
defer server.stop()

const (
testAttrKey = "foo"
testAttrVal = "bar"
)
addr := resolver.Address{
Addr: "localhost:" + server.port,
Attributes: attributes.New(testAttrKey, testAttrVal),
}
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
defer cancel()
creds := &attrTransportCreds{}

tr, err := NewClientTransport(ctx, context.Background(), addr, ConnectOptions{TransportCredentials: creds}, func() {}, func(GoAwayReason) {}, func() {})
if err != nil {
t.Fatalf("NewClientTransport(): %v", err)
}
defer tr.Close()

wantAttr := attributes.New(testAttrKey, testAttrVal)
if gotAttr := creds.attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) {
t.Fatalf("received attributes %v in creds, want %v", gotAttr, wantAttr)
}
}
55 changes: 28 additions & 27 deletions test/balancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,27 +353,27 @@ func (s) TestNonGRPCLBBalancerGetsNoGRPCLBAddress(t *testing.T) {
}
}

const aiBalancerName = "addrInfo-attribute-balancer"
const attrBalancerName = "attribute-balancer"

// aiBalancerBuilder builds a balancer and passes the attribute key and value
// attrBalancerBuilder builds a balancer and passes the attribute key and value
// with which it was configured at creation time by the test.
type aiBalancerBuilder struct {
type attrBalancerBuilder struct {
attrKey string
attrVal string
}

func (bb *aiBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
return &aiBalancer{cc: cc, attrKey: bb.attrKey, attrVal: bb.attrVal}
func (bb *attrBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
return &attrBalancer{cc: cc, attrKey: bb.attrKey, attrVal: bb.attrVal}
}

func (bb *aiBalancerBuilder) Name() string {
return aiBalancerName
func (bb *attrBalancerBuilder) Name() string {
return attrBalancerName
}

// aiBalancer receives an attribute key and value which it adds to the address
// that it calls NewSubConn on. This key/value pair reaches the credential
// handshaker and the test verifies the same.
type aiBalancer struct {
// attrBalancer receives an attribute key and value which it adds to the
// address that it calls NewSubConn on. This key/value pair reaches the
// credential handshaker and the test verifies the same.
type attrBalancer struct {
balancer.Balancer
cc balancer.ClientConn
attrKey string
Expand All @@ -382,7 +382,7 @@ type aiBalancer struct {

// UpdateClientConnState adds an attribute with the configured key/value to the
// addresses received and invokes NewSubConn.
func (b *aiBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
func (b *attrBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
addrs := ccs.ResolverState.Addresses
if len(addrs) == 0 {
return nil
Expand All @@ -399,12 +399,12 @@ func (b *aiBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
return nil
}

func (b *aiBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
func (b *attrBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
b.cc.UpdateState(balancer.State{ConnectivityState: state.ConnectivityState, Picker: &aiPicker{result: balancer.PickResult{SubConn: sc}, err: state.ConnectionError}})
}

func (b *aiBalancer) ResolverError(error) {}
func (b *aiBalancer) Close() {}
func (b *attrBalancer) ResolverError(error) {}
func (b *attrBalancer) Close() {}

type aiPicker struct {
result balancer.PickResult
Expand All @@ -415,22 +415,23 @@ func (aip *aiPicker) Pick(_ balancer.PickInfo) (balancer.PickResult, error) {
return aip.result, aip.err
}

// addrInfoTransportCreds is a transport credential implementation which stores
// the Attributes struct passed in the context locally for the test to inspect.
type addrInfoTransportCreds struct {
// attrTransportCreds is a transport credential implementation which stores
// Attributes from the ClientHandshakeInfo struct passed in the context locally
// for the test to inspect.
type attrTransportCreds struct {
credentials.TransportCredentials
attr *attributes.Attributes
}

func (ac *addrInfoTransportCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
ai := credentials.AddressInfoFromContext(ctx)
func (ac *attrTransportCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
ai := credentials.ClientHandshakeInfoFromContext(ctx)
ac.attr = ai.Attr
return rawConn, nil, nil
}
func (ac *addrInfoTransportCreds) Info() credentials.ProtocolInfo {
func (ac *attrTransportCreds) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}
func (ac *addrInfoTransportCreds) Clone() credentials.TransportCredentials {
func (ac *attrTransportCreds) Clone() credentials.TransportCredentials {
return nil
}

Expand All @@ -444,9 +445,9 @@ func (s) TestAddressAttributesInNewSubConn(t *testing.T) {
testAttrVal = "bar"
)

balancer.Register(&aiBalancerBuilder{attrKey: testAttrKey, attrVal: testAttrVal})
defer internal.BalancerUnregister(aiBalancerName)
t.Logf("Registered balancer %s...", aiBalancerName)
balancer.Register(&attrBalancerBuilder{attrKey: testAttrKey, attrVal: testAttrVal})
defer internal.BalancerUnregister(attrBalancerName)
t.Logf("Registered balancer %s...", attrBalancerName)

r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
Expand All @@ -463,10 +464,10 @@ func (s) TestAddressAttributesInNewSubConn(t *testing.T) {
defer s.Stop()
t.Logf("Started gRPC server at %s...", lis.Addr().String())

creds := &addrInfoTransportCreds{}
creds := &attrTransportCreds{}
dopts := []grpc.DialOption{
grpc.WithTransportCredentials(creds),
grpc.WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, aiBalancerName)),
grpc.WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, attrBalancerName)),
}
cc, err := grpc.Dial(r.Scheme()+":///test.server", dopts...)
if err != nil {
Expand Down

0 comments on commit 3e8e161

Please sign in to comment.