Skip to content

Commit c84a500

Browse files
credentials/alts: defer ALTS stream creation until handshake time (#6077)
1 parent 6f44ae8 commit c84a500

File tree

2 files changed

+97
-17
lines changed

2 files changed

+97
-17
lines changed

credentials/alts/internal/handshaker/handshaker.go

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,16 @@ func DefaultServerHandshakerOptions() *ServerHandshakerOptions {
138138
// and server options (server options struct does not exist now. When
139139
// caller can provide endpoints, it should be created.
140140

141-
// altsHandshaker is used to complete a ALTS handshaking between client and
141+
// altsHandshaker is used to complete an ALTS handshake between client and
142142
// server. This handshaker talks to the ALTS handshaker service in the metadata
143143
// server.
144144
type altsHandshaker struct {
145145
// RPC stream used to access the ALTS Handshaker service.
146146
stream altsgrpc.HandshakerService_DoHandshakeClient
147147
// the connection to the peer.
148148
conn net.Conn
149+
// a virtual connection to the ALTS handshaker service.
150+
clientConn *grpc.ClientConn
149151
// client handshake options.
150152
clientOpts *ClientHandshakerOptions
151153
// server handshake options.
@@ -154,39 +156,33 @@ type altsHandshaker struct {
154156
side core.Side
155157
}
156158

157-
// NewClientHandshaker creates a ALTS handshaker for GCP which contains an RPC
158-
// stub created using the passed conn and used to talk to the ALTS Handshaker
159+
// NewClientHandshaker creates a core.Handshaker that performs a client-side
160+
// ALTS handshake by acting as a proxy between the peer and the ALTS handshaker
159161
// service in the metadata server.
160162
func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ClientHandshakerOptions) (core.Handshaker, error) {
161-
stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx)
162-
if err != nil {
163-
return nil, err
164-
}
165163
return &altsHandshaker{
166-
stream: stream,
164+
stream: nil,
167165
conn: c,
166+
clientConn: conn,
168167
clientOpts: opts,
169168
side: core.ClientSide,
170169
}, nil
171170
}
172171

173-
// NewServerHandshaker creates a ALTS handshaker for GCP which contains an RPC
174-
// stub created using the passed conn and used to talk to the ALTS Handshaker
172+
// NewServerHandshaker creates a core.Handshaker that performs a server-side
173+
// ALTS handshake by acting as a proxy between the peer and the ALTS handshaker
175174
// service in the metadata server.
176175
func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ServerHandshakerOptions) (core.Handshaker, error) {
177-
stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx)
178-
if err != nil {
179-
return nil, err
180-
}
181176
return &altsHandshaker{
182-
stream: stream,
177+
stream: nil,
183178
conn: c,
179+
clientConn: conn,
184180
serverOpts: opts,
185181
side: core.ServerSide,
186182
}, nil
187183
}
188184

189-
// ClientHandshake starts and completes a client ALTS handshaking for GCP. Once
185+
// ClientHandshake starts and completes a client ALTS handshake for GCP. Once
190186
// done, ClientHandshake returns a secure connection.
191187
func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
192188
if !acquire() {
@@ -198,6 +194,16 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
198194
return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client handshaker")
199195
}
200196

197+
// TODO(matthewstevenson88): Change unit tests to use public APIs so
198+
// that h.stream can unconditionally be set based on h.clientConn.
199+
if h.stream == nil {
200+
stream, err := altsgrpc.NewHandshakerServiceClient(h.clientConn).DoHandshake(ctx)
201+
if err != nil {
202+
return nil, nil, fmt.Errorf("failed to establish stream to ALTS handshaker service: %v", err)
203+
}
204+
h.stream = stream
205+
}
206+
201207
// Create target identities from service account list.
202208
targetIdentities := make([]*altspb.Identity, 0, len(h.clientOpts.TargetServiceAccounts))
203209
for _, account := range h.clientOpts.TargetServiceAccounts {
@@ -229,7 +235,7 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
229235
return conn, authInfo, nil
230236
}
231237

