Skip to content

Commit

Permalink
alts: Record network latency and pass it to the handshaker service. (#…
Browse files Browse the repository at this point in the history
…6851)

* alts: Record network latency and pass it to the handshaker service.

* Fix vet.sh warnings.

* Fix protoc version issue.

* Address review comments.
  • Loading branch information
matthewstevenson88 authored Dec 15, 2023
1 parent 45624f0 commit 444749d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 10 deletions.
9 changes: 8 additions & 1 deletion credentials/alts/internal/handshaker/handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"fmt"
"io"
"net"
"time"

"golang.org/x/sync/semaphore"
grpc "google.golang.org/grpc"
Expand Down Expand Up @@ -308,8 +309,10 @@ func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*al
// the results. Handshaker service takes care of frame parsing, so we read
// whatever received from the network and send it to the handshaker service.
func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) {
var lastWriteTime time.Time
for {
if len(resp.OutFrames) > 0 {
lastWriteTime = time.Now()
if _, err := h.conn.Write(resp.OutFrames); err != nil {
return nil, nil, err
}
Expand All @@ -333,11 +336,15 @@ func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []b
// Append extra bytes from the previous interaction with the
// handshaker service with the current buffer read from conn.
p := append(extra, buf[:n]...)
// Compute the time elapsed since the last write to the peer.
timeElapsed := time.Since(lastWriteTime)
timeElapsedMs := uint32(timeElapsed.Milliseconds())
// From here on, p and extra point to the same slice.
resp, err = h.accessHandshakerService(&altspb.HandshakerReq{
ReqOneof: &altspb.HandshakerReq_Next{
Next: &altspb.NextHandshakeMessageReq{
InBytes: p,
InBytes: p,
NetworkLatencyMs: timeElapsedMs,
},
},
})
Expand Down
28 changes: 23 additions & 5 deletions credentials/alts/internal/handshaker/handshaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"testing"
"time"

Expand Down Expand Up @@ -74,6 +75,9 @@ type testRPCStream struct {
first bool
// useful for testing concurrent calls.
delay time.Duration
// The minimum expected value of the network_latency_ms field in a
// NextHandshakeMessageReq.
minExpectedNetworkLatency time.Duration
}

func (t *testRPCStream) Recv() (*altspb.HandshakerResp, error) {
Expand Down Expand Up @@ -102,6 +106,17 @@ func (t *testRPCStream) Send(req *altspb.HandshakerReq) error {
}
}
} else {
switch req := req.ReqOneof.(type) {
case *altspb.HandshakerReq_Next:
// Compare the network_latency_ms field to the minimum expected network
// latency.
if nl := time.Duration(req.Next.NetworkLatencyMs) * time.Millisecond; nl < t.minExpectedNetworkLatency {
return fmt.Errorf("networkLatency (%v) is smaller than expected min network latency (%v)", nl, t.minExpectedNetworkLatency)
}
default:
return fmt.Errorf("handshake request has unexpected type: %v", req)
}

// Add delay to test concurrent calls.
cleanup := stat.Update()
defer cleanup()
Expand Down Expand Up @@ -133,9 +148,11 @@ func (s) TestClientHandshake(t *testing.T) {
for _, testCase := range []struct {
delay time.Duration
numberOfHandshakes int
readLatency time.Duration
}{
{0 * time.Millisecond, 1},
{100 * time.Millisecond, 10 * int(envconfig.ALTSMaxConcurrentHandshakes)},
{0 * time.Millisecond, 1, time.Duration(0)},
{0 * time.Millisecond, 1, 2 * time.Millisecond},
{100 * time.Millisecond, 10 * int(envconfig.ALTSMaxConcurrentHandshakes), time.Duration(0)},
} {
errc := make(chan error)
stat.Reset()
Expand All @@ -145,16 +162,17 @@ func (s) TestClientHandshake(t *testing.T) {

for i := 0; i < testCase.numberOfHandshakes; i++ {
stream := &testRPCStream{
t: t,
isClient: true,
t: t,
isClient: true,
minExpectedNetworkLatency: testCase.readLatency,
}
// Preload the inbound frames.
f1 := testutil.MakeFrame("ServerInit")
f2 := testutil.MakeFrame("ServerFinished")
in := bytes.NewBuffer(f1)
in.Write(f2)
out := new(bytes.Buffer)
tc := testutil.NewTestConn(in, out)
tc := testutil.NewTestConnWithReadLatency(in, out, testCase.readLatency)
chs := &altsHandshaker{
stream: stream,
conn: tc,
Expand Down
22 changes: 18 additions & 4 deletions credentials/alts/internal/testutil/testutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"io"
"net"
"sync"
"time"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/alts/internal/conn"
Expand Down Expand Up @@ -67,20 +68,33 @@ func (s *Stats) Reset() {
// testConn mimics a net.Conn to the peer.
type testConn struct {
net.Conn
in *bytes.Buffer
out *bytes.Buffer
in *bytes.Buffer
out *bytes.Buffer
readLatency time.Duration
}

// NewTestConn creates a new instance of testConn object.
func NewTestConn(in *bytes.Buffer, out *bytes.Buffer) net.Conn {
return &testConn{
in: in,
out: out,
in: in,
out: out,
readLatency: time.Duration(0),
}
}

// NewTestConnWithReadLatency creates a new instance of testConn object that
// pauses for readLatency before any call to Read() returns.
func NewTestConnWithReadLatency(in *bytes.Buffer, out *bytes.Buffer, readLatency time.Duration) net.Conn {
return &testConn{
in: in,
out: out,
readLatency: readLatency,
}
}

// Read reads from the in buffer.
func (c *testConn) Read(b []byte) (n int, err error) {
time.Sleep(c.readLatency)
return c.in.Read(b)
}

Expand Down

0 comments on commit 444749d

Please sign in to comment.