From c24de9d546a52d43325a0e37cc7ab02d92c7fb6f Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 16 Dec 2015 17:29:56 +0000 Subject: [PATCH] http2: add Server support for reading trailers from clients Updates golang/go#13557 Change-Id: I95bbb15d9abbbbc4dc6c3a22cd965d8dcef53fb8 Reviewed-on: https://go-review.googlesource.com/17891 Reviewed-by: Blake Mizerany --- http2/headermap.go | 1 + http2/hpack/hpack.go | 7 ++ http2/pipe.go | 41 ++++++---- http2/server.go | 187 ++++++++++++++++++++++++++++++++++++------- http2/server_test.go | 65 ++++++++++++++- 5 files changed, 256 insertions(+), 45 deletions(-) diff --git a/http2/headermap.go b/http2/headermap.go index 014f789646..c2805f6ac4 100644 --- a/http2/headermap.go +++ b/http2/headermap.go @@ -57,6 +57,7 @@ func init() { "server", "set-cookie", "strict-transport-security", + "trailer", "transfer-encoding", "user-agent", "vary", diff --git a/http2/hpack/hpack.go b/http2/hpack/hpack.go index 8e9b2f2ebf..329a8d036d 100644 --- a/http2/hpack/hpack.go +++ b/http2/hpack/hpack.go @@ -102,6 +102,13 @@ func (d *Decoder) SetMaxStringLength(n int) { d.maxStrLen = n } +// SetEmitFunc changes the callback used when new header fields +// are decoded. +// It must be non-nil. It does not affect EmitEnabled. +func (d *Decoder) SetEmitFunc(emitFunc func(f HeaderField)) { + d.emit = emitFunc +} + // SetEmitEnabled controls whether the emitFunc provided to NewDecoder // should be called. The default is true. // diff --git a/http2/pipe.go b/http2/pipe.go index e30661cb23..208f3851d9 100644 --- a/http2/pipe.go +++ b/http2/pipe.go @@ -14,11 +14,12 @@ import ( // io.Pipe except there are no PipeReader/PipeWriter halves, and the // underlying buffer is an interface. (io.Pipe is always unbuffered) type pipe struct { - mu sync.Mutex - c sync.Cond // c.L must point to - b pipeBuffer - err error // read error once empty. non-nil means closed. - donec chan struct{} // closed on error + mu sync.Mutex + c sync.Cond // c.L must point to + b pipeBuffer + err error // read error once empty. non-nil means closed. + donec chan struct{} // closed on error + readFn func() // optional code to run in Read before error } type pipeBuffer interface { @@ -40,6 +41,10 @@ func (p *pipe) Read(d []byte) (n int, err error) { return p.b.Read(d) } if p.err != nil { + if p.readFn != nil { + p.readFn() // e.g. copy trailers + p.readFn = nil // not sticky like p.err + } return 0, p.err } p.c.Wait() @@ -63,13 +68,18 @@ func (p *pipe) Write(d []byte) (n int, err error) { return p.b.Write(d) } -// CloseWithError causes Reads to wake up and return the -// provided err after all data has been read. +// CloseWithError causes the next Read (waking up a current blocked +// Read if needed) to return the provided err after all data has been +// read. // // The error must be non-nil. -func (p *pipe) CloseWithError(err error) { +func (p *pipe) CloseWithError(err error) { p.closeWithErrorAndCode(err, nil) } + +// closeWithErrorAndCode is like CloseWithError but also sets some code to run +// in the caller's goroutine before returning the error. +func (p *pipe) closeWithErrorAndCode(err error, fn func()) { if err == nil { - panic("CloseWithError must be non-nil") + panic("CloseWithError err must be non-nil") } p.mu.Lock() defer p.mu.Unlock() @@ -77,11 +87,14 @@ func (p *pipe) CloseWithError(err error) { p.c.L = &p.mu } defer p.c.Signal() - if p.err == nil { - p.err = err - if p.donec != nil { - close(p.donec) - } + if p.err != nil { + // Already been done. + return + } + p.readFn = fn + p.err = err + if p.donec != nil { + close(p.donec) } } diff --git a/http2/server.go b/http2/server.go index 8d5f7cd471..238c186365 100644 --- a/http2/server.go +++ b/http2/server.go @@ -224,7 +224,7 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) { sc.flow.add(initialWindowSize) sc.inflow.add(initialWindowSize) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) - sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, sc.onNewHeaderField) + sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, nil) sc.hpackDecoder.SetMaxStringLength(sc.maxHeaderStringLen()) fr := NewFramer(sc.bw, c) @@ -411,20 +411,26 @@ type requestParam struct { // responseWriter's state field. type stream struct { // immutable: + sc *serverConn id uint32 body *pipe // non-nil if expecting DATA frames cw closeWaiter // closed wait stream transitions to closed state // owned by serverConn's serve loop: - bodyBytes int64 // body bytes seen so far - declBodyBytes int64 // or -1 if undeclared - flow flow // limits writing from Handler to client - inflow flow // what the client is allowed to POST/etc to us - parent *stream // or nil - weight uint8 - state streamState - sentReset bool // only true once detached from streams map - gotReset bool // only true once detacted from streams map + bodyBytes int64 // body bytes seen so far + declBodyBytes int64 // or -1 if undeclared + flow flow // limits writing from Handler to client + inflow flow // what the client is allowed to POST/etc to us + parent *stream // or nil + numTrailerValues int64 + weight uint8 + state streamState + sentReset bool // only true once detached from streams map + gotReset bool // only true once detacted from streams map + gotTrailerHeader bool // HEADER frame for trailers was seen + + trailer http.Header // accumulated trailers + reqTrailer http.Header // handler's Request.Trailer } func (sc *serverConn) Framer() *Framer { return sc.framer } @@ -537,6 +543,37 @@ func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) { } } +func (st *stream) onNewTrailerField(f hpack.HeaderField) { + sc := st.sc + sc.serveG.check() + sc.vlogf("got trailer field %+v", f) + switch { + case !validHeader(f.Name): + // TODO: change hpack signature so this can return + // errors? Or stash an error somewhere on st or sc + // for processHeaderBlockFragment etc to pick up and + // return after the hpack Write/Close. For now just + // ignore. + return + case strings.HasPrefix(f.Name, ":"): + // TODO: same TODO as above. + return + default: + key := sc.canonicalHeader(f.Name) + if st.trailer != nil { + vv := append(st.trailer[key], f.Value) + st.trailer[key] = vv + + // arbitrary; TODO: read spec about header list size limits wrt trailers + const tooBig = 1000 + if len(vv) >= tooBig { + sc.hpackDecoder.SetEmitEnabled(false) + } + + } + } +} + func (sc *serverConn) canonicalHeader(v string) string { sc.serveG.check() cv, ok := commonCanonHeader[v] @@ -1249,7 +1286,7 @@ func (sc *serverConn) processData(f *DataFrame) error { // with a stream error (Section 5.4.2) of type STREAM_CLOSED." id := f.Header().StreamID st, ok := sc.streams[id] - if !ok || st.state != stateOpen { + if !ok || st.state != stateOpen || st.gotTrailerHeader { // This includes sending a RST_STREAM if the stream is // in stateHalfClosedLocal (which currently means that // the http.Handler returned, so it's done reading & @@ -1283,17 +1320,38 @@ func (sc *serverConn) processData(f *DataFrame) error { st.bodyBytes += int64(len(data)) } if f.StreamEnded() { - if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes { - st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes", - st.declBodyBytes, st.bodyBytes)) - } else { - st.body.CloseWithError(io.EOF) - } - st.state = stateHalfClosedRemote + st.endStream() } return nil } +// endStream closes a Request.Body's pipe. It is called when a DATA +// frame says a request body is over (or after trailers). +func (st *stream) endStream() { + sc := st.sc + sc.serveG.check() + + if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes { + st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes", + st.declBodyBytes, st.bodyBytes)) + } else { + st.body.closeWithErrorAndCode(io.EOF, st.copyTrailersToHandlerRequest) + st.body.CloseWithError(io.EOF) + } + st.state = stateHalfClosedRemote +} + +// copyTrailersToHandlerRequest is run in the Handler's goroutine in +// its Request.Body.Read just before it gets io.EOF. +func (st *stream) copyTrailersToHandlerRequest() { + for k, vv := range st.trailer { + if _, ok := st.reqTrailer[k]; ok { + // Only copy it over it was pre-declared. + st.reqTrailer[k] = vv + } + } +} + func (sc *serverConn) processHeaders(f *HeadersFrame) error { sc.serveG.check() id := f.Header().StreamID @@ -1302,20 +1360,36 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error { return nil } // http://http2.github.io/http2-spec/#rfc.section.5.1.1 - if id%2 != 1 || id <= sc.maxStreamID || sc.req.stream != nil { - // Streams initiated by a client MUST use odd-numbered - // stream identifiers. [...] The identifier of a newly - // established stream MUST be numerically greater than all - // streams that the initiating endpoint has opened or - // reserved. [...] An endpoint that receives an unexpected - // stream identifier MUST respond with a connection error - // (Section 5.4.1) of type PROTOCOL_ERROR. + // Streams initiated by a client MUST use odd-numbered stream + // identifiers. [...] An endpoint that receives an unexpected + // stream identifier MUST respond with a connection error + // (Section 5.4.1) of type PROTOCOL_ERROR. + if id%2 != 1 { return ConnectionError(ErrCodeProtocol) } + // A HEADERS frame can be used to create a new stream or + // send a trailer for an open one. If we already have a stream + // open, let it process its own HEADERS frame (trailers at this + // point, if it's valid). + st := sc.streams[f.Header().StreamID] + if st != nil { + return st.processTrailerHeaders(f) + } + + // [...] The identifier of a newly established stream MUST be + // numerically greater than all streams that the initiating + // endpoint has opened or reserved. [...] An endpoint that + // receives an unexpected stream identifier MUST respond with + // a connection error (Section 5.4.1) of type PROTOCOL_ERROR. + if id <= sc.maxStreamID || sc.req.stream != nil { + return ConnectionError(ErrCodeProtocol) + } + if id > sc.maxStreamID { sc.maxStreamID = id } - st := &stream{ + st = &stream{ + sc: sc, id: id, state: stateOpen, } @@ -1341,16 +1415,30 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error { stream: st, header: make(http.Header), } + sc.hpackDecoder.SetEmitFunc(sc.onNewHeaderField) sc.hpackDecoder.SetEmitEnabled(true) return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded()) } +func (st *stream) processTrailerHeaders(f *HeadersFrame) error { + sc := st.sc + sc.serveG.check() + if st.gotTrailerHeader { + return ConnectionError(ErrCodeProtocol) + } + st.gotTrailerHeader = true + return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded()) +} + func (sc *serverConn) processContinuation(f *ContinuationFrame) error { sc.serveG.check() st := sc.streams[f.Header().StreamID] if st == nil || sc.curHeaderStreamID() != st.id { return ConnectionError(ErrCodeProtocol) } + if st.gotTrailerHeader { + return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded()) + } return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded()) } @@ -1389,6 +1477,10 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo if err != nil { return err } + st.reqTrailer = req.Trailer + if st.reqTrailer != nil { + st.trailer = make(http.Header) + } st.body = req.Body.(*requestBody).pipe // may be nil st.declBodyBytes = req.ContentLength @@ -1402,6 +1494,24 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo return nil } +func (st *stream) processTrailerHeaderBlockFragment(frag []byte, end bool) error { + sc := st.sc + sc.serveG.check() + sc.hpackDecoder.SetEmitFunc(st.onNewTrailerField) + if _, err := sc.hpackDecoder.Write(frag); err != nil { + return ConnectionError(ErrCodeCompression) + } + if !end { + return nil + } + err := sc.hpackDecoder.Close() + st.endStream() + if err != nil { + return ConnectionError(ErrCodeCompression) + } + return nil +} + func (sc *serverConn) processPriority(f *PriorityFrame) error { adjustStreamPriority(sc.streams, f.StreamID, f.PriorityParam) return nil @@ -1489,6 +1599,26 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err if cookies := rp.header["Cookie"]; len(cookies) > 1 { rp.header.Set("Cookie", strings.Join(cookies, "; ")) } + + // Setup Trailers + var trailer http.Header + for _, v := range rp.header["Trailer"] { + for _, key := range strings.Split(v, ",") { + key = http.CanonicalHeaderKey(strings.TrimSpace(key)) + switch key { + case "Transfer-Encoding", "Trailer", "Content-Length": + // Bogus. (copy of http1 rules) + // Ignore. + default: + if trailer == nil { + trailer = make(http.Header) + } + trailer[key] = nil + } + } + } + delete(rp.header, "Trailer") + body := &requestBody{ conn: sc, stream: rp.stream, @@ -1512,10 +1642,11 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err TLS: tlsState, Host: authority, Body: body, + Trailer: trailer, } if bodyOpen { body.pipe = &pipe{ - b: &fixedBuffer{buf: make([]byte, initialWindowSize)}, // TODO: share/remove XXX + b: &fixedBuffer{buf: make([]byte, initialWindowSize)}, // TODO: garbage } if vv, ok := rp.header["Content-Length"]; ok { diff --git a/http2/server_test.go b/http2/server_test.go index d8fbe88193..4835ba6d74 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -246,6 +246,21 @@ func (st *serverTester) encodeHeaderField(k, v string) { } } +// encodeHeaderRaw is the magic-free version of encodeHeader. +// It takes 0 or more (k, v) pairs and encodes them. +func (st *serverTester) encodeHeaderRaw(headers ...string) []byte { + if len(headers)%2 == 1 { + panic("odd number of kv args") + } + st.headerBuf.Reset() + for len(headers) > 0 { + k, v := headers[0], headers[1] + st.encodeHeaderField(k, v) + headers = headers[2:] + } + return st.headerBuf.Bytes() +} + // encodeHeader encodes headers and returns their HPACK bytes. headers // must contain an even number of key/value pairs. There may be // multiple pairs for keys (e.g. "cookie"). The :method, :path, and @@ -299,7 +314,6 @@ func (st *serverTester) encodeHeader(headers ...string) []byte { vals[k] = append(vals[k], v) } } - st.headerBuf.Reset() for _, k := range keys { for _, v := range vals[k] { st.encodeHeaderField(k, v) @@ -2451,8 +2465,53 @@ func TestCompressionErrorOnClose(t *testing.T) { // test that a server handler can read trailers from a client func TestServerReadsTrailers(t *testing.T) { - // TODO: use testBodyContents or testServerRequest - t.Skip("unimplemented") + const testBody = "some test body" + writeReq := func(st *serverTester) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: st.encodeHeader("trailer", "Foo, Bar", "trailer", "Baz"), + EndStream: false, + EndHeaders: true, + }) + st.writeData(1, false, []byte(testBody)) + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: st.encodeHeaderRaw( + "foo", "foov", + "bar", "barv", + "baz", "bazv", + "surprise", "wasn't declared; shouldn't show up", + ), + EndStream: true, + EndHeaders: true, + }) + } + checkReq := func(r *http.Request) { + wantTrailer := http.Header{ + "Foo": nil, + "Bar": nil, + "Baz": nil, + } + if !reflect.DeepEqual(r.Trailer, wantTrailer) { + t.Errorf("initial Trailer = %v; want %v", r.Trailer, wantTrailer) + } + slurp, err := ioutil.ReadAll(r.Body) + if string(slurp) != testBody { + t.Errorf("read body %q; want %q", slurp, testBody) + } + if err != nil { + t.Fatalf("Body slurp: %v", err) + } + wantTrailerAfter := http.Header{ + "Foo": {"foov"}, + "Bar": {"barv"}, + "Baz": {"bazv"}, + } + if !reflect.DeepEqual(r.Trailer, wantTrailerAfter) { + t.Errorf("final Trailer = %v; want %v", r.Trailer, wantTrailerAfter) + } + } + testServerRequest(t, writeReq, checkReq) } // test that a server handler can send trailers