232-
// ServerHandshake starts and completes a server ALTS handshaking for GCP. Once
238+
// ServerHandshake starts and completes a server ALTS handshake for GCP. Once
233239
// done, ServerHandshake returns a secure connection.
234240
func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
235241
if !acquire() {
@@ -241,6 +247,16 @@ func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credent
241247
return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server handshaker")
242248
}
243249

250+
// TODO(matthewstevenson88): Change unit tests to use public APIs so
251+
// that h.stream can unconditionally be set based on h.clientConn.
252+
if h.stream == nil {
253+
stream, err := altsgrpc.NewHandshakerServiceClient(h.clientConn).DoHandshake(ctx)
254+
if err != nil {
255+
return nil, nil, fmt.Errorf("failed to establish stream to ALTS handshaker service: %v", err)
256+
}
257+
h.stream = stream
258+
}
259+
244260
p := make([]byte, frameLimit)
245261
n, err := h.conn.Read(p)
246262
if err != nil {

credentials/alts/internal/handshaker/handshaker_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import (
2525
"testing"
2626
"time"
2727

28+
"github.com/google/go-cmp/cmp"
29+
"github.com/google/go-cmp/cmp/cmpopts"
2830
grpc "google.golang.org/grpc"
2931
core "google.golang.org/grpc/credentials/alts/internal"
3032
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
@@ -283,3 +285,65 @@ func (s) TestPeerNotResponding(t *testing.T) {
283285
t.Errorf("ClientHandshake() = %v, want %v", got, want)
284286
}
285287
}
288+
289+
func (s) TestNewClientHandshaker(t *testing.T) {
290+
conn := testutil.NewTestConn(nil, nil)
291+
clientConn := &grpc.ClientConn{}
292+
opts := &ClientHandshakerOptions{}
293+
hs, err := NewClientHandshaker(context.Background(), clientConn, conn, opts)
294+
if err != nil {
295+
t.Errorf("NewClientHandshaker returned unexpected error: %v", err)
296+
}
297+
expectedHs := &altsHandshaker{
298+
stream: nil,
299+
conn: conn,
300+
clientConn: clientConn,
301+
clientOpts: opts,
302+
serverOpts: nil,
303+
side: core.ClientSide,
304+
}
305+
cmpOpts := []cmp.Option{
306+
cmp.AllowUnexported(altsHandshaker{}),
307+
cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"),
308+
}
309+
if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) {
310+
t.Errorf("NewClientHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want)
311+
}
312+
if hs.(*altsHandshaker).stream != nil {
313+
t.Errorf("NewClientHandshaker() returned handshaker with non-nil stream")
314+
}
315+
if hs.(*altsHandshaker).clientConn != clientConn {
316+
t.Errorf("NewClientHandshaker() returned handshaker with unexpected clientConn")
317+
}
318+
}
319+
320+
func (s) TestNewServerHandshaker(t *testing.T) {
321+
conn := testutil.NewTestConn(nil, nil)
322+
clientConn := &grpc.ClientConn{}
323+
opts := &ServerHandshakerOptions{}
324+
hs, err := NewServerHandshaker(context.Background(), clientConn, conn, opts)
325+
if err != nil {
326+
t.Errorf("NewServerHandshaker returned unexpected error: %v", err)
327+
}
328+
expectedHs := &altsHandshaker{
329+
stream: nil,
330+
conn: conn,
331+
clientConn: clientConn,
332+
clientOpts: nil,
333+
serverOpts: opts,
334+
side: core.ServerSide,
335+
}
336+
cmpOpts := []cmp.Option{
337+
cmp.AllowUnexported(altsHandshaker{}),
338+
cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"),
339+
}
340+
if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) {
341+
t.Errorf("NewServerHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want)
342+
}
343+
if hs.(*altsHandshaker).stream != nil {
344+
t.Errorf("NewServerHandshaker() returned handshaker with non-nil stream")
345+
}
346+
if hs.(*altsHandshaker).clientConn != clientConn {
347+
t.Errorf("NewServerHandshaker() returned handshaker with unexpected clientConn")
348+
}
349+
}

0 commit comments

Comments
 (0)