diff --git a/clientconn.go b/clientconn.go index 64a7982fad1f..53375ed9b7c5 100644 --- a/clientconn.go +++ b/clientconn.go @@ -1241,7 +1241,20 @@ func (ac *addrConn) transportMonitor() { // Block until we receive a goaway or an error occurs. select { case <-t.GoAway(): + done := t.Error() + cleanup := t.Close + // Since this transport will be orphaned (won't have a transportMonitor) + // we need to launch a goroutine to keep track of clientConn.Close() + // happening since it might not be noticed by any other goroutine for a while. + go func() { + <-done + cleanup() + }() case <-t.Error(): + // In case this is triggered because clientConn.Close() + // was called, we want to immeditately close the transport + // since no other goroutine might notice it for a while. + t.Close() case <-cdeadline: ac.mu.Lock() // This implies that client received server preface. diff --git a/test/end2end_test.go b/test/end2end_test.go index 3df8bfff0a86..ffd13961223a 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -748,7 +748,6 @@ type lazyConn struct { func (l *lazyConn) Write(b []byte) (int, error) { if atomic.LoadInt32(&(l.beLazy)) == 1 { - // The sleep duration here needs to less than the leakCheck deadline. time.Sleep(time.Second) } return l.Conn.Write(b) @@ -963,7 +962,7 @@ func testServerGoAwayPendingRPC(t *testing.T, e env) { } // The existing RPC should be still good to proceed. if err := stream.Send(req); err != nil { - t.Fatalf("%v.Send(%v) = %v, want ", stream, req, err) + t.Fatalf("%v.Send(_) = %v, want ", stream, err) } if _, err := stream.Recv(); err != nil { t.Fatalf("%v.Recv() = _, %v, want _, ", stream, err) @@ -3053,7 +3052,6 @@ func testMultipleSetHeaderStreamingRPCError(t *testing.T, e env) { if !reflect.DeepEqual(header, expectedHeader) { t.Fatalf("Received header metadata %v, want %v", header, expectedHeader) } - if err := stream.CloseSend(); err != nil { t.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil) } @@ -3156,44 +3154,6 @@ func testRetry(t *testing.T, e env) { } } -func TestRPCTimeout(t *testing.T) { - defer leakcheck.Check(t) - for _, e := range listTestEnv() { - testRPCTimeout(t, e) - } -} - -// TODO(zhaoq): Have a better test coverage of timeout and cancellation mechanism. -func testRPCTimeout(t *testing.T, e env) { - te := newTest(t, e) - te.startServer(&testServer{security: e.security, unaryCallSleepTime: 50 * time.Millisecond}) - defer te.tearDown() - - cc := te.clientConn() - tc := testpb.NewTestServiceClient(cc) - - const argSize = 2718 - const respSize = 314 - - payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize) - if err != nil { - t.Fatal(err) - } - - req := &testpb.SimpleRequest{ - ResponseType: testpb.PayloadType_COMPRESSABLE, - ResponseSize: respSize, - Payload: payload, - } - for i := -1; i <= 10; i++ { - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(i)*time.Millisecond) - if _, err := tc.UnaryCall(ctx, req); status.Code(err) != codes.DeadlineExceeded { - t.Fatalf("TestService/UnaryCallv(_, _) = _, %v; want , error code: %s", err, codes.DeadlineExceeded) - } - cancel() - } -} - func TestCancel(t *testing.T) { defer leakcheck.Check(t) for _, e := range listTestEnv() { @@ -3687,7 +3647,7 @@ func testClientStreaming(t *testing.T, e env, sizes []int) { Payload: payload, } if err := stream.Send(req); err != nil { - t.Fatalf("%v.Send(%v) = %v, want ", stream, req, err) + t.Fatalf("%v.Send(_) = %v, want ", stream, err) } sum += s } @@ -5078,7 +5038,7 @@ func TestTapTimeout(t *testing.T) { ss := &stubServer{ emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { <-ctx.Done() - return &testpb.Empty{}, nil + return nil, status.Errorf(codes.Canceled, ctx.Err().Error()) }, } if err := ss.Start(sopts); err != nil { @@ -6218,3 +6178,40 @@ func TestFailFastRPCErrorOnBadCertificates(t *testing.T) { } te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) } + +func TestRPCTimeout(t *testing.T) { + defer leakcheck.Check(t) + for _, e := range listTestEnv() { + testRPCTimeout(t, e) + } +} + +func testRPCTimeout(t *testing.T, e env) { + te := newTest(t, e) + te.startServer(&testServer{security: e.security, unaryCallSleepTime: 500 * time.Millisecond}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + + const argSize = 2718 + const respSize = 314 + + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize) + if err != nil { + t.Fatal(err) + } + + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE, + ResponseSize: respSize, + Payload: payload, + } + for i := -1; i <= 10; i++ { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(i)*time.Millisecond) + if _, err := tc.UnaryCall(ctx, req); status.Code(err) != codes.DeadlineExceeded { + t.Fatalf("TestService/UnaryCallv(_, _) = _, %v; want , error code: %s", err, codes.DeadlineExceeded) + } + cancel() + } +} diff --git a/transport/controlbuf.go b/transport/controlbuf.go new file mode 100644 index 000000000000..ba35100c171a --- /dev/null +++ b/transport/controlbuf.go @@ -0,0 +1,758 @@ +/* + * + * Copyright 2014 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "bytes" + "fmt" + "runtime" + "sync" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" +) + +type itemNode struct { + it interface{} + next *itemNode +} + +type itemList struct { + head *itemNode + tail *itemNode +} + +func (il *itemList) enqueue(i interface{}) { + n := &itemNode{it: i} + if il.tail == nil { + il.head, il.tail = n, n + return + } + il.tail.next = n + il.tail = n +} + +// peek returns the first item in the list without removing it from the +// list. +func (il *itemList) peek() interface{} { + return il.head.it +} + +func (il *itemList) dequeue() interface{} { + if il.head == nil { + return nil + } + i := il.head.it + il.head = il.head.next + if il.head == nil { + il.tail = nil + } + return i +} + +func (il *itemList) dequeueAll() *itemNode { + h := il.head + il.head, il.tail = nil, nil + return h +} + +func (il *itemList) isEmpty() bool { + return il.head == nil +} + +// The following defines various control items which could flow through +// the control buffer of transport. They represent different aspects of +// control tasks, e.g., flow control, settings, streaming resetting, etc. + +type headerFrame struct { + streamID uint32 + hf []hpack.HeaderField + endStream bool // Valid on server side. + initStream func(uint32) (bool, error) // Used only on the client side. + onWrite func() + wq *writeQuota // write quota for the stream created. + cleanup *cleanupStream // Valid on the server side. + onOrphaned func(error) // Valid on client-side +} + +type cleanupStream struct { + streamID uint32 + idPtr *uint32 + rst bool + rstCode http2.ErrCode + onWrite func() +} + +type dataFrame struct { + streamID uint32 + endStream bool + h []byte + d []byte + // onEachWrite is called every time + // a part of d is written out. + onEachWrite func() +} + +type incomingWindowUpdate struct { + streamID uint32 + increment uint32 +} + +type outgoingWindowUpdate struct { + streamID uint32 + increment uint32 +} + +type incomingSettings struct { + ss []http2.Setting +} + +type outgoingSettings struct { + ss []http2.Setting +} + +type settingsAck struct { +} + +type incomingGoAway struct { +} + +type goAway struct { + code http2.ErrCode + debugData []byte + headsUp bool + closeConn bool +} + +type ping struct { + ack bool + data [8]byte +} + +type outStreamState int + +const ( + active outStreamState = iota + empty + waitingOnStreamQuota +) + +type outStream struct { + id uint32 + state outStreamState + itl *itemList + bytesOutStanding int + wq *writeQuota + + next *outStream + prev *outStream +} + +func (s *outStream) deleteSelf() { + if s.prev != nil { + s.prev.next = s.next + } + if s.next != nil { + s.next.prev = s.prev + } + s.next, s.prev = nil, nil +} + +type outStreamList struct { + // Following are sentinal objects that mark the + // beginning and end of the list. They do not + // contain any item lists. All valid objects are + // inserted in between them. + // This is needed so that an outStream object can + // deleteSelf() in O(1) time without knowing which + // list it belongs to. + head *outStream + tail *outStream +} + +func newOutStreamList() *outStreamList { + head, tail := new(outStream), new(outStream) + head.next = tail + tail.prev = head + return &outStreamList{ + head: head, + tail: tail, + } +} + +func (l *outStreamList) enqueue(s *outStream) { + e := l.tail.prev + e.next = s + s.prev = e + s.next = l.tail + l.tail.prev = s +} + +// remove from the beginning of the list. +func (l *outStreamList) dequeue() *outStream { + b := l.head.next + if b == l.tail { + return nil + } + b.deleteSelf() + return b +} + +type controlBuffer struct { + ch chan struct{} + done <-chan struct{} + mu sync.Mutex + consumerWaiting bool + list *itemList + err error +} + +func newControlBuffer(done <-chan struct{}) *controlBuffer { + return &controlBuffer{ + ch: make(chan struct{}, 1), + list: &itemList{}, + done: done, + } +} + +func (c *controlBuffer) put(it interface{}) error { + _, err := c.executeAndPut(nil, it) + return err +} + +func (c *controlBuffer) executeAndPut(f func(it interface{}) bool, it interface{}) (bool, error) { + var wakeUp bool + c.mu.Lock() + if c.err != nil { + c.mu.Unlock() + return false, c.err + } + if f != nil { + if !f(it) { // f wasn't successful + c.mu.Unlock() + return false, nil + } + } + if c.consumerWaiting { + wakeUp = true + c.consumerWaiting = false + } + c.list.enqueue(it) + c.mu.Unlock() + if wakeUp { + select { + case c.ch <- struct{}{}: + default: + } + } + return true, nil +} + +func (c *controlBuffer) get(block bool) (interface{}, error) { + for { + c.mu.Lock() + if c.err != nil { + c.mu.Unlock() + return nil, c.err + } + if !c.list.isEmpty() { + h := c.list.dequeue() + c.mu.Unlock() + return h, nil + } + if !block { + c.mu.Unlock() + return nil, nil + } + c.consumerWaiting = true + c.mu.Unlock() + select { + case <-c.ch: + case <-c.done: + c.finish() + return nil, ErrConnClosing + } + } +} + +func (c *controlBuffer) finish() { + c.mu.Lock() + if c.err != nil { + c.mu.Unlock() + return + } + c.err = ErrConnClosing + // There may be headers for streams in the control buffer. + // These streams need to be cleaned out since the transport + // is still not aware of these yet. + for head := c.list.dequeueAll(); head != nil; head = head.next { + hdr, ok := head.it.(*headerFrame) + if !ok { + continue + } + if hdr.onOrphaned != nil { // It will be nil on the server-side. + hdr.onOrphaned(ErrConnClosing) + } + } + c.mu.Unlock() +} + +type side int + +const ( + clientSide side = iota + serverSide +) + +type loopyWriter struct { + side side + cbuf *controlBuffer + sendQuota uint32 + oiws uint32 // outbound initial window size. + estdStreams map[uint32]*outStream // Established streams. + activeStreams *outStreamList // Streams that are sending data. + framer *framer + hBuf *bytes.Buffer // The buffer for HPACK encoding. + hEnc *hpack.Encoder // HPACK encoder. + bdpEst *bdpEstimator + draining bool + + // Side-specific handlers + ssGoAwayHandler func(*goAway) (bool, error) +} + +func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator) *loopyWriter { + var buf bytes.Buffer + l := &loopyWriter{ + side: s, + cbuf: cbuf, + sendQuota: defaultWindowSize, + oiws: defaultWindowSize, + estdStreams: make(map[uint32]*outStream), + activeStreams: newOutStreamList(), + framer: fr, + hBuf: &buf, + hEnc: hpack.NewEncoder(&buf), + bdpEst: bdpEst, + } + return l +} + +const minBatchSize = 1000 + +// run should be run in a separate goroutine. +func (l *loopyWriter) run() { + var ( + it interface{} + err error + isEmpty bool + ) + defer func() { + errorf("transport: loopyWriter.run returning. Err: %v", err) + }() + for { + it, err = l.cbuf.get(true) + if err != nil { + return + } + if err = l.handle(it); err != nil { + return + } + if _, err = l.processData(); err != nil { + return + } + gosched := true + hasdata: + for { + it, err = l.cbuf.get(false) + if err != nil { + return + } + if it != nil { + if err = l.handle(it); err != nil { + return + } + if _, err = l.processData(); err != nil { + return + } + continue hasdata + } + if isEmpty, err = l.processData(); err != nil { + return + } + if !isEmpty { + continue hasdata + } + if gosched { + gosched = false + if l.framer.writer.offset < minBatchSize { + runtime.Gosched() + continue hasdata + } + } + l.framer.writer.Flush() + break hasdata + + } + } +} + +func (l *loopyWriter) outgoingWindowUpdateHandler(w *outgoingWindowUpdate) error { + return l.framer.fr.WriteWindowUpdate(w.streamID, w.increment) +} + +func (l *loopyWriter) incomingWindowUpdateHandler(w *incomingWindowUpdate) error { + // Otherwise update the quota. + if w.streamID == 0 { + l.sendQuota += w.increment + return nil + } + // Find the stream and update it. + if str, ok := l.estdStreams[w.streamID]; ok { + str.bytesOutStanding -= int(w.increment) + if strQuota := int(l.oiws) - str.bytesOutStanding; strQuota > 0 && str.state == waitingOnStreamQuota { + str.state = active + l.activeStreams.enqueue(str) + return nil + } + } + return nil +} + +func (l *loopyWriter) outgoingSettingsHandler(s *outgoingSettings) error { + return l.framer.fr.WriteSettings(s.ss...) +} + +func (l *loopyWriter) incomingSettingsHandler(s *incomingSettings) error { + if err := l.applySettings(s.ss); err != nil { + return err + } + return l.framer.fr.WriteSettingsAck() +} + +func (l *loopyWriter) headerHandler(h *headerFrame) error { + if l.side == serverSide { + if h.endStream { // Case 1.A: Server wants to close stream. + // Make sure it's not a trailers only response. + if str, ok := l.estdStreams[h.streamID]; ok { + if str.state != empty { // either active or waiting on stream quota. + // add it str's list of items. + str.itl.enqueue(h) + return nil + } + } + if err := l.writeHeader(h.streamID, h.endStream, h.hf, h.onWrite); err != nil { + return err + } + return l.cleanupStreamHandler(h.cleanup) + } + // Case 1.B: Server is responding back with headers. + str := &outStream{ + state: empty, + itl: &itemList{}, + wq: h.wq, + } + l.estdStreams[h.streamID] = str + return l.writeHeader(h.streamID, h.endStream, h.hf, h.onWrite) + } + // Case 2: Client wants to originate stream. + str := &outStream{ + id: h.streamID, + state: empty, + itl: &itemList{}, + wq: h.wq, + } + str.itl.enqueue(h) + return l.originateStream(str) +} + +func (l *loopyWriter) originateStream(str *outStream) error { + hdr := str.itl.dequeue().(*headerFrame) + sendPing, err := hdr.initStream(str.id) + if err != nil { + if err == ErrConnClosing { + return err + } + // Other errors(errStreamDrain) need not close transport. + return nil + } + if err = l.writeHeader(str.id, hdr.endStream, hdr.hf, hdr.onWrite); err != nil { + return err + } + l.estdStreams[str.id] = str + if sendPing { + return l.pingHandler(&ping{data: [8]byte{}}) + } + return nil +} + +func (l *loopyWriter) writeHeader(streamID uint32, endStream bool, hf []hpack.HeaderField, onWrite func()) error { + if onWrite != nil { + onWrite() + } + l.hBuf.Reset() + for _, f := range hf { + if err := l.hEnc.WriteField(f); err != nil { + warningf("transport: loopyWriter.writeHeader encountered error while encoding headers:", err) + } + } + var ( + err error + endHeaders, first bool + ) + first = true + for !endHeaders { + size := l.hBuf.Len() + if size > http2MaxFrameLen { + size = http2MaxFrameLen + } else { + endHeaders = true + } + if first { + first = false + err = l.framer.fr.WriteHeaders(http2.HeadersFrameParam{ + StreamID: streamID, + BlockFragment: l.hBuf.Next(size), + EndStream: endStream, + EndHeaders: endHeaders, + }) + } else { + err = l.framer.fr.WriteContinuation( + streamID, + endHeaders, + l.hBuf.Next(size), + ) + } + if err != nil { + return err + } + } + return nil +} + +func (l *loopyWriter) preprocessData(df *dataFrame) error { + str, ok := l.estdStreams[df.streamID] + if !ok { + return nil + } + // If we got data for a stream it means that + // stream was originated and the headers were sent out. + str.itl.enqueue(df) + if str.state == empty { + str.state = active + l.activeStreams.enqueue(str) + } + return nil +} + +func (l *loopyWriter) pingHandler(p *ping) error { + if !p.ack { + l.bdpEst.timesnap(p.data) + } + return l.framer.fr.WritePing(p.ack, p.data) + +} + +func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error { + c.onWrite() + if str, ok := l.estdStreams[c.streamID]; ok { + // On the server side it could be a trailers-only response or + // a RST_STREAM before stream initialization thus the stream might + // not be established yet. + delete(l.estdStreams, c.streamID) + str.deleteSelf() + } + if c.rst { // If RST_STREAM needs to be sent. + if err := l.framer.fr.WriteRSTStream(c.streamID, c.rstCode); err != nil { + return err + } + } + if l.side == clientSide && l.draining && len(l.estdStreams) == 0 { + return ErrConnClosing + } + return nil +} + +func (l *loopyWriter) incomingGoAwayHandler(*incomingGoAway) error { + if l.side == clientSide { + l.draining = true + if len(l.estdStreams) == 0 { + return ErrConnClosing + } + } + return nil +} + +func (l *loopyWriter) goAwayHandler(g *goAway) error { + // Handling of outgoing GoAway is very specific to side. + if l.ssGoAwayHandler != nil { + draining, err := l.ssGoAwayHandler(g) + if err != nil { + return err + } + l.draining = draining + } + return nil +} + +func (l *loopyWriter) handle(i interface{}) error { + switch i := i.(type) { + case *incomingWindowUpdate: + return l.incomingWindowUpdateHandler(i) + case *outgoingWindowUpdate: + return l.outgoingWindowUpdateHandler(i) + case *incomingSettings: + return l.incomingSettingsHandler(i) + case *outgoingSettings: + return l.outgoingSettingsHandler(i) + case *headerFrame: + return l.headerHandler(i) + case *cleanupStream: + return l.cleanupStreamHandler(i) + case *incomingGoAway: + return l.incomingGoAwayHandler(i) + case *dataFrame: + return l.preprocessData(i) + case *ping: + return l.pingHandler(i) + case *goAway: + return l.goAwayHandler(i) + default: + return fmt.Errorf("transport: unknown control message type %T", i) + } +} + +func (l *loopyWriter) applySettings(ss []http2.Setting) error { + for _, s := range ss { + switch s.ID { + case http2.SettingInitialWindowSize: + o := l.oiws + l.oiws = s.Val + if o < l.oiws { + // If the new limit is greater make all depleted streams active. + for _, stream := range l.estdStreams { + if stream.state == waitingOnStreamQuota { + stream.state = active + l.activeStreams.enqueue(stream) + } + } + } + } + } + return nil +} + +func (l *loopyWriter) processData() (bool, error) { + if l.sendQuota == 0 { + return true, nil + } + str := l.activeStreams.dequeue() + if str == nil { + return true, nil + } + dataItem := str.itl.peek().(*dataFrame) + if len(dataItem.h) == 0 && len(dataItem.d) == 0 { + // Client sends out empty data frame with endStream = true + if err := l.framer.fr.WriteData(dataItem.streamID, dataItem.endStream, nil); err != nil { + return false, err + } + str.itl.dequeue() + if str.itl.isEmpty() { + str.state = empty + } else if trailer, ok := str.itl.peek().(*headerFrame); ok { // the next item is trailers. + if err := l.writeHeader(trailer.streamID, trailer.endStream, trailer.hf, trailer.onWrite); err != nil { + return false, err + } + if err := l.cleanupStreamHandler(trailer.cleanup); err != nil { + return false, nil + } + } else { + l.activeStreams.enqueue(str) + } + return false, nil + } + var ( + idx int + buf []byte + ) + if len(dataItem.h) != 0 { // data header has not been written out yet. + buf = dataItem.h + } else { + idx = 1 + buf = dataItem.d + } + size := http2MaxFrameLen + if len(buf) < size { + size = len(buf) + } + if strQuota := int(l.oiws) - str.bytesOutStanding; strQuota <= 0 { + str.state = waitingOnStreamQuota + return false, nil + } else if strQuota < size { + size = strQuota + } + + if l.sendQuota < uint32(size) { + size = int(l.sendQuota) + } + // Now that outgoing flow controls are checked we can replenish str's write quota + str.wq.replenish(size) + var endStream bool + // This last data message on this stream and all + // of it can be written in this go. + if dataItem.endStream && size == len(buf) { + // buf contains either data or it contains header but data is empty. + if idx == 1 || len(dataItem.d) == 0 { + endStream = true + } + } + if dataItem.onEachWrite != nil { + dataItem.onEachWrite() + } + if err := l.framer.fr.WriteData(dataItem.streamID, endStream, buf[:size]); err != nil { + return false, err + } + buf = buf[size:] + str.bytesOutStanding += size + l.sendQuota -= uint32(size) + if idx == 0 { + dataItem.h = buf + } else { + dataItem.d = buf + } + + if len(dataItem.h) == 0 && len(dataItem.d) == 0 { // All the data from that message was written out. + str.itl.dequeue() + } + if str.itl.isEmpty() { + str.state = empty + } else if trailer, ok := str.itl.peek().(*headerFrame); ok { // The next item is trailers. + if err := l.writeHeader(trailer.streamID, trailer.endStream, trailer.hf, trailer.onWrite); err != nil { + return false, err + } + if err := l.cleanupStreamHandler(trailer.cleanup); err != nil { + return false, err + } + } else if int(l.oiws)-str.bytesOutStanding <= 0 { // Ran out of stream quota. + str.state = waitingOnStreamQuota + } else { // Otherwise add it back to the list of active streams. + l.activeStreams.enqueue(str) + } + return false, nil +} diff --git a/transport/control.go b/transport/flowcontrol.go similarity index 55% rename from transport/control.go rename to transport/flowcontrol.go index 0474b09074ba..cfe0d78dacdc 100644 --- a/transport/control.go +++ b/transport/flowcontrol.go @@ -20,13 +20,10 @@ package transport import ( "fmt" - "io" "math" "sync" + "sync/atomic" "time" - - "golang.org/x/net/http2" - "golang.org/x/net/http2/hpack" ) const ( @@ -46,192 +43,86 @@ const ( defaultKeepalivePolicyMinTime = time.Duration(5 * time.Minute) // max window limit set by HTTP2 Specs. maxWindowSize = math.MaxInt32 - // defaultLocalSendQuota sets is default value for number of data + // defaultWriteQuota is the default value for number of data // bytes that each stream can schedule before some of it being // flushed out. - defaultLocalSendQuota = 128 * 1024 + defaultWriteQuota = 64 * 1024 ) -// The following defines various control items which could flow through -// the control buffer of transport. They represent different aspects of -// control tasks, e.g., flow control, settings, streaming resetting, etc. - -type headerFrame struct { - streamID uint32 - hf []hpack.HeaderField - endStream bool -} - -func (*headerFrame) item() {} - -type continuationFrame struct { - streamID uint32 - endHeaders bool - headerBlockFragment []byte -} - -type dataFrame struct { - streamID uint32 - endStream bool - d []byte - f func() -} - -func (*dataFrame) item() {} - -func (*continuationFrame) item() {} - -type windowUpdate struct { - streamID uint32 - increment uint32 -} - -func (*windowUpdate) item() {} - -type settings struct { - ss []http2.Setting -} - -func (*settings) item() {} - -type settingsAck struct { -} - -func (*settingsAck) item() {} - -type resetStream struct { - streamID uint32 - code http2.ErrCode -} - -func (*resetStream) item() {} - -type goAway struct { - code http2.ErrCode - debugData []byte - headsUp bool - closeConn bool -} - -func (*goAway) item() {} - -type flushIO struct { - closeTr bool -} - -func (*flushIO) item() {} - -type ping struct { - ack bool - data [8]byte -} - -func (*ping) item() {} - -// quotaPool is a pool which accumulates the quota and sends it to acquire() -// when it is available. -type quotaPool struct { - mu sync.Mutex - c chan struct{} - version uint32 - quota int -} - -// newQuotaPool creates a quotaPool which has quota q available to consume. -func newQuotaPool(q int) *quotaPool { - qb := "aPool{ - quota: q, - c: make(chan struct{}, 1), +// writeQuota is a soft limit on the amount of data a stream can +// schedule before some of it is written out. +type writeQuota struct { + quota int32 + // get waits on read from when quota goes less than or equal to zero. + // replenish writes on it when quota goes positive again. + ch chan struct{} + // done is triggered in error case. + done <-chan struct{} +} + +func newWriteQuota(sz int32, done <-chan struct{}) *writeQuota { + return &writeQuota{ + quota: sz, + ch: make(chan struct{}, 1), + done: done, } - return qb } -// add cancels the pending quota sent on acquired, incremented by v and sends -// it back on acquire. -func (qb *quotaPool) add(v int) { - qb.mu.Lock() - defer qb.mu.Unlock() - qb.lockedAdd(v) +func (w *writeQuota) get(sz int32) error { + for { + if atomic.LoadInt32(&w.quota) > 0 { + atomic.AddInt32(&w.quota, -sz) + return nil + } + select { + case <-w.ch: + continue + case <-w.done: + return errStreamDone + } + } } -func (qb *quotaPool) lockedAdd(v int) { - var wakeUp bool - if qb.quota <= 0 { - wakeUp = true // Wake up potential waiters. - } - qb.quota += v - if wakeUp && qb.quota > 0 { +func (w *writeQuota) replenish(n int) { + sz := int32(n) + a := atomic.AddInt32(&w.quota, sz) + b := a - sz + if b <= 0 && a > 0 { select { - case qb.c <- struct{}{}: + case w.ch <- struct{}{}: default: } } } -func (qb *quotaPool) addAndUpdate(v int) { - qb.mu.Lock() - qb.lockedAdd(v) - qb.version++ - qb.mu.Unlock() +type trInFlow struct { + limit uint32 + unacked uint32 } -func (qb *quotaPool) get(v int, wc waiters) (int, uint32, error) { - qb.mu.Lock() - if qb.quota > 0 { - if v > qb.quota { - v = qb.quota - } - qb.quota -= v - ver := qb.version - qb.mu.Unlock() - return v, ver, nil - } - qb.mu.Unlock() - for { - select { - case <-wc.ctx.Done(): - return 0, 0, ContextErr(wc.ctx.Err()) - case <-wc.tctx.Done(): - return 0, 0, ErrConnClosing - case <-wc.done: - return 0, 0, io.EOF - case <-wc.goAway: - return 0, 0, errStreamDrain - case <-qb.c: - qb.mu.Lock() - if qb.quota > 0 { - if v > qb.quota { - v = qb.quota - } - qb.quota -= v - ver := qb.version - if qb.quota > 0 { - select { - case qb.c <- struct{}{}: - default: - } - } - qb.mu.Unlock() - return v, ver, nil +func (f *trInFlow) newLimit(n uint32) uint32 { + d := n - f.limit + f.limit = n + return d +} - } - qb.mu.Unlock() - } +func (f *trInFlow) onData(n uint32) uint32 { + f.unacked += n + if f.unacked >= f.limit/4 { + w := f.unacked + f.unacked = 0 + return w } + return 0 } -func (qb *quotaPool) compareAndExecute(version uint32, success, failure func()) bool { - qb.mu.Lock() - if version == qb.version { - success() - qb.mu.Unlock() - return true - } - failure() - qb.mu.Unlock() - return false +func (f *trInFlow) reset() uint32 { + w := f.unacked + f.unacked = 0 + return w } +// TODO(mmukhi): Simplify this code. // inFlow deals with inbound flow control type inFlow struct { mu sync.Mutex @@ -252,9 +143,9 @@ type inFlow struct { // It assumes that n is always greater than the old limit. func (f *inFlow) newLimit(n uint32) uint32 { f.mu.Lock() - defer f.mu.Unlock() d := n - f.limit f.limit = n + f.mu.Unlock() return d } @@ -263,7 +154,6 @@ func (f *inFlow) maybeAdjust(n uint32) uint32 { n = uint32(math.MaxInt32) } f.mu.Lock() - defer f.mu.Unlock() // estSenderQuota is the receiver's view of the maximum number of bytes the sender // can send without a window update. estSenderQuota := int32(f.limit - (f.pendingData + f.pendingUpdate)) @@ -275,7 +165,7 @@ func (f *inFlow) maybeAdjust(n uint32) uint32 { // for this message. Therefore we must send an update over the limit since there's an active read // request from the application. if estUntransmittedData > estSenderQuota { - // Sender's window shouldn't go more than 2^31 - 1 as speecified in the HTTP spec. + // Sender's window shouldn't go more than 2^31 - 1 as specified in the HTTP spec. if f.limit+n > maxWindowSize { f.delta = maxWindowSize - f.limit } else { @@ -284,19 +174,24 @@ func (f *inFlow) maybeAdjust(n uint32) uint32 { // is padded; We will fallback on the current available window(at least a 1/4th of the limit). f.delta = n } + f.mu.Unlock() return f.delta } + f.mu.Unlock() return 0 } // onData is invoked when some data frame is received. It updates pendingData. func (f *inFlow) onData(n uint32) error { f.mu.Lock() - defer f.mu.Unlock() f.pendingData += n if f.pendingData+f.pendingUpdate > f.limit+f.delta { - return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", f.pendingData+f.pendingUpdate, f.limit) + limit := f.limit + rcvd := f.pendingData + f.pendingUpdate + f.mu.Unlock() + return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", rcvd, limit) } + f.mu.Unlock() return nil } @@ -304,8 +199,8 @@ func (f *inFlow) onData(n uint32) error { // to be sent to the peer. func (f *inFlow) onRead(n uint32) uint32 { f.mu.Lock() - defer f.mu.Unlock() if f.pendingData == 0 { + f.mu.Unlock() return 0 } f.pendingData -= n @@ -320,15 +215,9 @@ func (f *inFlow) onRead(n uint32) uint32 { if f.pendingUpdate >= f.limit/4 { wu := f.pendingUpdate f.pendingUpdate = 0 + f.mu.Unlock() return wu } + f.mu.Unlock() return 0 } - -func (f *inFlow) resetPendingUpdate() uint32 { - f.mu.Lock() - defer f.mu.Unlock() - n := f.pendingUpdate - f.pendingUpdate = 0 - return n -} diff --git a/transport/handler_server.go b/transport/handler_server.go index 1a5e96c5a17b..9d0c88c73387 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -365,7 +365,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace ht.stats.HandleRPC(s.ctx, inHeader) } s.trReader = &transportReader{ - reader: &recvBufferReader{ctx: s.ctx, recv: s.buf}, + reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf}, windowHandler: func(int) {}, } diff --git a/transport/http2_client.go b/transport/http2_client.go index 8b5be0d6d51f..560dc3970755 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -19,8 +19,6 @@ package transport import ( - "bytes" - "fmt" "io" "math" "net" @@ -45,14 +43,17 @@ import ( type http2Client struct { ctx context.Context cancel context.CancelFunc + ctxDone <-chan struct{} // Cache the ctx.Done() chan. userAgent string md interface{} conn net.Conn // underlying communication channel + loopy *loopyWriter remoteAddr net.Addr localAddr net.Addr authInfo credentials.AuthInfo // auth info about the connection - nextID uint32 // the next stream ID to be used + readerDone chan struct{} // sync point to enable testing. + writerDone chan struct{} // sync point to enable testing. // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor) // that the server sent GoAway on this transport. goAway chan struct{} @@ -60,21 +61,10 @@ type http2Client struct { awakenKeepalive chan struct{} framer *framer - hBuf *bytes.Buffer // the buffer for HPACK encoding - hEnc *hpack.Encoder // HPACK encoder - // controlBuf delivers all the control related tasks (e.g., window // updates, reset streams, and various settings) to the controller. controlBuf *controlBuffer - fc *inFlow - // sendQuotaPool provides flow control to outbound message. - sendQuotaPool *quotaPool - // localSendQuota limits the amount of data that can be scheduled - // for writing before it is actually written out. - localSendQuota *quotaPool - // streamsQuota limits the max number of concurrent streams. - streamsQuota *quotaPool - + fc *trInFlow // The scheme used: https if TLS is on, http otherwise. scheme string @@ -91,21 +81,21 @@ type http2Client struct { initialWindowSize int32 - bdpEst *bdpEstimator - outQuotaVersion uint32 - + bdpEst *bdpEstimator // onSuccess is a callback that client transport calls upon // receiving server preface to signal that a succefull HTTP2 // connection was established. onSuccess func() - mu sync.Mutex // guard the following variables - state transportState // the state of underlying connection + maxConcurrentStreams uint32 + streamQuota int64 + streamsQuotaAvailable chan struct{} + waitingStreams uint32 + nextID uint32 + + mu sync.Mutex // guard the following variables + state transportState activeStreams map[uint32]*Stream - // The max number of concurrent streams - maxStreams int - // the per-stream outbound flow control window size set by the peer. - streamSendQuota uint32 // prevGoAway ID records the Last-Stream-ID in the previous GOAway frame. prevGoAwayID uint32 // goAwayReason records the http2.ErrCode and debug data received with the @@ -187,7 +177,6 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne icwz = opts.InitialConnWindowSize dynamicWindow = false } - var buf bytes.Buffer writeBufSize := defaultWriteBufSize if opts.WriteBufferSize > 0 { writeBufSize = opts.WriteBufferSize @@ -197,38 +186,35 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne readBufSize = opts.ReadBufferSize } t := &http2Client{ - ctx: ctx, - cancel: cancel, - userAgent: opts.UserAgent, - md: addr.Metadata, - conn: conn, - remoteAddr: conn.RemoteAddr(), - localAddr: conn.LocalAddr(), - authInfo: authInfo, - // The client initiated stream id is odd starting from 1. - nextID: 1, - goAway: make(chan struct{}), - awakenKeepalive: make(chan struct{}, 1), - hBuf: &buf, - hEnc: hpack.NewEncoder(&buf), - framer: newFramer(conn, writeBufSize, readBufSize), - controlBuf: newControlBuffer(), - fc: &inFlow{limit: uint32(icwz)}, - sendQuotaPool: newQuotaPool(defaultWindowSize), - localSendQuota: newQuotaPool(defaultLocalSendQuota), - scheme: scheme, - state: reachable, - activeStreams: make(map[uint32]*Stream), - isSecure: isSecure, - creds: opts.PerRPCCredentials, - maxStreams: defaultMaxStreamsClient, - streamsQuota: newQuotaPool(defaultMaxStreamsClient), - streamSendQuota: defaultWindowSize, - kp: kp, - statsHandler: opts.StatsHandler, - initialWindowSize: initialWindowSize, - onSuccess: onSuccess, - } + ctx: ctx, + ctxDone: ctx.Done(), // Cache Done chan. + cancel: cancel, + userAgent: opts.UserAgent, + md: addr.Metadata, + conn: conn, + remoteAddr: conn.RemoteAddr(), + localAddr: conn.LocalAddr(), + authInfo: authInfo, + readerDone: make(chan struct{}), + writerDone: make(chan struct{}), + goAway: make(chan struct{}), + awakenKeepalive: make(chan struct{}, 1), + framer: newFramer(conn, writeBufSize, readBufSize), + fc: &trInFlow{limit: uint32(icwz)}, + scheme: scheme, + activeStreams: make(map[uint32]*Stream), + isSecure: isSecure, + creds: opts.PerRPCCredentials, + kp: kp, + statsHandler: opts.StatsHandler, + initialWindowSize: initialWindowSize, + onSuccess: onSuccess, + nextID: 1, + maxConcurrentStreams: defaultMaxStreamsClient, + streamQuota: defaultMaxStreamsClient, + streamsQuotaAvailable: make(chan struct{}, 1), + } + t.controlBuf = newControlBuffer(t.ctxDone) if opts.InitialWindowSize >= defaultWindowSize { t.initialWindowSize = opts.InitialWindowSize dynamicWindow = false @@ -287,8 +273,10 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne } t.framer.writer.Flush() go func() { - loopyWriter(t.ctx, t.controlBuf, t.itemHandler) + t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst) + t.loopy.run() t.conn.Close() + close(t.writerDone) }() if t.kp.Time != infinity { go t.keepalive() @@ -299,18 +287,14 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { // TODO(zhaoq): Handle uint32 overflow of Stream.id. s := &Stream{ - id: t.nextID, done: make(chan struct{}), - goAway: make(chan struct{}), method: callHdr.Method, sendCompress: callHdr.SendCompress, buf: newRecvBuffer(), - fc: &inFlow{limit: uint32(t.initialWindowSize)}, - sendQuotaPool: newQuotaPool(int(t.streamSendQuota)), headerChan: make(chan struct{}), contentSubtype: callHdr.ContentSubtype, } - t.nextID += 2 + s.wq = newWriteQuota(defaultWriteQuota, s.done) s.requestRead = func(n int) { t.adjustWindow(s, uint32(n)) } @@ -320,26 +304,18 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { s.ctx = ctx s.trReader = &transportReader{ reader: &recvBufferReader{ - ctx: s.ctx, - goAway: s.goAway, - recv: s.buf, + ctx: s.ctx, + ctxDone: s.ctx.Done(), + recv: s.buf, }, windowHandler: func(n int) { t.updateWindow(s, uint32(n)) }, } - s.waiters = waiters{ - ctx: s.ctx, - tctx: t.ctx, - done: s.done, - goAway: s.goAway, - } return s } -// NewStream creates a stream and registers it into the transport as "active" -// streams. -func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { +func (t *http2Client) getPeer() *peer.Peer { pr := &peer.Peer{ Addr: t.remoteAddr, } @@ -347,71 +323,17 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea if t.authInfo != nil { pr.AuthInfo = t.authInfo } - ctx = peer.NewContext(ctx, pr) - var ( - authData = make(map[string]string) - audience string - ) - // Create an audience string only if needed. - if len(t.creds) > 0 || callHdr.Creds != nil { - // Construct URI required to get auth request metadata. - // Omit port if it is the default one. - host := strings.TrimSuffix(callHdr.Host, ":443") - pos := strings.LastIndex(callHdr.Method, "/") - if pos == -1 { - pos = len(callHdr.Method) - } - audience = "https://" + host + callHdr.Method[:pos] - } - for _, c := range t.creds { - data, err := c.GetRequestMetadata(ctx, audience) - if err != nil { - if _, ok := status.FromError(err); ok { - return nil, err - } + return pr +} - return nil, streamErrorf(codes.Unauthenticated, "transport: %v", err) - } - for k, v := range data { - // Capital header names are illegal in HTTP/2. - k = strings.ToLower(k) - authData[k] = v - } - } - callAuthData := map[string]string{} - // Check if credentials.PerRPCCredentials were provided via call options. - // Note: if these credentials are provided both via dial options and call - // options, then both sets of credentials will be applied. - if callCreds := callHdr.Creds; callCreds != nil { - if !t.isSecure && callCreds.RequireTransportSecurity() { - return nil, streamErrorf(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection") - } - data, err := callCreds.GetRequestMetadata(ctx, audience) - if err != nil { - return nil, streamErrorf(codes.Internal, "transport: %v", err) - } - for k, v := range data { - // Capital header names are illegal in HTTP/2 - k = strings.ToLower(k) - callAuthData[k] = v - } - } - t.mu.Lock() - if t.activeStreams == nil { - t.mu.Unlock() - return nil, ErrConnClosing - } - if t.state == draining { - t.mu.Unlock() - return nil, errStreamDrain - } - if t.state != reachable { - t.mu.Unlock() - return nil, ErrConnClosing +func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) ([]hpack.HeaderField, error) { + aud := t.createAudience(callHdr) + authData, err := t.getTrAuthData(ctx, aud) + if err != nil { + return nil, err } - t.mu.Unlock() - // Get a quota of 1 from streamsQuota. - if _, _, err := t.streamsQuota.get(1, waiters{ctx: ctx, tctx: t.ctx}); err != nil { + callAuthData, err := t.getCallAuthData(ctx, aud, callHdr) + if err != nil { return nil, err } // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields @@ -485,38 +407,172 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } } } - t.mu.Lock() - if t.state == draining { - t.mu.Unlock() - t.streamsQuota.add(1) - return nil, errStreamDrain + return headerFields, nil +} + +func (t *http2Client) createAudience(callHdr *CallHdr) string { + // Create an audience string only if needed. + if len(t.creds) == 0 && callHdr.Creds == nil { + return "" } - if t.state != reachable { - t.mu.Unlock() - return nil, ErrConnClosing + // Construct URI required to get auth request metadata. + // Omit port if it is the default one. + host := strings.TrimSuffix(callHdr.Host, ":443") + pos := strings.LastIndex(callHdr.Method, "/") + if pos == -1 { + pos = len(callHdr.Method) + } + return "https://" + host + callHdr.Method[:pos] +} + +func (t *http2Client) getTrAuthData(ctx context.Context, audience string) (map[string]string, error) { + authData := map[string]string{} + for _, c := range t.creds { + data, err := c.GetRequestMetadata(ctx, audience) + if err != nil { + if _, ok := status.FromError(err); ok { + return nil, err + } + + return nil, streamErrorf(codes.Unauthenticated, "transport: %v", err) + } + for k, v := range data { + // Capital header names are illegal in HTTP/2. + k = strings.ToLower(k) + authData[k] = v + } + } + return authData, nil +} + +func (t *http2Client) getCallAuthData(ctx context.Context, audience string, callHdr *CallHdr) (map[string]string, error) { + callAuthData := map[string]string{} + // Check if credentials.PerRPCCredentials were provided via call options. + // Note: if these credentials are provided both via dial options and call + // options, then both sets of credentials will be applied. + if callCreds := callHdr.Creds; callCreds != nil { + if !t.isSecure && callCreds.RequireTransportSecurity() { + return nil, streamErrorf(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection") + } + data, err := callCreds.GetRequestMetadata(ctx, audience) + if err != nil { + return nil, streamErrorf(codes.Internal, "transport: %v", err) + } + for k, v := range data { + // Capital header names are illegal in HTTP/2 + k = strings.ToLower(k) + callAuthData[k] = v + } + } + return callAuthData, nil +} + +// NewStream creates a stream and registers it into the transport as "active" +// streams. +func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { + ctx = peer.NewContext(ctx, t.getPeer()) + headerFields, err := t.createHeaderFields(ctx, callHdr) + if err != nil { + return nil, err } s := t.newStream(ctx, callHdr) - t.activeStreams[s.id] = s - // If the number of active streams change from 0 to 1, then check if keepalive - // has gone dormant. If so, wake it up. - if len(t.activeStreams) == 1 { - select { - case t.awakenKeepalive <- struct{}{}: - t.controlBuf.put(&ping{data: [8]byte{}}) - // Fill the awakenKeepalive channel again as this channel must be - // kept non-writable except at the point that the keepalive() - // goroutine is waiting either to be awaken or shutdown. - t.awakenKeepalive <- struct{}{} - default: + cleanup := func(err error) { + if s.swapState(streamDone) == streamDone { + // If it was already done, return. + return } + // The stream was unprocessed by the server. + atomic.StoreUint32(&s.unprocessed, 1) + s.write(recvMsg{err: err}) + close(s.done) + // If headerChan isn't closed, then close it. + if atomic.SwapUint32(&s.headerDone, 1) == 0 { + close(s.headerChan) + } + } - t.controlBuf.put(&headerFrame{ - streamID: s.id, + hdr := &headerFrame{ hf: headerFields, endStream: false, - }) - t.mu.Unlock() - + initStream: func(id uint32) (bool, error) { + t.mu.Lock() + if state := t.state; state != reachable { + t.mu.Unlock() + // Do a quick cleanup. + err := error(errStreamDrain) + if state == closing { + err = ErrConnClosing + } + cleanup(err) + return false, err + } + t.activeStreams[id] = s + var sendPing bool + // If the number of active streams change from 0 to 1, then check if keepalive + // has gone dormant. If so, wake it up. + if len(t.activeStreams) == 1 { + select { + case t.awakenKeepalive <- struct{}{}: + sendPing = true + // Fill the awakenKeepalive channel again as this channel must be + // kept non-writable except at the point that the keepalive() + // goroutine is waiting either to be awaken or shutdown. + t.awakenKeepalive <- struct{}{} + default: + } + } + t.mu.Unlock() + return sendPing, nil + }, + onOrphaned: cleanup, + wq: s.wq, + } + firstTry := true + var ch chan struct{} + checkForStreamQuota := func(it interface{}) bool { + if t.streamQuota <= 0 { // Can go negative if server decreases it. + if firstTry { + t.waitingStreams++ + } + ch = t.streamsQuotaAvailable + return false + } + if !firstTry { + t.waitingStreams-- + } + t.streamQuota-- + h := it.(*headerFrame) + h.streamID = t.nextID + t.nextID += 2 + s.id = h.streamID + s.fc = &inFlow{limit: uint32(t.initialWindowSize)} + if t.streamQuota > 0 && t.waitingStreams > 0 { + select { + case t.streamsQuotaAvailable <- struct{}{}: + default: + } + } + return true + } + for { + success, err := t.controlBuf.executeAndPut(checkForStreamQuota, hdr) + if err != nil { + return nil, err + } + if success { + break + } + firstTry = false + select { + case <-ch: + case <-s.ctx.Done(): + return nil, ContextErr(s.ctx.Err()) + case <-t.goAway: + return nil, errStreamDrain + case <-t.ctx.Done(): + return nil, ErrConnClosing + } + } if t.statsHandler != nil { outHeader := &stats.OutHeader{ Client: true, @@ -533,58 +589,63 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea // CloseStream clears the footprint of a stream when the stream is not needed any more. // This must not be executed in reader's goroutine. func (t *http2Client) CloseStream(s *Stream, err error) { - t.mu.Lock() - if t.activeStreams == nil { - t.mu.Unlock() - return - } + var ( + rst bool + rstCode http2.ErrCode + ) if err != nil { - // notify in-flight streams, before the deletion - s.write(recvMsg{err: err}) + rst = true + rstCode = http2.ErrCodeCancel } - delete(t.activeStreams, s.id) - if t.state == draining && len(t.activeStreams) == 0 { - // The transport is draining and s is the last live stream on t. - t.mu.Unlock() - t.Close() + t.closeStream(s, err, rst, rstCode, nil, nil) +} + +func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.ErrCode, st *status.Status, mdata map[string][]string) { + // Set stream status to done. + if s.swapState(streamDone) == streamDone { + // If it was already done, return. return } - t.mu.Unlock() - // rstStream is true in case the stream is being closed at the client-side - // and the server needs to be intimated about it by sending a RST_STREAM - // frame. - // To make sure this frame is written to the wire before the headers of the - // next stream waiting for streamsQuota, we add to streamsQuota pool only - // after having acquired the writableChan to send RST_STREAM out (look at - // the controller() routine). - var rstStream bool - var rstError http2.ErrCode - defer func() { - // In case, the client doesn't have to send RST_STREAM to server - // we can safely add back to streamsQuota pool now. - if !rstStream { - t.streamsQuota.add(1) - return - } - t.controlBuf.put(&resetStream{s.id, rstError}) - }() - s.mu.Lock() - rstStream = s.rstStream - rstError = s.rstError - if s.state == streamDone { - s.mu.Unlock() - return + // status and trailers can be updated here without any synchronization because the stream goroutine will + // only read it after it sees an io.EOF error from read or write and we'll write those errors + // only after updating this. + s.status = st + if len(mdata) > 0 { + s.trailer = mdata } - if !s.headerDone { + if err != nil { + // This will unblock reads eventually. + s.write(recvMsg{err: err}) + } + // This will unblock write. + close(s.done) + // If headerChan isn't closed, then close it. + if atomic.SwapUint32(&s.headerDone, 1) == 0 { close(s.headerChan) - s.headerDone = true } - s.state = streamDone - s.mu.Unlock() - if err != nil && !rstStream { - rstStream = true - rstError = http2.ErrCodeCancel + cleanup := &cleanupStream{ + streamID: s.id, + onWrite: func() { + t.mu.Lock() + if t.activeStreams != nil { + delete(t.activeStreams, s.id) + } + t.mu.Unlock() + }, + rst: rst, + rstCode: rstCode, + } + addBackStreamQuota := func(interface{}) bool { + t.streamQuota++ + if t.streamQuota > 0 && t.waitingStreams > 0 { + select { + case t.streamsQuotaAvailable <- struct{}{}: + default: + } + } + return true } + t.controlBuf.executeAndPut(addBackStreamQuota, cleanup) } // Close kicks off the shutdown process of the transport. This should be called @@ -592,27 +653,21 @@ func (t *http2Client) CloseStream(s *Stream, err error) { // accessed any more. func (t *http2Client) Close() error { t.mu.Lock() + // Make sure we only Close once. if t.state == closing { t.mu.Unlock() return nil } t.state = closing - t.mu.Unlock() - t.cancel() - err := t.conn.Close() - t.mu.Lock() streams := t.activeStreams t.activeStreams = nil t.mu.Unlock() + t.controlBuf.finish() + t.cancel() + err := t.conn.Close() // Notify all active streams. for _, s := range streams { - s.mu.Lock() - if !s.headerDone { - close(s.headerChan) - s.headerDone = true - } - s.mu.Unlock() - s.write(recvMsg{err: ErrConnClosing}) + t.closeStream(s, ErrConnClosing, false, http2.ErrCodeNo, nil, nil) } if t.statsHandler != nil { connEnd := &stats.ConnEnd{ @@ -630,8 +685,8 @@ func (t *http2Client) Close() error { // closing. func (t *http2Client) GracefulClose() error { t.mu.Lock() - switch t.state { - case closing, draining: + // Make sure we move to draining only from active. + if t.state == draining || t.state == closing { t.mu.Unlock() return nil } @@ -647,110 +702,34 @@ func (t *http2Client) GracefulClose() error { // Write formats the data into HTTP2 data frame(s) and sends it out. The caller // should proceed only if Write returns nil. func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { - select { - case <-s.ctx.Done(): - return ContextErr(s.ctx.Err()) - case <-s.done: - return io.EOF - case <-t.ctx.Done(): - return ErrConnClosing - default: - } - - if hdr == nil && data == nil && opts.Last { - // stream.CloseSend uses this to send an empty frame with endStream=True - t.controlBuf.put(&dataFrame{streamID: s.id, endStream: true, f: func() {}}) - return nil - } - // Add data to header frame so that we can equally distribute data across frames. - emptyLen := http2MaxFrameLen - len(hdr) - if emptyLen > len(data) { - emptyLen = len(data) - } - hdr = append(hdr, data[:emptyLen]...) - data = data[emptyLen:] - var ( - streamQuota int - streamQuotaVer uint32 - err error - ) - for idx, r := range [][]byte{hdr, data} { - for len(r) > 0 { - size := http2MaxFrameLen - if size > len(r) { - size = len(r) - } - if streamQuota == 0 { // Used up all the locally cached stream quota. - // Get all the stream quota there is. - streamQuota, streamQuotaVer, err = s.sendQuotaPool.get(math.MaxInt32, s.waiters) - if err != nil { - return err - } - } - if size > streamQuota { - size = streamQuota - } - - // Get size worth quota from transport. - tq, _, err := t.sendQuotaPool.get(size, s.waiters) - if err != nil { - return err - } - if tq < size { - size = tq - } - ltq, _, err := t.localSendQuota.get(size, s.waiters) - if err != nil { - // Add the acquired quota back to transport. - t.sendQuotaPool.add(tq) - return err - } - // even if ltq is smaller than size we don't adjust size since - // ltq is only a soft limit. - streamQuota -= size - p := r[:size] - var endStream bool - // See if this is the last frame to be written. - if opts.Last { - if len(r)-size == 0 { // No more data in r after this iteration. - if idx == 0 { // We're writing data header. - if len(data) == 0 { // There's no data to follow. - endStream = true - } - } else { // We're writing data. - endStream = true - } - } - } - success := func() { - ltq := ltq - t.controlBuf.put(&dataFrame{streamID: s.id, endStream: endStream, d: p, f: func() { t.localSendQuota.add(ltq) }}) - r = r[size:] - } - failure := func() { // The stream quota version must have changed. - // Our streamQuota cache is invalidated now, so give it back. - s.sendQuotaPool.lockedAdd(streamQuota + size) - } - if !s.sendQuotaPool.compareAndExecute(streamQuotaVer, success, failure) { - // Couldn't send this chunk out. - t.sendQuotaPool.add(size) - t.localSendQuota.add(ltq) - streamQuota = 0 - } + if opts.Last { + // If it's the last message, update stream state. + if !s.compareAndSwapState(streamActive, streamWriteDone) { + return errStreamDone } + } else if s.getState() != streamActive { + return errStreamDone } - if streamQuota > 0 { // Add the left over quota back to stream. - s.sendQuotaPool.add(streamQuota) - } - if !opts.Last { - return nil - } - s.mu.Lock() - if s.state != streamDone { - s.state = streamWriteDone + df := &dataFrame{ + streamID: s.id, + endStream: opts.Last, + } + if hdr != nil || data != nil { // If it's not an empty data frame. + // Add some data to grpc message header so that we can equally + // distribute bytes across frames. + emptyLen := http2MaxFrameLen - len(hdr) + if emptyLen > len(data) { + emptyLen = len(data) + } + hdr = append(hdr, data[:emptyLen]...) + data = data[emptyLen:] + df.h, df.d = hdr, data + // TODO(mmukhi): The above logic in this if can be moved to loopyWriter's data handler. + if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { + return err + } } - s.mu.Unlock() - return nil + return t.controlBuf.put(df) } func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) { @@ -764,34 +743,17 @@ func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) { // of stream if the application is requesting data larger in size than // the window. func (t *http2Client) adjustWindow(s *Stream, n uint32) { - s.mu.Lock() - defer s.mu.Unlock() - if s.state == streamDone { - return - } if w := s.fc.maybeAdjust(n); w > 0 { - // Piggyback connection's window update along. - if cw := t.fc.resetPendingUpdate(); cw > 0 { - t.controlBuf.put(&windowUpdate{0, cw}) - } - t.controlBuf.put(&windowUpdate{s.id, w}) + t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w}) } } -// updateWindow adjusts the inbound quota for the stream and the transport. -// Window updates will deliver to the controller for sending when -// the cumulative quota exceeds the corresponding threshold. +// updateWindow adjusts the inbound quota for the stream. +// Window updates will be sent out when the cumulative quota +// exceeds the corresponding threshold. func (t *http2Client) updateWindow(s *Stream, n uint32) { - s.mu.Lock() - defer s.mu.Unlock() - if s.state == streamDone { - return - } if w := s.fc.onRead(n); w > 0 { - if cw := t.fc.resetPendingUpdate(); cw > 0 { - t.controlBuf.put(&windowUpdate{0, cw}) - } - t.controlBuf.put(&windowUpdate{s.id, w}) + t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w}) } } @@ -803,10 +765,13 @@ func (t *http2Client) updateFlowControl(n uint32) { for _, s := range t.activeStreams { s.fc.newLimit(n) } - t.initialWindowSize = int32(n) t.mu.Unlock() - t.controlBuf.put(&windowUpdate{0, t.fc.newLimit(n)}) - t.controlBuf.put(&settings{ + updateIWS := func(interface{}) bool { + t.initialWindowSize = int32(n) + return true + } + t.controlBuf.executeAndPut(updateIWS, &outgoingWindowUpdate{streamID: 0, increment: t.fc.newLimit(n)}) + t.controlBuf.put(&outgoingSettings{ ss: []http2.Setting{ { ID: http2.SettingInitialWindowSize, @@ -831,21 +796,24 @@ func (t *http2Client) handleData(f *http2.DataFrame) { // active(fast) streams from starving in presence of slow or // inactive streams. // - // Furthermore, if a bdpPing is being sent out we can piggyback - // connection's window update for the bytes we just received. + if w := t.fc.onData(uint32(size)); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: w, + }) + } if sendBDPPing { - if size != 0 { // Could've been an empty data frame. - t.controlBuf.put(&windowUpdate{0, uint32(size)}) + // Avoid excessive ping detection (e.g. in an L7 proxy) + // by sending a window update prior to the BDP ping. + + if w := t.fc.reset(); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: w, + }) } + t.controlBuf.put(bdpPing) - } else { - if err := t.fc.onData(uint32(size)); err != nil { - t.Close() - return - } - if w := t.fc.onRead(uint32(size)); w > 0 { - t.controlBuf.put(&windowUpdate{0, w}) - } } // Select the right stream to dispatch. s, ok := t.getStream(f) @@ -853,25 +821,15 @@ func (t *http2Client) handleData(f *http2.DataFrame) { return } if size > 0 { - s.mu.Lock() - if s.state == streamDone { - s.mu.Unlock() - return - } if err := s.fc.onData(uint32(size)); err != nil { - s.rstStream = true - s.rstError = http2.ErrCodeFlowControl - s.finish(status.New(codes.Internal, err.Error())) - s.mu.Unlock() - s.write(recvMsg{err: io.EOF}) + t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil) return } if f.Header().Flags.Has(http2.FlagDataPadded) { if w := s.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 { - t.controlBuf.put(&windowUpdate{s.id, w}) + t.controlBuf.put(&outgoingWindowUpdate{s.id, w}) } } - s.mu.Unlock() // TODO(bradfitz, zhaoq): A copy is required here because there is no // guarantee f.Data() is consumed before the arrival of next frame. // Can this copy be eliminated? @@ -884,14 +842,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) { // The server has closed the stream without sending trailers. Record that // the read direction is closed, and set the status appropriately. if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) { - s.mu.Lock() - if s.state == streamDone { - s.mu.Unlock() - return - } - s.finish(status.New(codes.Internal, "server closed the stream without sending trailers")) - s.mu.Unlock() - s.write(recvMsg{err: io.EOF}) + t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil) } } @@ -900,73 +851,56 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) { if !ok { return } - s.mu.Lock() - if s.state == streamDone { - s.mu.Unlock() - return - } - if !s.headerDone { - close(s.headerChan) - s.headerDone = true - } - code := http2.ErrCode(f.ErrCode) if code == http2.ErrCodeRefusedStream { // The stream was unprocessed by the server. - s.unprocessed = true + atomic.StoreUint32(&s.unprocessed, 1) } statusCode, ok := http2ErrConvTab[code] if !ok { warningf("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error %v", f.ErrCode) statusCode = codes.Unknown } - s.finish(status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode)) - s.mu.Unlock() - s.write(recvMsg{err: io.EOF}) + t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode), nil) } func (t *http2Client) handleSettings(f *http2.SettingsFrame, isFirst bool) { if f.IsAck() { return } - var rs []http2.Setting - var ps []http2.Setting - isMaxConcurrentStreamsMissing := true + var maxStreams *uint32 + var ss []http2.Setting f.ForeachSetting(func(s http2.Setting) error { if s.ID == http2.SettingMaxConcurrentStreams { - isMaxConcurrentStreamsMissing = false - } - if t.isRestrictive(s) { - rs = append(rs, s) - } else { - ps = append(ps, s) + maxStreams = new(uint32) + *maxStreams = s.Val + return nil } + ss = append(ss, s) return nil }) - if isFirst && isMaxConcurrentStreamsMissing { - // This means server is imposing no limits on - // maximum number of concurrent streams initiated by client. - // So we must remove our self-imposed limit. - ps = append(ps, http2.Setting{ - ID: http2.SettingMaxConcurrentStreams, - Val: math.MaxUint32, - }) + if isFirst && maxStreams == nil { + maxStreams = new(uint32) + *maxStreams = math.MaxUint32 } - t.applySettings(rs) - t.controlBuf.put(&settingsAck{}) - t.applySettings(ps) -} - -func (t *http2Client) isRestrictive(s http2.Setting) bool { - switch s.ID { - case http2.SettingMaxConcurrentStreams: - return int(s.Val) < t.maxStreams - case http2.SettingInitialWindowSize: - // Note: we don't acquire a lock here to read streamSendQuota - // because the same goroutine updates it later. - return s.Val < t.streamSendQuota - } - return false + sf := &incomingSettings{ + ss: ss, + } + if maxStreams == nil { + t.controlBuf.put(sf) + return + } + updateStreamQuota := func(interface{}) bool { + delta := int64(*maxStreams) - int64(t.maxConcurrentStreams) + t.maxConcurrentStreams = *maxStreams + t.streamQuota += delta + if delta > 0 && t.waitingStreams > 0 { + close(t.streamsQuotaAvailable) // wake all of them up. + t.streamsQuotaAvailable = make(chan struct{}, 1) + } + return true + } + t.controlBuf.executeAndPut(updateStreamQuota, sf) } func (t *http2Client) handlePing(f *http2.PingFrame) { @@ -984,7 +918,7 @@ func (t *http2Client) handlePing(f *http2.PingFrame) { func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { t.mu.Lock() - if t.state != reachable && t.state != draining { + if t.state == closing { t.mu.Unlock() return } @@ -1019,6 +953,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { t.setGoAwayReason(f) close(t.goAway) t.state = draining + t.controlBuf.put(&incomingGoAway{}) } // All streams with IDs greater than the GoAwayId // and smaller than the previous GoAway ID should be killed. @@ -1029,11 +964,8 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { for streamID, stream := range t.activeStreams { if streamID > id && streamID <= upperLimit { // The stream was unprocessed by the server. - stream.mu.Lock() - stream.unprocessed = true - stream.finish(statusGoAway) - stream.mu.Unlock() - close(stream.goAway) + atomic.StoreUint32(&stream.unprocessed, 1) + t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil) } } t.prevGoAwayID = id @@ -1065,15 +997,10 @@ func (t *http2Client) GetGoAwayReason() GoAwayReason { } func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) { - id := f.Header().StreamID - incr := f.Increment - if id == 0 { - t.sendQuotaPool.add(int(incr)) - return - } - if s, ok := t.getStream(f); ok { - s.sendQuotaPool.add(int(incr)) - } + t.controlBuf.put(&incomingWindowUpdate{ + streamID: f.Header().StreamID, + increment: f.Increment, + }) } // operateHeaders takes action on the decoded headers. @@ -1082,18 +1009,11 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { if !ok { return } - s.mu.Lock() - s.bytesReceived = true - s.mu.Unlock() + atomic.StoreUint32(&s.bytesReceived, 1) var state decodeState if err := state.decodeResponseHeader(frame); err != nil { - s.mu.Lock() - if !s.headerDone { - close(s.headerChan) - s.headerDone = true - } - s.mu.Unlock() - s.write(recvMsg{err: err}) + // TODO(mmukhi, dfawley): Perhaps send a reset stream. + t.closeStream(s, err, false, http2.ErrCodeNo, nil, nil) // Something wrong. Stops reading even when there is remaining. return } @@ -1117,40 +1037,25 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { } } }() - - s.mu.Lock() - if !s.headerDone { + // If headers haven't been received yet. + if atomic.SwapUint32(&s.headerDone, 1) == 0 { if !endStream { // Headers frame is not actually a trailers-only frame. isHeader = true + // These values can be set without any synchronization because + // stream goroutine will read it only after seeing a closed + // headerChan which we'll close after setting this. s.recvCompress = state.encoding if len(state.mdata) > 0 { s.header = state.mdata } } close(s.headerChan) - s.headerDone = true } - if !endStream || s.state == streamDone { - s.mu.Unlock() + if !endStream { return } - if len(state.mdata) > 0 { - s.trailer = state.mdata - } - s.finish(state.status()) - s.mu.Unlock() - s.write(recvMsg{err: io.EOF}) -} - -func handleMalformedHTTP2(s *Stream, err error) { - s.mu.Lock() - if !s.headerDone { - close(s.headerChan) - s.headerDone = true - } - s.mu.Unlock() - s.write(recvMsg{err: err}) + t.closeStream(s, io.EOF, false, http2.ErrCodeNo, state.status(), state.mdata) } // reader runs as a separate goroutine in charge of reading data from network @@ -1160,6 +1065,7 @@ func handleMalformedHTTP2(s *Stream, err error) { // optimal. // TODO(zhaoq): Check the validity of the incoming frame sequence. func (t *http2Client) reader() { + defer close(t.readerDone) // Check the validity of server preface. frame, err := t.framer.fr.ReadFrame() if err != nil { @@ -1189,7 +1095,8 @@ func (t *http2Client) reader() { t.mu.Unlock() if s != nil { // use error detail to provide better err message - handleMalformedHTTP2(s, streamErrorf(http2ErrConvTab[se.Code], "%v", t.framer.fr.ErrorDetail())) + // TODO(mmukhi, dfawley): Perhaps send a RST_STREAM to the server. + t.closeStream(s, streamErrorf(http2ErrConvTab[se.Code], "%v", t.framer.fr.ErrorDetail()), false, http2.ErrCodeNo, nil, nil) } continue } else { @@ -1219,109 +1126,6 @@ func (t *http2Client) reader() { } } -func (t *http2Client) applySettings(ss []http2.Setting) { - for _, s := range ss { - switch s.ID { - case http2.SettingMaxConcurrentStreams: - // TODO(zhaoq): This is a hack to avoid significant refactoring of the - // code to deal with the unrealistic int32 overflow. Probably will try - // to find a better way to handle this later. - if s.Val > math.MaxInt32 { - s.Val = math.MaxInt32 - } - ms := t.maxStreams - t.maxStreams = int(s.Val) - t.streamsQuota.add(int(s.Val) - ms) - case http2.SettingInitialWindowSize: - t.mu.Lock() - for _, stream := range t.activeStreams { - // Adjust the sending quota for each stream. - stream.sendQuotaPool.addAndUpdate(int(s.Val) - int(t.streamSendQuota)) - } - t.streamSendQuota = s.Val - t.mu.Unlock() - } - } -} - -// TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer) -// is duplicated between the client and the server. -// The transport layer needs to be refactored to take care of this. -func (t *http2Client) itemHandler(i item) (err error) { - defer func() { - if err != nil { - errorf(" error in itemHandler: %v", err) - } - }() - switch i := i.(type) { - case *dataFrame: - if err := t.framer.fr.WriteData(i.streamID, i.endStream, i.d); err != nil { - return err - } - i.f() - return nil - case *headerFrame: - t.hBuf.Reset() - for _, f := range i.hf { - t.hEnc.WriteField(f) - } - endHeaders := false - first := true - for !endHeaders { - size := t.hBuf.Len() - if size > http2MaxFrameLen { - size = http2MaxFrameLen - } else { - endHeaders = true - } - if first { - first = false - err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{ - StreamID: i.streamID, - BlockFragment: t.hBuf.Next(size), - EndStream: i.endStream, - EndHeaders: endHeaders, - }) - } else { - err = t.framer.fr.WriteContinuation( - i.streamID, - endHeaders, - t.hBuf.Next(size), - ) - } - if err != nil { - return err - } - } - return nil - case *windowUpdate: - return t.framer.fr.WriteWindowUpdate(i.streamID, i.increment) - case *settings: - return t.framer.fr.WriteSettings(i.ss...) - case *settingsAck: - return t.framer.fr.WriteSettingsAck() - case *resetStream: - // If the server needs to be to intimated about stream closing, - // then we need to make sure the RST_STREAM frame is written to - // the wire before the headers of the next stream waiting on - // streamQuota. We ensure this by adding to the streamsQuota pool - // only after having acquired the writableChan to send RST_STREAM. - err := t.framer.fr.WriteRSTStream(i.streamID, i.code) - t.streamsQuota.add(1) - return err - case *flushIO: - return t.framer.writer.Flush() - case *ping: - if !i.ack { - t.bdpEst.timesnap(i.data) - } - return t.framer.fr.WritePing(i.ack, i.data) - default: - errorf("transport: http2Client.controller got unexpected item type %v", i) - return fmt.Errorf("transport: http2Client.controller got unexpected item type %v", i) - } -} - // keepalive running in a separate goroutune makes sure the connection is alive by sending pings. func (t *http2Client) keepalive() { p := &ping{data: [8]byte{}} diff --git a/transport/http2_server.go b/transport/http2_server.go index 97b214c640ef..cd9f25a617e1 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -52,28 +52,25 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe // http2Server implements the ServerTransport interface with HTTP2. type http2Server struct { ctx context.Context + ctxDone <-chan struct{} // Cache the context.Done() chan cancel context.CancelFunc conn net.Conn + loopy *loopyWriter + readerDone chan struct{} // sync point to enable testing. + writerDone chan struct{} // sync point to enable testing. remoteAddr net.Addr localAddr net.Addr maxStreamID uint32 // max stream ID ever seen authInfo credentials.AuthInfo // auth info about the connection inTapHandle tap.ServerInHandle framer *framer - hBuf *bytes.Buffer // the buffer for HPACK encoding - hEnc *hpack.Encoder // HPACK encoder // The max number of concurrent streams. maxStreams uint32 // controlBuf delivers all the control related tasks (e.g., window // updates, reset streams, and various settings) to the controller. controlBuf *controlBuffer - fc *inFlow - // sendQuotaPool provides flow control to outbound message. - sendQuotaPool *quotaPool - // localSendQuota limits the amount of data that can be scheduled - // for writing before it is actually written out. - localSendQuota *quotaPool - stats stats.Handler + fc *trInFlow + stats stats.Handler // Flag to keep track of reading activity on transport. // 1 is true and 0 is false. activity uint32 // Accessed atomically. @@ -104,8 +101,6 @@ type http2Server struct { drainChan chan struct{} state transportState activeStreams map[uint32]*Stream - // the per-stream outbound flow control window size set by the peer. - streamSendQuota uint32 // idle is the time instant when the connection went idle. // This is either the beginning of the connection or when the number of // RPCs go down to 0. @@ -185,33 +180,30 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err if kep.MinTime == 0 { kep.MinTime = defaultKeepalivePolicyMinTime } - var buf bytes.Buffer ctx, cancel := context.WithCancel(context.Background()) t := &http2Server{ ctx: ctx, cancel: cancel, + ctxDone: ctx.Done(), conn: conn, remoteAddr: conn.RemoteAddr(), localAddr: conn.LocalAddr(), authInfo: config.AuthInfo, framer: framer, - hBuf: &buf, - hEnc: hpack.NewEncoder(&buf), + readerDone: make(chan struct{}), + writerDone: make(chan struct{}), maxStreams: maxStreams, inTapHandle: config.InTapHandle, - controlBuf: newControlBuffer(), - fc: &inFlow{limit: uint32(icwz)}, - sendQuotaPool: newQuotaPool(defaultWindowSize), - localSendQuota: newQuotaPool(defaultLocalSendQuota), + fc: &trInFlow{limit: uint32(icwz)}, state: reachable, activeStreams: make(map[uint32]*Stream), - streamSendQuota: defaultWindowSize, stats: config.StatsHandler, kp: kp, idle: time.Now(), kep: kep, initialWindowSize: iwz, } + t.controlBuf = newControlBuffer(t.ctxDone) if dynamicWindow { t.bdpEst = &bdpEstimator{ bdp: initialWindowSize, @@ -258,8 +250,11 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err t.handleSettings(sf) go func() { - loopyWriter(t.ctx, t.controlBuf, t.itemHandler) + t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst) + t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler + t.loopy.run() t.conn.Close() + close(t.writerDone) }() go t.keepalive() return t, nil @@ -268,12 +263,16 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err // operateHeader takes action on the decoded headers. func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (close bool) { streamID := frame.Header().StreamID - var state decodeState for _, hf := range frame.Fields { if err := state.processHeaderField(hf); err != nil { if se, ok := err.(StreamError); ok { - t.controlBuf.put(&resetStream{streamID, statusCodeConvTab[se.Code]}) + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: statusCodeConvTab[se.Code], + onWrite: func() {}, + }) } return } @@ -325,7 +324,12 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( s.ctx, err = t.inTapHandle(s.ctx, info) if err != nil { warningf("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err) - t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) + t.controlBuf.put(&cleanupStream{ + streamID: s.id, + rst: true, + rstCode: http2.ErrCodeRefusedStream, + onWrite: func() {}, + }) return } } @@ -336,7 +340,12 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( } if uint32(len(t.activeStreams)) >= t.maxStreams { t.mu.Unlock() - t.controlBuf.put(&resetStream{streamID, http2.ErrCodeRefusedStream}) + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: http2.ErrCodeRefusedStream, + onWrite: func() {}, + }) return } if streamID%2 != 1 || streamID <= t.maxStreamID { @@ -346,7 +355,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( return true } t.maxStreamID = streamID - s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota)) t.activeStreams[streamID] = s if len(t.activeStreams) == 1 { t.idle = time.Time{} @@ -367,19 +375,18 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( } t.stats.HandleRPC(s.ctx, inHeader) } + s.ctxDone = s.ctx.Done() + s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone) s.trReader = &transportReader{ reader: &recvBufferReader{ - ctx: s.ctx, - recv: s.buf, + ctx: s.ctx, + ctxDone: s.ctxDone, + recv: s.buf, }, windowHandler: func(n int) { t.updateWindow(s, uint32(n)) }, } - s.waiters = waiters{ - ctx: s.ctx, - tctx: t.ctx, - } handle(s) return } @@ -388,18 +395,26 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( // typically run in a separate goroutine. // traceCtx attaches trace to ctx and returns the new context. func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) { + defer close(t.readerDone) for { frame, err := t.framer.fr.ReadFrame() atomic.StoreUint32(&t.activity, 1) if err != nil { if se, ok := err.(http2.StreamError); ok { + warningf("transport: http2Server.HandleStreams encountered http2.StreamError: %v", se) t.mu.Lock() s := t.activeStreams[se.StreamID] t.mu.Unlock() if s != nil { - t.closeStream(s) + t.closeStream(s, true, se.Code, nil) + } else { + t.controlBuf.put(&cleanupStream{ + streamID: se.StreamID, + rst: true, + rstCode: se.Code, + onWrite: func() {}, + }) } - t.controlBuf.put(&resetStream{se.StreamID, se.Code}) continue } if err == io.EOF || err == io.ErrUnexpectedEOF { @@ -453,33 +468,20 @@ func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) { // of stream if the application is requesting data larger in size than // the window. func (t *http2Server) adjustWindow(s *Stream, n uint32) { - s.mu.Lock() - defer s.mu.Unlock() - if s.state == streamDone { - return - } if w := s.fc.maybeAdjust(n); w > 0 { - if cw := t.fc.resetPendingUpdate(); cw > 0 { - t.controlBuf.put(&windowUpdate{0, cw}) - } - t.controlBuf.put(&windowUpdate{s.id, w}) + t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w}) } + } // updateWindow adjusts the inbound quota for the stream and the transport. // Window updates will deliver to the controller for sending when // the cumulative quota exceeds the corresponding threshold. func (t *http2Server) updateWindow(s *Stream, n uint32) { - s.mu.Lock() - defer s.mu.Unlock() - if s.state == streamDone { - return - } if w := s.fc.onRead(n); w > 0 { - if cw := t.fc.resetPendingUpdate(); cw > 0 { - t.controlBuf.put(&windowUpdate{0, cw}) - } - t.controlBuf.put(&windowUpdate{s.id, w}) + t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, + increment: w, + }) } } @@ -493,8 +495,11 @@ func (t *http2Server) updateFlowControl(n uint32) { } t.initialWindowSize = int32(n) t.mu.Unlock() - t.controlBuf.put(&windowUpdate{0, t.fc.newLimit(n)}) - t.controlBuf.put(&settings{ + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: t.fc.newLimit(n), + }) + t.controlBuf.put(&outgoingSettings{ ss: []http2.Setting{ { ID: http2.SettingInitialWindowSize, @@ -519,23 +524,22 @@ func (t *http2Server) handleData(f *http2.DataFrame) { // Decoupling the connection flow control will prevent other // active(fast) streams from starving in presence of slow or // inactive streams. - // - // Furthermore, if a bdpPing is being sent out we can piggyback - // connection's window update for the bytes we just received. + if w := t.fc.onData(uint32(size)); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: w, + }) + } if sendBDPPing { - if size != 0 { // Could be an empty frame. - t.controlBuf.put(&windowUpdate{0, uint32(size)}) + // Avoid excessive ping detection (e.g. in an L7 proxy) + // by sending a window update prior to the BDP ping. + if w := t.fc.reset(); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: w, + }) } t.controlBuf.put(bdpPing) - } else { - if err := t.fc.onData(uint32(size)); err != nil { - errorf("transport: http2Server %v", err) - t.Close() - return - } - if w := t.fc.onRead(uint32(size)); w > 0 { - t.controlBuf.put(&windowUpdate{0, w}) - } } // Select the right stream to dispatch. s, ok := t.getStream(f) @@ -543,23 +547,15 @@ func (t *http2Server) handleData(f *http2.DataFrame) { return } if size > 0 { - s.mu.Lock() - if s.state == streamDone { - s.mu.Unlock() - return - } if err := s.fc.onData(uint32(size)); err != nil { - s.mu.Unlock() - t.closeStream(s) - t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl}) + t.closeStream(s, true, http2.ErrCodeFlowControl, nil) return } if f.Header().Flags.Has(http2.FlagDataPadded) { if w := s.fc.onRead(uint32(size) - uint32(len(f.Data()))); w > 0 { - t.controlBuf.put(&windowUpdate{s.id, w}) + t.controlBuf.put(&outgoingWindowUpdate{s.id, w}) } } - s.mu.Unlock() // TODO(bradfitz, zhaoq): A copy is required here because there is no // guarantee f.Data() is consumed before the arrival of next frame. // Can this copy be eliminated? @@ -571,11 +567,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) { } if f.Header().Flags.Has(http2.FlagDataEndStream) { // Received the end of stream from the client. - s.mu.Lock() - if s.state != streamDone { - s.state = streamReadDone - } - s.mu.Unlock() + s.compareAndSwapState(streamActive, streamReadDone) s.write(recvMsg{err: io.EOF}) } } @@ -585,50 +577,21 @@ func (t *http2Server) handleRSTStream(f *http2.RSTStreamFrame) { if !ok { return } - t.closeStream(s) + t.closeStream(s, false, 0, nil) } func (t *http2Server) handleSettings(f *http2.SettingsFrame) { if f.IsAck() { return } - var rs []http2.Setting - var ps []http2.Setting + var ss []http2.Setting f.ForeachSetting(func(s http2.Setting) error { - if t.isRestrictive(s) { - rs = append(rs, s) - } else { - ps = append(ps, s) - } + ss = append(ss, s) return nil }) - t.applySettings(rs) - t.controlBuf.put(&settingsAck{}) - t.applySettings(ps) -} - -func (t *http2Server) isRestrictive(s http2.Setting) bool { - switch s.ID { - case http2.SettingInitialWindowSize: - // Note: we don't acquire a lock here to read streamSendQuota - // because the same goroutine updates it later. - return s.Val < t.streamSendQuota - } - return false -} - -func (t *http2Server) applySettings(ss []http2.Setting) { - for _, s := range ss { - if s.ID == http2.SettingInitialWindowSize { - t.mu.Lock() - for _, stream := range t.activeStreams { - stream.sendQuotaPool.addAndUpdate(int(s.Val) - int(t.streamSendQuota)) - } - t.streamSendQuota = s.Val - t.mu.Unlock() - } - - } + t.controlBuf.put(&incomingSettings{ + ss: ss, + }) } const ( @@ -687,30 +650,15 @@ func (t *http2Server) handlePing(f *http2.PingFrame) { } func (t *http2Server) handleWindowUpdate(f *http2.WindowUpdateFrame) { - id := f.Header().StreamID - incr := f.Increment - if id == 0 { - t.sendQuotaPool.add(int(incr)) - return - } - if s, ok := t.getStream(f); ok { - s.sendQuotaPool.add(int(incr)) - } + t.controlBuf.put(&incomingWindowUpdate{ + streamID: f.Header().StreamID, + increment: f.Increment, + }) } // WriteHeader sends the header metedata md back to the client. func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { - select { - case <-s.ctx.Done(): - return ContextErr(s.ctx.Err()) - case <-t.ctx.Done(): - return ErrConnClosing - default: - } - - s.mu.Lock() - if s.headerOk || s.state == streamDone { - s.mu.Unlock() + if s.headerOk || s.getState() == streamDone { return ErrIllegalHeaderWrite } s.headerOk = true @@ -722,7 +670,6 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { } } md = s.header - s.mu.Unlock() // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields // first and create a slice of that exact size. headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else. @@ -744,6 +691,10 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { streamID: s.id, hf: headerFields, endStream: false, + onWrite: func() { + atomic.StoreUint32(&t.resetPingStrikes, 1) + }, + wq: s.wq, }) if t.stats != nil { // Note: WireLength is not set in outHeader. @@ -759,35 +710,19 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { // TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early // OK is adopted. func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { - select { - case <-t.ctx.Done(): - return ErrConnClosing - default: - } - - var headersSent, hasHeader bool - s.mu.Lock() - if s.state == streamDone { - s.mu.Unlock() - return nil - } - if s.headerOk { - headersSent = true - } - if s.header.Len() > 0 { - hasHeader = true - } - s.mu.Unlock() - - if !headersSent && hasHeader { - t.WriteHeader(s, nil) - headersSent = true + if !s.headerOk && s.header.Len() > 0 { + if err := t.WriteHeader(s, nil); err != nil { + return err + } + } else { + if s.getState() == streamDone { + return nil + } } - // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields // first and create a slice of that exact size. headerFields := make([]hpack.HeaderField, 0, 2) // grpc-status and grpc-message will be there if none else. - if !headersSent { + if !s.headerOk { headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)}) } @@ -814,108 +749,66 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) } } - t.controlBuf.put(&headerFrame{ + trailer := &headerFrame{ streamID: s.id, hf: headerFields, endStream: true, - }) + onWrite: func() { + atomic.StoreUint32(&t.resetPingStrikes, 1) + }, + } + t.closeStream(s, false, 0, trailer) if t.stats != nil { t.stats.HandleRPC(s.Context(), &stats.OutTrailer{}) } - t.closeStream(s) return nil } // Write converts the data into HTTP2 data frame and sends it out. Non-nil error // is returns if it fails (e.g., framing error, transport error). func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { - select { - case <-s.ctx.Done(): - return ContextErr(s.ctx.Err()) - case <-t.ctx.Done(): - return ErrConnClosing - default: - } - - var writeHeaderFrame bool - s.mu.Lock() - if !s.headerOk { - writeHeaderFrame = true - } - s.mu.Unlock() - if writeHeaderFrame { - t.WriteHeader(s, nil) + if !s.headerOk { // Headers haven't been written yet. + if err := t.WriteHeader(s, nil); err != nil { + // TODO(mmukhi, dfawley): Make sure this is the right code to return. + return streamErrorf(codes.Internal, "transport: %v", err) + } + } else { + // Writing headers checks for this condition. + if s.getState() == streamDone { + // TODO(mmukhi, dfawley): Should the server write also return io.EOF? + s.cancel() + select { + case <-t.ctx.Done(): + return ErrConnClosing + default: + } + return ContextErr(s.ctx.Err()) + } } - // Add data to header frame so that we can equally distribute data across frames. + // Add some data to header frame so that we can equally distribute bytes across frames. emptyLen := http2MaxFrameLen - len(hdr) if emptyLen > len(data) { emptyLen = len(data) } hdr = append(hdr, data[:emptyLen]...) data = data[emptyLen:] - var ( - streamQuota int - streamQuotaVer uint32 - err error - ) - for _, r := range [][]byte{hdr, data} { - for len(r) > 0 { - size := http2MaxFrameLen - if size > len(r) { - size = len(r) - } - if streamQuota == 0 { // Used up all the locally cached stream quota. - // Get all the stream quota there is. - streamQuota, streamQuotaVer, err = s.sendQuotaPool.get(math.MaxInt32, s.waiters) - if err != nil { - return err - } - } - if size > streamQuota { - size = streamQuota - } - // Get size worth quota from transport. - tq, _, err := t.sendQuotaPool.get(size, s.waiters) - if err != nil { - return err - } - if tq < size { - size = tq - } - ltq, _, err := t.localSendQuota.get(size, s.waiters) - if err != nil { - // Add the acquired quota back to transport. - t.sendQuotaPool.add(tq) - return err - } - // even if ltq is smaller than size we don't adjust size since, - // ltq is only a soft limit. - streamQuota -= size - p := r[:size] - success := func() { - ltq := ltq - t.controlBuf.put(&dataFrame{streamID: s.id, endStream: false, d: p, f: func() { - t.localSendQuota.add(ltq) - }}) - r = r[size:] - } - failure := func() { // The stream quota version must have changed. - // Our streamQuota cache is invalidated now, so give it back. - s.sendQuotaPool.lockedAdd(streamQuota + size) - } - if !s.sendQuotaPool.compareAndExecute(streamQuotaVer, success, failure) { - // Couldn't send this chunk out. - t.sendQuotaPool.add(size) - t.localSendQuota.add(ltq) - streamQuota = 0 - } - } + df := &dataFrame{ + streamID: s.id, + h: hdr, + d: data, + onEachWrite: func() { + atomic.StoreUint32(&t.resetPingStrikes, 1) + }, } - if streamQuota > 0 { - // ADd the left over quota back to stream. - s.sendQuotaPool.add(streamQuota) + if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { + select { + case <-t.ctx.Done(): + return ErrConnClosing + default: + } + return ContextErr(s.ctx.Err()) } - return nil + return t.controlBuf.put(df) } // keepalive running in a separate goroutine does the following: @@ -998,136 +891,6 @@ func (t *http2Server) keepalive() { } } -var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}} - -// TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer) -// is duplicated between the client and the server. -// The transport layer needs to be refactored to take care of this. -func (t *http2Server) itemHandler(i item) error { - switch i := i.(type) { - case *dataFrame: - // Reset ping strikes when sending data since this might cause - // the peer to send ping. - atomic.StoreUint32(&t.resetPingStrikes, 1) - if err := t.framer.fr.WriteData(i.streamID, i.endStream, i.d); err != nil { - return err - } - i.f() - return nil - case *headerFrame: - t.hBuf.Reset() - for _, f := range i.hf { - t.hEnc.WriteField(f) - } - first := true - endHeaders := false - for !endHeaders { - size := t.hBuf.Len() - if size > http2MaxFrameLen { - size = http2MaxFrameLen - } else { - endHeaders = true - } - var err error - if first { - first = false - err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{ - StreamID: i.streamID, - BlockFragment: t.hBuf.Next(size), - EndStream: i.endStream, - EndHeaders: endHeaders, - }) - } else { - err = t.framer.fr.WriteContinuation( - i.streamID, - endHeaders, - t.hBuf.Next(size), - ) - } - if err != nil { - return err - } - } - atomic.StoreUint32(&t.resetPingStrikes, 1) - return nil - case *windowUpdate: - return t.framer.fr.WriteWindowUpdate(i.streamID, i.increment) - case *settings: - return t.framer.fr.WriteSettings(i.ss...) - case *settingsAck: - return t.framer.fr.WriteSettingsAck() - case *resetStream: - return t.framer.fr.WriteRSTStream(i.streamID, i.code) - case *goAway: - t.mu.Lock() - if t.state == closing { - t.mu.Unlock() - // The transport is closing. - return fmt.Errorf("transport: Connection closing") - } - sid := t.maxStreamID - if !i.headsUp { - // Stop accepting more streams now. - t.state = draining - if len(t.activeStreams) == 0 { - i.closeConn = true - } - t.mu.Unlock() - if err := t.framer.fr.WriteGoAway(sid, i.code, i.debugData); err != nil { - return err - } - if i.closeConn { - // Abruptly close the connection following the GoAway (via - // loopywriter). But flush out what's inside the buffer first. - t.controlBuf.put(&flushIO{closeTr: true}) - } - return nil - } - t.mu.Unlock() - // For a graceful close, send out a GoAway with stream ID of MaxUInt32, - // Follow that with a ping and wait for the ack to come back or a timer - // to expire. During this time accept new streams since they might have - // originated before the GoAway reaches the client. - // After getting the ack or timer expiration send out another GoAway this - // time with an ID of the max stream server intends to process. - if err := t.framer.fr.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}); err != nil { - return err - } - if err := t.framer.fr.WritePing(false, goAwayPing.data); err != nil { - return err - } - go func() { - timer := time.NewTimer(time.Minute) - defer timer.Stop() - select { - case <-t.drainChan: - case <-timer.C: - case <-t.ctx.Done(): - return - } - t.controlBuf.put(&goAway{code: i.code, debugData: i.debugData}) - }() - return nil - case *flushIO: - if err := t.framer.writer.Flush(); err != nil { - return err - } - if i.closeTr { - return ErrConnClosing - } - return nil - case *ping: - if !i.ack { - t.bdpEst.timesnap(i.data) - } - return t.framer.fr.WritePing(i.ack, i.data) - default: - err := status.Errorf(codes.Internal, "transport: http2Server.controller got unexpected item type %t", i) - errorf("%v", err) - return err - } -} - // Close starts shutting down the http2Server transport. // TODO(zhaoq): Now the destruction is not blocked on any pending streams. This // could cause some resource issue. Revisit this later. @@ -1141,6 +904,7 @@ func (t *http2Server) Close() error { streams := t.activeStreams t.activeStreams = nil t.mu.Unlock() + t.controlBuf.finish() t.cancel() err := t.conn.Close() // Cancel all active streams. @@ -1156,27 +920,36 @@ func (t *http2Server) Close() error { // closeStream clears the footprint of a stream when the stream is not needed // any more. -func (t *http2Server) closeStream(s *Stream) { - t.mu.Lock() - delete(t.activeStreams, s.id) - if len(t.activeStreams) == 0 { - t.idle = time.Now() - } - if t.state == draining && len(t.activeStreams) == 0 { - defer t.controlBuf.put(&flushIO{closeTr: true}) +func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, hdr *headerFrame) { + if s.swapState(streamDone) == streamDone { + // If the stream was already done, return. + return } - t.mu.Unlock() // In case stream sending and receiving are invoked in separate // goroutines (e.g., bi-directional streaming), cancel needs to be // called to interrupt the potential blocking on other goroutines. s.cancel() - s.mu.Lock() - if s.state == streamDone { - s.mu.Unlock() - return + cleanup := &cleanupStream{ + streamID: s.id, + rst: rst, + rstCode: rstCode, + onWrite: func() { + t.mu.Lock() + if t.activeStreams != nil { + delete(t.activeStreams, s.id) + if len(t.activeStreams) == 0 { + t.idle = time.Now() + } + } + t.mu.Unlock() + }, + } + if hdr != nil { + hdr.cleanup = cleanup + t.controlBuf.put(hdr) + } else { + t.controlBuf.put(cleanup) } - s.state = streamDone - s.mu.Unlock() } func (t *http2Server) RemoteAddr() net.Addr { @@ -1197,6 +970,63 @@ func (t *http2Server) drain(code http2.ErrCode, debugData []byte) { t.controlBuf.put(&goAway{code: code, debugData: debugData, headsUp: true}) } +var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}} + +// Handles outgoing GoAway and returns true if loopy needs to put itself +// in draining mode. +func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) { + t.mu.Lock() + if t.state == closing { // TODO(mmukhi): This seems unnecessary. + t.mu.Unlock() + // The transport is closing. + return false, ErrConnClosing + } + sid := t.maxStreamID + if !g.headsUp { + // Stop accepting more streams now. + t.state = draining + if len(t.activeStreams) == 0 { + g.closeConn = true + } + t.mu.Unlock() + if err := t.framer.fr.WriteGoAway(sid, g.code, g.debugData); err != nil { + return false, err + } + if g.closeConn { + // Abruptly close the connection following the GoAway (via + // loopywriter). But flush out what's inside the buffer first. + t.framer.writer.Flush() + return false, fmt.Errorf("transport: Connection closing") + } + return true, nil + } + t.mu.Unlock() + // For a graceful close, send out a GoAway with stream ID of MaxUInt32, + // Follow that with a ping and wait for the ack to come back or a timer + // to expire. During this time accept new streams since they might have + // originated before the GoAway reaches the client. + // After getting the ack or timer expiration send out another GoAway this + // time with an ID of the max stream server intends to process. + if err := t.framer.fr.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}); err != nil { + return false, err + } + if err := t.framer.fr.WritePing(false, goAwayPing.data); err != nil { + return false, err + } + go func() { + timer := time.NewTimer(time.Minute) + defer timer.Stop() + select { + case <-t.drainChan: + case <-timer.C: + case <-t.ctx.Done(): + return + } + t.controlBuf.put(&goAway{code: g.code, debugData: g.debugData}) + }() + return false, nil +} + var rgen = rand.New(rand.NewSource(time.Now().UnixNano())) func getJitter(v time.Duration) time.Duration { diff --git a/transport/http_util.go b/transport/http_util.go index de37e38ec9f4..b595d3d18e0f 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -23,7 +23,6 @@ import ( "bytes" "encoding/base64" "fmt" - "io" "net" "net/http" "strconv" @@ -509,19 +508,63 @@ func decodeGrpcMessageUnchecked(msg string) string { return buf.String() } +type bufWriter struct { + buf []byte + offset int + batchSize int + conn net.Conn + err error + + onFlush func() +} + +func newBufWriter(conn net.Conn, batchSize int) *bufWriter { + return &bufWriter{ + buf: make([]byte, batchSize*2), + batchSize: batchSize, + conn: conn, + } +} + +func (w *bufWriter) Write(b []byte) (n int, err error) { + if w.err != nil { + return 0, w.err + } + n = copy(w.buf[w.offset:], b) + w.offset += n + if w.offset >= w.batchSize { + err = w.Flush() + } + return n, err +} + +func (w *bufWriter) Flush() error { + if w.err != nil { + return w.err + } + if w.offset == 0 { + return nil + } + if w.onFlush != nil { + w.onFlush() + } + _, w.err = w.conn.Write(w.buf[:w.offset]) + w.offset = 0 + return w.err +} + type framer struct { - numWriters int32 - reader io.Reader - writer *bufio.Writer - fr *http2.Framer + writer *bufWriter + fr *http2.Framer } func newFramer(conn net.Conn, writeBufferSize, readBufferSize int) *framer { + r := bufio.NewReaderSize(conn, readBufferSize) + w := newBufWriter(conn, writeBufferSize) f := &framer{ - reader: bufio.NewReaderSize(conn, readBufferSize), - writer: bufio.NewWriterSize(conn, writeBufferSize), + writer: w, + fr: http2.NewFramer(w, r), } - f.fr = http2.NewFramer(f.writer, f.reader) // Opt-in to Frame reuse API on framer to reduce garbage. // Frames aren't safe to read from after a subsequent call to ReadFrame. f.fr.SetReuseFrames() diff --git a/transport/transport.go b/transport/transport.go index e0c1e343e7a1..fc7658dba61f 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -19,16 +19,17 @@ // Package transport defines and implements message oriented communication // channel to complete various transactions (e.g., an RPC). It is meant for // grpc-internal usage and is not intended to be imported directly by users. -package transport // import "google.golang.org/grpc/transport" +package transport // externally used as import "google.golang.org/grpc/transport" import ( + "errors" "fmt" "io" "net" "sync" + "sync/atomic" "golang.org/x/net/context" - "golang.org/x/net/http2" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" @@ -57,6 +58,7 @@ type recvBuffer struct { c chan recvMsg mu sync.Mutex backlog []recvMsg + err error } func newRecvBuffer() *recvBuffer { @@ -68,6 +70,13 @@ func newRecvBuffer() *recvBuffer { func (b *recvBuffer) put(r recvMsg) { b.mu.Lock() + if b.err != nil { + b.mu.Unlock() + // An error had occurred earlier, don't accept more + // data or errors. + return + } + b.err = r.err if len(b.backlog) == 0 { select { case b.c <- r: @@ -101,14 +110,15 @@ func (b *recvBuffer) get() <-chan recvMsg { return b.c } +// // recvBufferReader implements io.Reader interface to read the data from // recvBuffer. type recvBufferReader struct { - ctx context.Context - goAway chan struct{} - recv *recvBuffer - last []byte // Stores the remaining data in the previous calls. - err error + ctx context.Context + ctxDone <-chan struct{} // cache of ctx.Done() (for performance). + recv *recvBuffer + last []byte // Stores the remaining data in the previous calls. + err error } // Read reads the next len(p) bytes from last. If last is drained, it tries to @@ -130,10 +140,8 @@ func (r *recvBufferReader) read(p []byte) (n int, err error) { return copied, nil } select { - case <-r.ctx.Done(): + case <-r.ctxDone: return 0, ContextErr(r.ctx.Err()) - case <-r.goAway: - return 0, errStreamDrain case m := <-r.recv.get(): r.recv.load() if m.err != nil { @@ -145,61 +153,7 @@ func (r *recvBufferReader) read(p []byte) (n int, err error) { } } -// All items in an out of a controlBuffer should be the same type. -type item interface { - item() -} - -// controlBuffer is an unbounded channel of item. -type controlBuffer struct { - c chan item - mu sync.Mutex - backlog []item -} - -func newControlBuffer() *controlBuffer { - b := &controlBuffer{ - c: make(chan item, 1), - } - return b -} - -func (b *controlBuffer) put(r item) { - b.mu.Lock() - if len(b.backlog) == 0 { - select { - case b.c <- r: - b.mu.Unlock() - return - default: - } - } - b.backlog = append(b.backlog, r) - b.mu.Unlock() -} - -func (b *controlBuffer) load() { - b.mu.Lock() - if len(b.backlog) > 0 { - select { - case b.c <- b.backlog[0]: - b.backlog[0] = nil - b.backlog = b.backlog[1:] - default: - } - } - b.mu.Unlock() -} - -// get returns the channel that receives an item in the buffer. -// -// Upon receipt of an item, the caller should call load to send another -// item onto the channel if there is any. -func (b *controlBuffer) get() <-chan item { - return b.c -} - -type streamState uint8 +type streamState uint32 const ( streamActive streamState = iota @@ -214,8 +168,8 @@ type Stream struct { st ServerTransport // nil for client side Stream ctx context.Context // the associated context of the stream cancel context.CancelFunc // always nil for client side Stream - done chan struct{} // closed when the final status arrives - goAway chan struct{} // closed when a GOAWAY control message is received + done chan struct{} // closed at the end of stream to unblock writers. On the client side. + ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance) method string // the associated RPC method of the stream recvCompress string sendCompress string @@ -223,47 +177,51 @@ type Stream struct { trReader io.Reader fc *inFlow recvQuota uint32 - waiters waiters + wq *writeQuota // Callback to state application's intentions to read data. This // is used to adjust flow control, if needed. requestRead func(int) - sendQuotaPool *quotaPool - headerChan chan struct{} // closed to indicate the end of header metadata. - headerDone bool // set when headerChan is closed. Used to avoid closing headerChan multiple times. - header metadata.MD // the received header metadata. - trailer metadata.MD // the key-value map of trailer metadata. + headerChan chan struct{} // closed to indicate the end of header metadata. + headerDone uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times. + header metadata.MD // the received header metadata. + trailer metadata.MD // the key-value map of trailer metadata. - mu sync.RWMutex // guard the following - headerOk bool // becomes true from the first header is about to send + headerOk bool // becomes true from the first header is about to send state streamState status *status.Status // the status error received from the server - rstStream bool // indicates whether a RST_STREAM frame needs to be sent - rstError http2.ErrCode // the error that needs to be sent along with the RST_STREAM frame - - bytesReceived bool // indicates whether any bytes have been received on this stream - unprocessed bool // set if the server sends a refused stream or GOAWAY including this stream + bytesReceived uint32 // indicates whether any bytes have been received on this stream + unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream // contentSubtype is the content-subtype for requests. // this must be lowercase or the behavior is undefined. contentSubtype string } +func (s *Stream) swapState(st streamState) streamState { + return streamState(atomic.SwapUint32((*uint32)(&s.state), uint32(st))) +} + +func (s *Stream) compareAndSwapState(oldState, newState streamState) bool { + return atomic.CompareAndSwapUint32((*uint32)(&s.state), uint32(oldState), uint32(newState)) +} + +func (s *Stream) getState() streamState { + return streamState(atomic.LoadUint32((*uint32)(&s.state))) +} + func (s *Stream) waitOnHeader() error { if s.headerChan == nil { // On the server headerChan is always nil since a stream originates // only after having received headers. return nil } - wc := s.waiters select { - case <-wc.ctx.Done(): - return ContextErr(wc.ctx.Err()) - case <-wc.goAway: - return errStreamDrain + case <-s.ctx.Done(): + return ContextErr(s.ctx.Err()) case <-s.headerChan: return nil } @@ -289,12 +247,6 @@ func (s *Stream) Done() <-chan struct{} { return s.done } -// GoAway returns a channel which is closed when the server sent GoAways signal -// before this stream was initiated. -func (s *Stream) GoAway() <-chan struct{} { - return s.goAway -} - // Header acquires the key-value pairs of header metadata once it // is available. It blocks until i) the metadata is ready or ii) there is no // header metadata or iii) the stream is canceled/expired. @@ -303,6 +255,9 @@ func (s *Stream) Header() (metadata.MD, error) { // Even if the stream is closed, header is returned if available. select { case <-s.headerChan: + if s.header == nil { + return nil, nil + } return s.header.Copy(), nil default: } @@ -312,10 +267,10 @@ func (s *Stream) Header() (metadata.MD, error) { // Trailer returns the cached trailer metedata. Note that if it is not called // after the entire stream is done, it could return an empty MD. Client // side only. +// It can be safely read only after stream has ended that is either read +// or write have returned io.EOF. func (s *Stream) Trailer() metadata.MD { - s.mu.RLock() c := s.trailer.Copy() - s.mu.RUnlock() return c } @@ -345,24 +300,23 @@ func (s *Stream) Method() string { } // Status returns the status received from the server. +// Status can be read safely only after the stream has ended, +// that is, read or write has returned io.EOF. func (s *Stream) Status() *status.Status { return s.status } // SetHeader sets the header metadata. This can be called multiple times. // Server side only. +// This should not be called in parallel to other data writes. func (s *Stream) SetHeader(md metadata.MD) error { - s.mu.Lock() - if s.headerOk || s.state == streamDone { - s.mu.Unlock() - return ErrIllegalHeaderWrite - } if md.Len() == 0 { - s.mu.Unlock() return nil } + if s.headerOk || atomic.LoadUint32((*uint32)(&s.state)) == uint32(streamDone) { + return ErrIllegalHeaderWrite + } s.header = metadata.Join(s.header, md) - s.mu.Unlock() return nil } @@ -376,13 +330,12 @@ func (s *Stream) SendHeader(md metadata.MD) error { // SetTrailer sets the trailer metadata which will be sent with the RPC status // by the server. This can be called multiple times. Server side only. +// This should not be called parallel to other data writes. func (s *Stream) SetTrailer(md metadata.MD) error { if md.Len() == 0 { return nil } - s.mu.Lock() s.trailer = metadata.Join(s.trailer, md) - s.mu.Unlock() return nil } @@ -422,29 +375,15 @@ func (t *transportReader) Read(p []byte) (n int, err error) { return } -// finish sets the stream's state and status, and closes the done channel. -// s.mu must be held by the caller. st must always be non-nil. -func (s *Stream) finish(st *status.Status) { - s.status = st - s.state = streamDone - close(s.done) -} - // BytesReceived indicates whether any bytes have been received on this stream. func (s *Stream) BytesReceived() bool { - s.mu.Lock() - br := s.bytesReceived - s.mu.Unlock() - return br + return atomic.LoadUint32(&s.bytesReceived) == 1 } // Unprocessed indicates whether the server did not process this stream -- // i.e. it sent a refused stream or GOAWAY including this stream ID. func (s *Stream) Unprocessed() bool { - s.mu.Lock() - br := s.unprocessed - s.mu.Unlock() - return br + return atomic.LoadUint32(&s.unprocessed) == 1 } // GoString is implemented by Stream so context.String() won't @@ -694,6 +633,9 @@ var ( // connection is draining. This could be caused by goaway or balancer // removing the address. errStreamDrain = streamErrorf(codes.Unavailable, "the connection is draining") + // errStreamDone is returned from write at the client side to indiacte application + // layer of an error. + errStreamDone = errors.New("tne stream is done") // StatusGoAway indicates that the server sent a GOAWAY that included this // stream's ID in unprocessed RPCs. statusGoAway = status.New(codes.Unavailable, "the stream is rejected because server is draining the connection") @@ -711,15 +653,6 @@ func (e StreamError) Error() string { return fmt.Sprintf("stream error: code = %s desc = %q", e.Code, e.Desc) } -// waiters are passed to quotaPool get methods to -// wait on in addition to waiting on quota. -type waiters struct { - ctx context.Context - tctx context.Context - done chan struct{} - goAway chan struct{} -} - // GoAwayReason contains the reason for the GoAway frame received. type GoAwayReason uint8 @@ -733,39 +666,3 @@ const ( // "too_many_pings". GoAwayTooManyPings GoAwayReason = 2 ) - -// loopyWriter is run in a separate go routine. It is the single code path that will -// write data on wire. -func loopyWriter(ctx context.Context, cbuf *controlBuffer, handler func(item) error) { - for { - select { - case i := <-cbuf.get(): - cbuf.load() - if err := handler(i); err != nil { - errorf("transport: Error while handling item. Err: %v", err) - return - } - case <-ctx.Done(): - return - } - hasData: - for { - select { - case i := <-cbuf.get(): - cbuf.load() - if err := handler(i); err != nil { - errorf("transport: Error while handling item. Err: %v", err) - return - } - case <-ctx.Done(): - return - default: - if err := handler(&flushIO{}); err != nil { - errorf("transport: Error while flushing. Err: %v", err) - return - } - break hasData - } - } - } -} diff --git a/transport/transport_test.go b/transport/transport_test.go index 42261df9939b..1eb1ef9e680b 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -29,6 +29,7 @@ import ( "net" "net/http" "reflect" + "runtime" "strconv" "strings" "sync" @@ -156,7 +157,12 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) { p = make([]byte, n+1) } } - conn.controlBuf.put(&dataFrame{s.id, false, p, func() {}}) + conn.controlBuf.put(&dataFrame{ + streamID: s.id, + h: nil, + d: p, + onEachWrite: func() {}, + }) sent += len(p) } } @@ -190,17 +196,24 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) { time.Sleep(2 * time.Second) _, err := s.Read(p) if err != nil { - t.Fatalf("s.Read(_) = _, %v, want _, ", err) + t.Errorf("s.Read(_) = _, %v, want _, ", err) return } if !bytes.Equal(p, req) { - t.Fatalf("handleStream got %v, want %v", p, req) + t.Errorf("handleStream got %v, want %v", p, req) + return } // send a response back to the client. - h.t.Write(s, nil, resp, &Options{}) + if err := h.t.Write(s, nil, resp, &Options{}); err != nil { + t.Errorf("server Write got %v, want ", err) + return + } // send the trailer to end the stream. - h.t.WriteStatus(s, status.New(codes.OK, "")) + if err := h.t.WriteStatus(s, status.New(codes.OK, "")); err != nil { + t.Errorf("server WriteStatus got %v, want ", err) + return + } } func (h *testStreamHandler) handleStreamDelayWrite(t *testing.T, s *Stream) { @@ -213,19 +226,26 @@ func (h *testStreamHandler) handleStreamDelayWrite(t *testing.T, s *Stream) { p := make([]byte, len(req)) _, err := s.Read(p) if err != nil { - t.Fatalf("s.Read(_) = _, %v, want _, ", err) + t.Errorf("s.Read(_) = _, %v, want _, ", err) return } if !bytes.Equal(p, req) { - t.Fatalf("handleStream got %v, want %v", p, req) + t.Errorf("handleStream got %v, want %v", p, req) + return } // Wait before sending. Give time to client to start reading // before server starts sending. time.Sleep(2 * time.Second) - h.t.Write(s, nil, resp, &Options{}) + if err := h.t.Write(s, nil, resp, &Options{}); err != nil { + t.Errorf("server Write got %v, want ", err) + return + } // send the trailer to end the stream. - h.t.WriteStatus(s, status.New(codes.OK, "")) + if err := h.t.WriteStatus(s, status.New(codes.OK, "")); err != nil { + t.Errorf("server WriteStatus got %v, want ", err) + return + } } // start starts server. Other goroutines should block on s.readyChan for further operations. @@ -345,14 +365,19 @@ func (s *server) stop() { s.mu.Unlock() } -func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, ClientTransport) { - return setUpWithOptions(t, port, &ServerConfig{MaxStreams: maxStreams}, ht, ConnectOptions{}) -} - -func setUpWithOptions(t *testing.T, port int, serverConfig *ServerConfig, ht hType, copts ConnectOptions) (*server, ClientTransport) { +func setUpServerOnly(t *testing.T, port int, serverConfig *ServerConfig, ht hType) *server { server := &server{startedErr: make(chan error, 1)} go server.start(t, port, serverConfig, ht) server.wait(t, 2*time.Second) + return server +} + +func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, ClientTransport) { + return setUpWithOptions(t, port, &ServerConfig{MaxStreams: maxStreams}, ht, ConnectOptions{}, func() {}) +} + +func setUpWithOptions(t *testing.T, port int, serverConfig *ServerConfig, ht hType, copts ConnectOptions, onHandshake func()) (*server, ClientTransport) { + server := setUpServerOnly(t, port, serverConfig, ht) addr := "localhost:" + server.port var ( ct ClientTransport @@ -362,7 +387,7 @@ func setUpWithOptions(t *testing.T, port int, serverConfig *ServerConfig, ht hTy Addr: addr, } connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) - ct, connErr = NewClientTransport(connectCtx, context.Background(), target, copts, func() {}) + ct, connErr = NewClientTransport(connectCtx, context.Background(), target, copts, onHandshake) if connErr != nil { cancel() // Do not cancel in success path. t.Fatalf("failed to create transport: %v", connErr) @@ -404,7 +429,7 @@ func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, done chan net.Con // sends StreamError to concurrent stream reader. func TestInflightStreamClosing(t *testing.T) { serverConfig := &ServerConfig{} - server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) + server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}, func() {}) defer server.stop() defer client.Close() @@ -446,17 +471,14 @@ func TestMaxConnectionIdle(t *testing.T) { MaxConnectionIdle: 2 * time.Second, }, } - server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) + server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}, func() {}) defer server.stop() defer client.Close() stream, err := client.NewStream(context.Background(), &CallHdr{Flush: true}) if err != nil { t.Fatalf("Client failed to create RPC request: %v", err) } - stream.mu.Lock() - stream.rstStream = true - stream.mu.Unlock() - client.CloseStream(stream, nil) + client.(*http2Client).closeStream(stream, io.EOF, true, http2.ErrCodeCancel, nil, nil) // wait for server to see that closed stream and max-age logic to send goaway after no new RPCs are mode timeout := time.NewTimer(time.Second * 4) select { @@ -476,7 +498,7 @@ func TestMaxConnectionIdleNegative(t *testing.T) { MaxConnectionIdle: 2 * time.Second, }, } - server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) + server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}, func() {}) defer server.stop() defer client.Close() _, err := client.NewStream(context.Background(), &CallHdr{Flush: true}) @@ -502,7 +524,7 @@ func TestMaxConnectionAge(t *testing.T) { MaxConnectionAge: 2 * time.Second, }, } - server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) + server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}, func() {}) defer server.stop() defer client.Close() _, err := client.NewStream(context.Background(), &CallHdr{}) @@ -529,7 +551,7 @@ func TestKeepaliveServer(t *testing.T) { Timeout: 1 * time.Second, }, } - server, c := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) + server, c := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}, func() {}) defer server.stop() defer c.Close() client, err := net.Dial("tcp", server.lis.Addr().String()) @@ -572,7 +594,7 @@ func TestKeepaliveServerNegative(t *testing.T) { Timeout: 1 * time.Second, }, } - server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) + server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}, func() {}) defer server.stop() defer client.Close() // Give keepalive logic some time by sleeping. @@ -666,7 +688,7 @@ func TestKeepaliveClientStaysHealthyWithResponsiveServer(t *testing.T) { Time: 2 * time.Second, // Keepalive time = 2 sec. Timeout: 1 * time.Second, // Keepalive timeout = 1 sec. PermitWithoutStream: true, // Run keepalive even with no RPCs. - }}) + }}, func() {}) defer s.stop() defer tr.Close() // Give keep alive some time. @@ -693,7 +715,7 @@ func TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) { PermitWithoutStream: true, }, } - server, client := setUpWithOptions(t, 0, serverConfig, normal, clientOptions) + server, client := setUpWithOptions(t, 0, serverConfig, normal, clientOptions, func() {}) defer server.stop() defer client.Close() @@ -727,7 +749,7 @@ func TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) { Timeout: 1 * time.Second, }, } - server, client := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions) + server, client := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions, func() {}) defer server.stop() defer client.Close() @@ -766,7 +788,7 @@ func TestKeepaliveServerEnforcementWithObeyingClientNoRPC(t *testing.T) { PermitWithoutStream: true, }, } - server, client := setUpWithOptions(t, 0, serverConfig, normal, clientOptions) + server, client := setUpWithOptions(t, 0, serverConfig, normal, clientOptions, func() {}) defer server.stop() defer client.Close() @@ -793,7 +815,7 @@ func TestKeepaliveServerEnforcementWithObeyingClientWithRPC(t *testing.T) { Timeout: 1 * time.Second, }, } - server, client := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions) + server, client := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions, func() {}) defer server.stop() defer client.Close() @@ -945,12 +967,16 @@ func TestLargeMessageWithDelayRead(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - s, err := ct.NewStream(context.Background(), callHdr) + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10)) + defer cancel() + s, err := ct.NewStream(ctx, callHdr) if err != nil { t.Errorf("%v.NewStream(_, _) = _, %v, want _, ", ct, err) + return } - if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { + if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil { t.Errorf("%v.Write(_, _, _) = %v, want ", ct, err) + return } p := make([]byte, len(expectedResponseLarge)) @@ -958,6 +984,7 @@ func TestLargeMessageWithDelayRead(t *testing.T) { time.Sleep(2 * time.Second) if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { t.Errorf("s.Read(_) = _, %v, want _, ", err) + return } if _, err = s.Read(p); err != io.EOF { t.Errorf("Failed to complete the stream %v; want ", err) @@ -980,19 +1007,24 @@ func TestLargeMessageDelayWrite(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - s, err := ct.NewStream(context.Background(), callHdr) + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10)) + defer cancel() + s, err := ct.NewStream(ctx, callHdr) if err != nil { t.Errorf("%v.NewStream(_, _) = _, %v, want _, ", ct, err) + return } // Give time to server to start reading before client starts sending. time.Sleep(2 * time.Second) - if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { + if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil { t.Errorf("%v.Write(_, _, _) = %v, want ", ct, err) + return } p := make([]byte, len(expectedResponseLarge)) if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { t.Errorf("io.ReadFull(%v) = _, %v, want %v, ", err, p, expectedResponse) + return } if _, err = s.Read(p); err != io.EOF { t.Errorf("Failed to complete the stream %v; want ", err) @@ -1005,17 +1037,33 @@ func TestLargeMessageDelayWrite(t *testing.T) { } func TestGracefulClose(t *testing.T) { - server, ct := setUp(t, 0, math.MaxUint32, normal) - callHdr := &CallHdr{ - Host: "localhost", - Method: "foo.Small", - } - s, err := ct.NewStream(context.Background(), callHdr) + server, ct := setUp(t, 0, math.MaxUint32, pingpong) + defer server.stop() + defer ct.Close() + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10)) + defer cancel() + s, err := ct.NewStream(ctx, &CallHdr{}) if err != nil { - t.Fatalf("%v.NewStream(_, _) = _, %v, want _, ", ct, err) + t.Fatalf("NewStream(_, _) = _, %v, want _, ", err) + } + msg := make([]byte, 1024) + outgoingHeader := make([]byte, 5) + outgoingHeader[0] = byte(0) + binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(len(msg))) + incomingHeader := make([]byte, 5) + if err := ct.Write(s, outgoingHeader, msg, &Options{}); err != nil { + t.Fatalf("Error while writing: %v", err) + } + if _, err := s.Read(incomingHeader); err != nil { + t.Fatalf("Error while reading: %v", err) + } + sz := binary.BigEndian.Uint32(incomingHeader[1:]) + recvMsg := make([]byte, int(sz)) + if _, err := s.Read(recvMsg); err != nil { + t.Fatalf("Error while reading: %v", err) } if err = ct.GracefulClose(); err != nil { - t.Fatalf("%v.GracefulClose() = %v, want ", ct, err) + t.Fatalf("GracefulClose() = %v, want ", err) } var wg sync.WaitGroup // Expect the failure for all the follow-up streams because ct has been closed gracefully. @@ -1023,29 +1071,22 @@ func TestGracefulClose(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - if _, err := ct.NewStream(context.Background(), callHdr); err != errStreamDrain { - t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", ct, err, errStreamDrain) + str, err := ct.NewStream(context.Background(), &CallHdr{}) + if err == errStreamDrain { + return + } + ct.Write(str, nil, nil, &Options{Last: true}) + if _, err := str.Read(make([]byte, 8)); err != errStreamDrain { + t.Errorf("_.NewStream(_, _) = _, %v, want _, %v", err, errStreamDrain) } }() } - opts := Options{ - Last: true, - Delay: false, + ct.Write(s, nil, nil, &Options{Last: true}) + if _, err := s.Read(incomingHeader); err != io.EOF { + t.Fatalf("Client expected EOF from the server. Got: %v", err) } // The stream which was created before graceful close can still proceed. - if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != io.EOF { - t.Fatalf("%v.Write(_, _, _) = %v, want ", ct, err) - } - p := make([]byte, len(expectedResponse)) - if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponse) { - t.Fatalf("s.Read(%v) = _, %v, want %v, ", err, p, expectedResponse) - } - if _, err = s.Read(p); err != io.EOF { - t.Fatalf("Failed to complete the stream %v; want ", err) - } wg.Wait() - ct.Close() - server.stop() } func TestLargeMessageSuspension(t *testing.T) { @@ -1061,81 +1102,96 @@ func TestLargeMessageSuspension(t *testing.T) { if err != nil { t.Fatalf("failed to open stream: %v", err) } + // Launch a goroutine simillar to the stream monitoring goroutine in + // stream.go to keep track of context timeout and call CloseStream. + go func() { + <-ctx.Done() + ct.CloseStream(s, ContextErr(ctx.Err())) + }() // Write should not be done successfully due to flow control. msg := make([]byte, initialWindowSize*8) - err = ct.Write(s, nil, msg, &Options{Last: true, Delay: false}) + ct.Write(s, nil, msg, &Options{}) + err = ct.Write(s, nil, msg, &Options{Last: true}) + if err != errStreamDone { + t.Fatalf("Write got %v, want io.EOF", err) + } expectedErr := streamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded) - if err != expectedErr { - t.Fatalf("Write got %v, want %v", err, expectedErr) + if _, err := s.Read(make([]byte, 8)); err != expectedErr { + t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr) } ct.Close() server.stop() } func TestMaxStreams(t *testing.T) { - server, ct := setUp(t, 0, 1, suspended) + serverConfig := &ServerConfig{ + MaxStreams: 1, + } + server, ct := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}, func() {}) + defer ct.Close() + defer server.stop() callHdr := &CallHdr{ Host: "localhost", Method: "foo.Large", } - // Have a pending stream which takes all streams quota. s, err := ct.NewStream(context.Background(), callHdr) if err != nil { t.Fatalf("Failed to open stream: %v", err) } - cc, ok := ct.(*http2Client) - if !ok { - t.Fatalf("Failed to convert %v to *http2Client", ct) - } - done := make(chan struct{}) - ch := make(chan int) - ready := make(chan struct{}) - go func() { - for { - select { - case <-time.After(5 * time.Millisecond): - select { - case ch <- 0: - case <-ready: - return - } - case <-time.After(5 * time.Second): - close(done) - return - case <-ready: - return - } - } - }() - // Test these conditions until they pass or - // we reach the deadline (failure case). + // Keep creating streams until one fails with deadline exceeded, marking the application + // of server settings on client. + slist := []*Stream{} + pctx, cancel := context.WithCancel(context.Background()) + defer cancel() + timer := time.NewTimer(time.Second * 10) + expectedErr := streamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded) for { select { - case <-ch: - case <-done: - t.Fatalf("streamsQuota.quota shouldn't be non-zero.") + case <-timer.C: + t.Fatalf("Test timeout: client didn't receive server settings.") + default: } - cc.streamsQuota.mu.Lock() - sq := cc.streamsQuota.quota - cc.streamsQuota.mu.Unlock() - if sq == 0 { - break + ctx, cancel := context.WithDeadline(pctx, time.Now().Add(time.Second)) + // This is only to get rid of govet. All these context are based on a base + // context which is canceled at the end of the test. + defer cancel() + if str, err := ct.NewStream(ctx, callHdr); err == nil { + slist = append(slist, str) + continue + } else if err != expectedErr { + t.Fatalf("ct.NewStream(_,_) = _, %v, want _, %v", err, expectedErr) } + timer.Stop() + break } - close(ready) - // Close the pending stream so that the streams quota becomes available for the next new stream. - ct.CloseStream(s, nil) - cc.streamsQuota.mu.Lock() - i := cc.streamsQuota.quota - cc.streamsQuota.mu.Unlock() - if i != 1 { - t.Fatalf("streamsQuota is %d, want 1.", i) + done := make(chan struct{}) + // Try and create a new stream. + go func() { + defer close(done) + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10)) + defer cancel() + if _, err := ct.NewStream(ctx, callHdr); err != nil { + t.Errorf("Failed to open stream: %v", err) + } + }() + // Close all the extra streams created and make sure the new stream is not created. + for _, str := range slist { + ct.CloseStream(str, nil) } - if _, err := ct.NewStream(context.Background(), callHdr); err != nil { - t.Fatalf("Failed to open stream: %v", err) + select { + case <-done: + t.Fatalf("Test failed: didn't expect new stream to be created just yet.") + default: } + // Close the first stream created so that the new stream can finally be created. + ct.CloseStream(s, nil) + <-done ct.Close() - server.stop() + cc := ct.(*http2Client) + <-cc.writerDone + if cc.maxConcurrentStreams != 1 { + t.Fatalf("cc.maxConcurrentStreams: %d, want 1", cc.maxConcurrentStreams) + } } func TestServerContextCanceledOnClosedConnection(t *testing.T) { @@ -1171,7 +1227,13 @@ func TestServerContextCanceledOnClosedConnection(t *testing.T) { if err != nil { t.Fatalf("Failed to open stream: %v", err) } - cc.controlBuf.put(&dataFrame{s.id, false, make([]byte, http2MaxFrameLen), func() {}}) + cc.controlBuf.put(&dataFrame{ + streamID: s.id, + endStream: false, + h: nil, + d: make([]byte, http2MaxFrameLen), + onEachWrite: func() {}, + }) // Loop until the server side stream is created. var ss *Stream for { @@ -1202,7 +1264,7 @@ func TestClientConnDecoupledFromApplicationRead(t *testing.T) { InitialWindowSize: defaultWindowSize, InitialConnWindowSize: defaultWindowSize, } - server, client := setUpWithOptions(t, 0, &ServerConfig{}, notifyCall, connectOptions) + server, client := setUpWithOptions(t, 0, &ServerConfig{}, notifyCall, connectOptions, func() {}) defer server.stop() defer client.Close() @@ -1288,7 +1350,7 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) { InitialWindowSize: defaultWindowSize, InitialConnWindowSize: defaultWindowSize, } - server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) + server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}, func() {}) defer server.stop() defer client.Close() waitWhileTrue(t, func() (bool, error) { @@ -1340,18 +1402,6 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) { } } st.mu.Unlock() - // Trying to write more on a max-ed out stream should result in a RST_STREAM from the server. - ct := client.(*http2Client) - ct.controlBuf.put(&dataFrame{cstream2.id, true, make([]byte, 1), func() {}}) - code := http2ErrConvTab[http2.ErrCodeFlowControl] - waitWhileTrue(t, func() (bool, error) { - cstream2.mu.Lock() - defer cstream2.mu.Unlock() - if cstream2.status.Code() != code { - return true, fmt.Errorf("want code = %v, got %v", code, cstream2.status.Code()) - } - return false, nil - }) // Reading from the stream on server should succeed. if _, err := sstream1.Read(make([]byte, defaultWindowSize)); err != nil { t.Fatalf("_.Read(_) = %v, want ", err) @@ -1364,136 +1414,190 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) { } func TestServerWithMisbehavedClient(t *testing.T) { - serverConfig := &ServerConfig{ - InitialWindowSize: defaultWindowSize, - InitialConnWindowSize: defaultWindowSize, + server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) + defer server.stop() + // Create a client that can override server stream quota. + mconn, err := net.Dial("tcp", server.lis.Addr().String()) + if err != nil { + t.Fatalf("Clent failed to dial:%v", err) } - connectOptions := ConnectOptions{ - InitialWindowSize: defaultWindowSize, - InitialConnWindowSize: defaultWindowSize, + defer mconn.Close() + if err := mconn.SetWriteDeadline(time.Now().Add(time.Second * 10)); err != nil { + t.Fatalf("Failed to set write deadline: %v", err) } - server, ct := setUpWithOptions(t, 0, serverConfig, suspended, connectOptions) - callHdr := &CallHdr{ - Host: "localhost", - Method: "foo", + if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { + t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, ", n, err, len(clientPreface)) } - var sc *http2Server - // Wait until the server transport is setup. - for { - server.mu.Lock() - if len(server.conns) == 0 { - server.mu.Unlock() - time.Sleep(time.Millisecond) - continue - } - for k := range server.conns { - var ok bool - sc, ok = k.(*http2Server) - if !ok { - t.Fatalf("Failed to convert %v to *http2Server", k) + // success chan indicates that reader received a RSTStream from server. + success := make(chan struct{}) + var mu sync.Mutex + framer := http2.NewFramer(mconn, mconn) + if err := framer.WriteSettings(); err != nil { + t.Fatalf("Error while writing settings: %v", err) + } + go func() { // Launch a reader for this misbehaving client. + for { + frame, err := framer.ReadFrame() + if err != nil { + return } + switch frame := frame.(type) { + case *http2.PingFrame: + // Write ping ack back so that server's BDP estimation works right. + mu.Lock() + framer.WritePing(true, frame.Data) + mu.Unlock() + case *http2.RSTStreamFrame: + if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeFlowControl { + t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode)) + } + close(success) + return + default: + // Do nothing. + } + } - server.mu.Unlock() - break + }() + // Create a stream. + var buf bytes.Buffer + henc := hpack.NewEncoder(&buf) + // TODO(mmukhi): Remove unnecessary fields. + if err := henc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"}); err != nil { + t.Fatalf("Error while encoding header: %v", err) } - cc, ok := ct.(*http2Client) - if !ok { - t.Fatalf("Failed to convert %v to *http2Client", ct) + if err := henc.WriteField(hpack.HeaderField{Name: ":path", Value: "foo"}); err != nil { + t.Fatalf("Error while encoding header: %v", err) } - // Test server behavior for violation of stream flow control window size restriction. - s, err := ct.NewStream(context.Background(), callHdr) - if err != nil { - t.Fatalf("Failed to open stream: %v", err) + if err := henc.WriteField(hpack.HeaderField{Name: ":authority", Value: "localhost"}); err != nil { + t.Fatalf("Error while encoding header: %v", err) } - var sent int - // Drain the stream flow control window - cc.controlBuf.put(&dataFrame{s.id, false, make([]byte, http2MaxFrameLen), func() {}}) - sent += http2MaxFrameLen - // Wait until the server creates the corresponding stream and receive some data. - var ss *Stream + if err := henc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}); err != nil { + t.Fatalf("Error while encoding header: %v", err) + } + mu.Lock() + if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { + mu.Unlock() + t.Fatalf("Error while writing headers: %v", err) + } + mu.Unlock() + + // Test server behavior for violation of stream flow control window size restriction. + timer := time.NewTimer(time.Second * 5) + dbuf := make([]byte, http2MaxFrameLen) for { - time.Sleep(time.Millisecond) - sc.mu.Lock() - if len(sc.activeStreams) == 0 { - sc.mu.Unlock() - continue + select { + case <-timer.C: + t.Fatalf("Test timed out.") + case <-success: + return + default: } - ss = sc.activeStreams[s.id] - sc.mu.Unlock() - ss.fc.mu.Lock() - if ss.fc.pendingData > 0 { - ss.fc.mu.Unlock() - break + mu.Lock() + if err := framer.WriteData(1, false, dbuf); err != nil { + mu.Unlock() + // Error here means the server could have closed the connection due to flow control + // violation. Make sure that is the case by waiting for success chan to be closed. + select { + case <-timer.C: + t.Fatalf("Error while writing data: %v", err) + case <-success: + return + } } - ss.fc.mu.Unlock() - } - if ss.fc.pendingData != http2MaxFrameLen || ss.fc.pendingUpdate != 0 || sc.fc.pendingData != 0 || sc.fc.pendingUpdate != 0 { - t.Fatalf("Server mistakenly updates inbound flow control params: got %d, %d, %d, %d; want %d, %d, %d, %d", ss.fc.pendingData, ss.fc.pendingUpdate, sc.fc.pendingData, sc.fc.pendingUpdate, http2MaxFrameLen, 0, 0, 0) - } - // Keep sending until the server inbound window is drained for that stream. - for sent <= initialWindowSize { - cc.controlBuf.put(&dataFrame{s.id, false, make([]byte, 1), func() {}}) - sent++ - } - // Server sent a resetStream for s already. - code := http2ErrConvTab[http2.ErrCodeFlowControl] - if _, err := s.Read(make([]byte, 1)); err != io.EOF { - t.Fatalf("%v got err %v want ", s, err) - } - if s.status.Code() != code { - t.Fatalf("%v got status %v; want Code=%v", s, s.status, code) + mu.Unlock() + // This for loop is capable of hogging the CPU and cause starvation + // in Go versions prior to 1.9, + // in single CPU environment. Explicitly relinquish processor. + runtime.Gosched() } - - ct.CloseStream(s, nil) - ct.Close() - server.stop() } func TestClientWithMisbehavedServer(t *testing.T) { - // Turn off BDP estimation so that the server can - // violate stream window. - connectOptions := ConnectOptions{ - InitialWindowSize: initialWindowSize, - } - server, ct := setUpWithOptions(t, 0, &ServerConfig{}, misbehaved, connectOptions) - callHdr := &CallHdr{ - Host: "localhost", - Method: "foo.Stream", - } - conn, ok := ct.(*http2Client) - if !ok { - t.Fatalf("Failed to convert %v to *http2Client", ct) - } - // Test the logic for the violation of stream flow control window size restriction. - s, err := ct.NewStream(context.Background(), callHdr) + // Create a misbehaving server. + lis, err := net.Listen("tcp", "localhost:0") if err != nil { - t.Fatalf("Failed to open stream: %v", err) - } - d := make([]byte, 1) - if err := ct.Write(s, nil, d, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { - t.Fatalf("Failed to write: %v", err) - } - // Read without window update. - for { - p := make([]byte, http2MaxFrameLen) - if _, err = s.trReader.(*transportReader).reader.Read(p); err != nil { - break + t.Fatalf("Error while listening: %v", err) + } + defer lis.Close() + // success chan indicates that the server received + // RSTStream from the client. + success := make(chan struct{}) + go func() { // Launch the misbehaving server. + sconn, err := lis.Accept() + if err != nil { + t.Errorf("Error while accepting: %v", err) + return } + defer sconn.Close() + if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil { + t.Errorf("Error while reading clieng preface: %v", err) + return + } + sfr := http2.NewFramer(sconn, sconn) + if err := sfr.WriteSettingsAck(); err != nil { + t.Errorf("Error while writing settings: %v", err) + return + } + var mu sync.Mutex + for { + frame, err := sfr.ReadFrame() + if err != nil { + return + } + switch frame := frame.(type) { + case *http2.HeadersFrame: + // When the client creates a stream, violate the stream flow control. + go func() { + buf := make([]byte, http2MaxFrameLen) + for { + mu.Lock() + if err := sfr.WriteData(1, false, buf); err != nil { + mu.Unlock() + return + } + mu.Unlock() + // This for loop is capable of hogging the CPU and cause starvation + // in Go versions prior to 1.9, + // in single CPU environment. Explicitly relinquish processor. + runtime.Gosched() + } + }() + case *http2.RSTStreamFrame: + if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeFlowControl { + t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode)) + } + close(success) + return + case *http2.PingFrame: + mu.Lock() + sfr.WritePing(true, frame.Data) + mu.Unlock() + default: + } + } + }() + connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) + defer cancel() + ct, err := NewClientTransport(connectCtx, context.Background(), TargetInfo{Addr: lis.Addr().String()}, ConnectOptions{}, func() {}) + if err != nil { + t.Fatalf("Error while creating client transport: %v", err) } - if s.fc.pendingData <= initialWindowSize || s.fc.pendingUpdate != 0 || conn.fc.pendingData != 0 || conn.fc.pendingUpdate != 0 { - t.Fatalf("Client mistakenly updates inbound flow control params: got %d, %d, %d, %d; want >%d, %d, %d, >%d", s.fc.pendingData, s.fc.pendingUpdate, conn.fc.pendingData, conn.fc.pendingUpdate, initialWindowSize, 0, 0, 0) - } - - if err != io.EOF { - t.Fatalf("Got err %v, want ", err) + defer ct.Close() + str, err := ct.NewStream(context.Background(), &CallHdr{}) + if err != nil { + t.Fatalf("Error while creating stream: %v", err) } - if s.status.Code() != codes.Internal { - t.Fatalf("Got s.status %v, want s.status.Code()=Internal", s.status) + timer := time.NewTimer(time.Second * 5) + go func() { // This go routine mimics the one in stream.go to call CloseStream. + <-str.Done() + ct.CloseStream(str, nil) + }() + select { + case <-timer.C: + t.Fatalf("Test timed-out.") + case <-success: } - - conn.CloseStream(s, err) - ct.Close() - server.stop() } var encodingTestStatus = status.New(codes.Internal, "\n") @@ -1512,7 +1616,7 @@ func TestEncodingRequiredStatus(t *testing.T) { Last: true, Delay: false, } - if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != io.EOF { + if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != errStreamDone { t.Fatalf("Failed to write the request: %v", err) } p := make([]byte, http2MaxFrameLen) @@ -1613,7 +1717,7 @@ func TestAccountCheckWindowSizeWithLargeWindow(t *testing.T) { clientStream: 6 * 1024 * 1024, clientConn: 8 * 1024 * 1024, } - testAccountCheckWindowSize(t, wc) + testFlowControlAccountCheck(t, 1024*1024, wc) } func TestAccountCheckWindowSizeWithSmallWindow(t *testing.T) { @@ -1624,135 +1728,27 @@ func TestAccountCheckWindowSizeWithSmallWindow(t *testing.T) { clientStream: defaultWindowSize, clientConn: defaultWindowSize, } - testAccountCheckWindowSize(t, wc) + testFlowControlAccountCheck(t, 1024*1024, wc) } -func testAccountCheckWindowSize(t *testing.T, wc windowSizeConfig) { - serverConfig := &ServerConfig{ +func TestAccountCheckDynamicWindowSmallMessage(t *testing.T) { + testFlowControlAccountCheck(t, 1024, windowSizeConfig{}) +} + +func TestAccountCheckDynamicWindowLargeMessage(t *testing.T) { + testFlowControlAccountCheck(t, 1024*1024, windowSizeConfig{}) +} + +func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) { + sc := &ServerConfig{ InitialWindowSize: wc.serverStream, InitialConnWindowSize: wc.serverConn, } - connectOptions := ConnectOptions{ + co := ConnectOptions{ InitialWindowSize: wc.clientStream, InitialConnWindowSize: wc.clientConn, } - server, client := setUpWithOptions(t, 0, serverConfig, suspended, connectOptions) - defer server.stop() - defer client.Close() - - // Wait for server conns to be populated with new server transport. - waitWhileTrue(t, func() (bool, error) { - server.mu.Lock() - defer server.mu.Unlock() - if len(server.conns) == 0 { - return true, fmt.Errorf("timed out waiting for server transport to be created") - } - return false, nil - }) - var st *http2Server - server.mu.Lock() - for k := range server.conns { - st = k.(*http2Server) - } - server.mu.Unlock() - ct := client.(*http2Client) - cstream, err := client.NewStream(context.Background(), &CallHdr{Flush: true}) - if err != nil { - t.Fatalf("Failed to create stream. Err: %v", err) - } - // Wait for server to receive headers. - waitWhileTrue(t, func() (bool, error) { - st.mu.Lock() - defer st.mu.Unlock() - if len(st.activeStreams) == 0 { - return true, fmt.Errorf("timed out waiting for server to receive headers") - } - return false, nil - }) - // Sleeping to make sure the settings are applied in case of negative test. - time.Sleep(time.Second) - - waitWhileTrue(t, func() (bool, error) { - st.fc.mu.Lock() - lim := st.fc.limit - st.fc.mu.Unlock() - if lim != uint32(serverConfig.InitialConnWindowSize) { - return true, fmt.Errorf("Server transport flow control window size: got %v, want %v", lim, serverConfig.InitialConnWindowSize) - } - return false, nil - }) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - serverSendQuota, _, err := st.sendQuotaPool.get(math.MaxInt32, waiters{ - ctx: ctx, - tctx: st.ctx, - done: nil, - goAway: nil, - }) - if err != nil { - t.Fatalf("Error while acquiring sendQuota on server. Err: %v", err) - } - cancel() - st.sendQuotaPool.add(serverSendQuota) - if serverSendQuota != int(connectOptions.InitialConnWindowSize) { - t.Fatalf("Server send quota(%v) not equal to client's window size(%v) on conn.", serverSendQuota, connectOptions.InitialConnWindowSize) - } - st.mu.Lock() - ssq := st.streamSendQuota - st.mu.Unlock() - if ssq != uint32(connectOptions.InitialWindowSize) { - t.Fatalf("Server stream send quota(%v) not equal to client's window size(%v) on stream.", ssq, connectOptions.InitialWindowSize) - } - ct.fc.mu.Lock() - limit := ct.fc.limit - ct.fc.mu.Unlock() - if limit != uint32(connectOptions.InitialConnWindowSize) { - t.Fatalf("Client transport flow control window size is %v, want %v", limit, connectOptions.InitialConnWindowSize) - } - ctx, cancel = context.WithTimeout(context.Background(), time.Second) - clientSendQuota, _, err := ct.sendQuotaPool.get(math.MaxInt32, waiters{ - ctx: ctx, - tctx: ct.ctx, - done: nil, - goAway: nil, - }) - if err != nil { - t.Fatalf("Error while acquiring sendQuota on client. Err: %v", err) - } - cancel() - ct.sendQuotaPool.add(clientSendQuota) - if clientSendQuota != int(serverConfig.InitialConnWindowSize) { - t.Fatalf("Client send quota(%v) not equal to server's window size(%v) on conn.", clientSendQuota, serverConfig.InitialConnWindowSize) - } - ct.mu.Lock() - ssq = ct.streamSendQuota - ct.mu.Unlock() - if ssq != uint32(serverConfig.InitialWindowSize) { - t.Fatalf("Client stream send quota(%v) not equal to server's window size(%v) on stream.", ssq, serverConfig.InitialWindowSize) - } - cstream.fc.mu.Lock() - limit = cstream.fc.limit - cstream.fc.mu.Unlock() - if limit != uint32(connectOptions.InitialWindowSize) { - t.Fatalf("Client stream flow control window size is %v, want %v", limit, connectOptions.InitialWindowSize) - } - var sstream *Stream - st.mu.Lock() - for _, v := range st.activeStreams { - sstream = v - } - st.mu.Unlock() - sstream.fc.mu.Lock() - limit = sstream.fc.limit - sstream.fc.mu.Unlock() - if limit != uint32(serverConfig.InitialWindowSize) { - t.Fatalf("Server stream flow control window size is %v, want %v", limit, serverConfig.InitialWindowSize) - } -} - -// Check accounting on both sides after sending and receiving large messages. -func TestAccountCheckExpandingWindow(t *testing.T) { - server, client := setUp(t, 0, 0, pingpong) + server, client := setUpWithOptions(t, 0, sc, pingpong, co, func() {}) defer server.stop() defer client.Close() waitWhileTrue(t, func() (bool, error) { @@ -1770,12 +1766,10 @@ func TestAccountCheckExpandingWindow(t *testing.T) { } server.mu.Unlock() ct := client.(*http2Client) - cstream, err := client.NewStream(context.Background(), &CallHdr{Flush: true}) + cstream, err := client.NewStream(context.Background(), &CallHdr{}) if err != nil { t.Fatalf("Failed to create stream. Err: %v", err) } - - msgSize := 65535 * 16 * 2 msg := make([]byte, msgSize) buf := make([]byte, msgSize+5) buf[0] = byte(0) @@ -1799,145 +1793,42 @@ func TestAccountCheckExpandingWindow(t *testing.T) { t.Fatalf("Length of message received by client: %v, want: %v", len(recvMsg), len(msg)) } } - defer func() { - ct.Write(cstream, nil, nil, &Options{Last: true}) // Close the stream. - if _, err := cstream.Read(header); err != io.EOF { - t.Fatalf("Client expected an EOF from the server. Got: %v", err) - } - }() var sstream *Stream st.mu.Lock() for _, v := range st.activeStreams { sstream = v } st.mu.Unlock() - - waitWhileTrue(t, func() (bool, error) { - // Check that pendingData and delta on flow control windows on both sides are 0. - cstream.fc.mu.Lock() - if cstream.fc.delta != 0 { - cstream.fc.mu.Unlock() - return true, fmt.Errorf("delta on flow control window of client stream is non-zero") - } - if cstream.fc.pendingData != 0 { - cstream.fc.mu.Unlock() - return true, fmt.Errorf("pendingData on flow control window of client stream is non-zero") - } - cstream.fc.mu.Unlock() - sstream.fc.mu.Lock() - if sstream.fc.delta != 0 { - sstream.fc.mu.Unlock() - return true, fmt.Errorf("delta on flow control window of server stream is non-zero") - } - if sstream.fc.pendingData != 0 { - sstream.fc.mu.Unlock() - return true, fmt.Errorf("pendingData on flow control window of sercer stream is non-zero") - } - sstream.fc.mu.Unlock() - ct.fc.mu.Lock() - if ct.fc.delta != 0 { - ct.fc.mu.Unlock() - return true, fmt.Errorf("delta on flow control window of client transport is non-zero") - } - if ct.fc.pendingData != 0 { - ct.fc.mu.Unlock() - return true, fmt.Errorf("pendingData on flow control window of client transport is non-zero") - } - ct.fc.mu.Unlock() - st.fc.mu.Lock() - if st.fc.delta != 0 { - st.fc.mu.Unlock() - return true, fmt.Errorf("delta on flow control window of server transport is non-zero") - } - if st.fc.pendingData != 0 { - st.fc.mu.Unlock() - return true, fmt.Errorf("pendingData on flow control window of server transport is non-zero") - } - st.fc.mu.Unlock() - - // Check flow conrtrol window on client stream is equal to out flow on server stream. - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - serverStreamSendQuota, _, err := sstream.sendQuotaPool.get(math.MaxInt32, waiters{ - ctx: ctx, - tctx: context.Background(), - done: nil, - goAway: nil, - }) - cancel() - if err != nil { - return true, fmt.Errorf("error while acquiring server stream send quota. Err: %v", err) - } - sstream.sendQuotaPool.add(serverStreamSendQuota) - cstream.fc.mu.Lock() - clientEst := cstream.fc.limit - cstream.fc.pendingUpdate - cstream.fc.mu.Unlock() - if uint32(serverStreamSendQuota) != clientEst { - return true, fmt.Errorf("server stream outflow: %v, estimated by client: %v", serverStreamSendQuota, clientEst) - } - - // Check flow control window on server stream is equal to out flow on client stream. - ctx, cancel = context.WithTimeout(context.Background(), time.Second) - clientStreamSendQuota, _, err := cstream.sendQuotaPool.get(math.MaxInt32, waiters{ - ctx: ctx, - tctx: context.Background(), - done: nil, - goAway: nil, - }) - cancel() - if err != nil { - return true, fmt.Errorf("error while acquiring client stream send quota. Err: %v", err) - } - cstream.sendQuotaPool.add(clientStreamSendQuota) - sstream.fc.mu.Lock() - serverEst := sstream.fc.limit - sstream.fc.pendingUpdate - sstream.fc.mu.Unlock() - if uint32(clientStreamSendQuota) != serverEst { - return true, fmt.Errorf("client stream outflow: %v. estimated by server: %v", clientStreamSendQuota, serverEst) - } - - // Check flow control window on client transport is equal to out flow of server transport. - ctx, cancel = context.WithTimeout(context.Background(), time.Second) - serverTrSendQuota, _, err := st.sendQuotaPool.get(math.MaxInt32, waiters{ - ctx: ctx, - tctx: st.ctx, - done: nil, - goAway: nil, - }) - cancel() - if err != nil { - return true, fmt.Errorf("error while acquring server transport send quota. Err: %v", err) - } - st.sendQuotaPool.add(serverTrSendQuota) - ct.fc.mu.Lock() - clientEst = ct.fc.limit - ct.fc.pendingUpdate - ct.fc.mu.Unlock() - if uint32(serverTrSendQuota) != clientEst { - return true, fmt.Errorf("server transport outflow: %v, estimated by client: %v", serverTrSendQuota, clientEst) - } - - // Check flow control window on server transport is equal to out flow of client transport. - ctx, cancel = context.WithTimeout(context.Background(), time.Second) - clientTrSendQuota, _, err := ct.sendQuotaPool.get(math.MaxInt32, waiters{ - ctx: ctx, - tctx: ct.ctx, - done: nil, - goAway: nil, - }) - cancel() - if err != nil { - return true, fmt.Errorf("error while acquiring client transport send quota. Err: %v", err) - } - ct.sendQuotaPool.add(clientTrSendQuota) - st.fc.mu.Lock() - serverEst = st.fc.limit - st.fc.pendingUpdate - st.fc.mu.Unlock() - if uint32(clientTrSendQuota) != serverEst { - return true, fmt.Errorf("client transport outflow: %v, estimated by client: %v", clientTrSendQuota, serverEst) - } - - return false, nil - }) - + loopyServerStream := st.loopy.estdStreams[sstream.id] + loopyClientStream := ct.loopy.estdStreams[cstream.id] + ct.Write(cstream, nil, nil, &Options{Last: true}) // Close the stream. + if _, err := cstream.Read(header); err != io.EOF { + t.Fatalf("Client expected an EOF from the server. Got: %v", err) + } + // Sleep for a little to make sure both sides flush out their buffers. + time.Sleep(time.Millisecond * 500) + // Close down both server and client so that their internals can be read without data + // races. + ct.Close() + st.Close() + <-st.readerDone + <-st.writerDone + <-ct.readerDone + <-ct.writerDone + // Check transport flow control. + if ct.fc.limit != ct.fc.unacked+st.loopy.sendQuota { + t.Fatalf("Account mismatch: client transport inflow(%d) != client unacked(%d) + server sendQuota(%d)", ct.fc.limit, ct.fc.unacked, st.loopy.sendQuota) + } + if st.fc.limit != st.fc.unacked+ct.loopy.sendQuota { + t.Fatalf("Account mismatch: server transport inflow(%d) != server unacked(%d) + client sendQuota(%d)", st.fc.limit, st.fc.unacked, ct.loopy.sendQuota) + } + // Check stream flow control. + if int(cstream.fc.limit+cstream.fc.delta-cstream.fc.pendingData-cstream.fc.pendingUpdate) != int(st.loopy.oiws)-loopyServerStream.bytesOutStanding { + t.Fatalf("Account mismatch: client stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != server outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", cstream.fc.limit, cstream.fc.delta, cstream.fc.pendingData, cstream.fc.pendingUpdate, st.loopy.oiws, loopyServerStream.bytesOutStanding) + } + if int(sstream.fc.limit+sstream.fc.delta-sstream.fc.pendingData-sstream.fc.pendingUpdate) != int(ct.loopy.oiws)-loopyClientStream.bytesOutStanding { + t.Fatalf("Account mismatch: server stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != client outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", sstream.fc.limit, sstream.fc.delta, sstream.fc.pendingData, sstream.fc.pendingUpdate, ct.loopy.oiws, loopyClientStream.bytesOutStanding) + } } func waitWhileTrue(t *testing.T, condition func() (bool, error)) { @@ -2022,7 +1913,6 @@ func (s *httpServer) start(t *testing.T, lis net.Listener) { t.Errorf("Error accepting connection: %v", err) return } - defer s.conn.Close() // Read preface sent by client. if _, err = io.ReadFull(s.conn, make([]byte, len(http2.ClientPreface))); err != nil { t.Errorf("Error at server-side while reading preface from cleint. Err: %v", err) @@ -2140,8 +2030,6 @@ func TestHTTPStatusOKAndMissingGRPCStatus(t *testing.T) { t.Fatalf("stream.Read(_) = _, %v, want _, io.EOF", err) } want := codes.Unknown - stream.mu.Lock() - defer stream.mu.Unlock() if stream.status.Code() != want { t.Fatalf("Status code of stream: %v, want: %v", stream.status.Code(), want) } @@ -2157,15 +2045,14 @@ func TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) { testRecvBuffer := newRecvBuffer() s := &Stream{ ctx: context.Background(), - goAway: make(chan struct{}), buf: testRecvBuffer, requestRead: func(int) {}, } s.trReader = &transportReader{ reader: &recvBufferReader{ - ctx: s.ctx, - goAway: s.goAway, - recv: s.buf, + ctx: s.ctx, + ctxDone: s.ctx.Done(), + recv: s.buf, }, windowHandler: func(int) {}, }