Skip to content

Commit

Permalink
http3: cancel reading on request stream if request processing fails
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Apr 8, 2024
1 parent 7021b84 commit 41456df
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 62 deletions.
65 changes: 28 additions & 37 deletions http3/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ type contextKey struct{ name string }

func (k *contextKey) String() string { return "quic-go/http3 context value " + k.name }

var errHijacked = errors.New("hijacked")

// ServerContextKey is a context key. It can be used in HTTP
// handlers with Context.Value to access the server that
// started the handler. The associated value will be of
Expand Down Expand Up @@ -461,29 +459,7 @@ func (s *Server) handleConn(conn quic.Connection) error {
}
return fmt.Errorf("accepting stream failed: %w", err)
}
go func() {
rerr := s.handleRequest(hconn, str, decoder, func(e ErrCode) {
conn.CloseWithError(quic.ApplicationErrorCode(e), "")
})
if rerr.err == errHijacked {
return
}
if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 {
s.logger.Debugf("Handling request failed: %s", err)
if rerr.streamErr != 0 {
str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
}
if rerr.connErr != 0 {
var reason string
if rerr.err != nil {
reason = rerr.err.Error()
}
conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
}
return
}
str.Close()
}()
go s.handleRequest(hconn, str, decoder)
}
}

Expand All @@ -494,30 +470,40 @@ func (s *Server) maxHeaderBytes() uint64 {
return uint64(s.MaxHeaderBytes)
}

func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack.Decoder, closeConnection func(ErrCode)) requestError {
func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack.Decoder) {
frame, err := parseNextFrame(str)
if err != nil {
return newStreamError(ErrCodeRequestIncomplete, err)
str.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
return
}
hf, ok := frame.(*headersFrame)
if !ok {
return newConnError(ErrCodeFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "expected first frame to be a HEADERS frame")
return
}
if hf.Length > s.maxHeaderBytes() {
return newStreamError(ErrCodeFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, s.maxHeaderBytes()))
str.CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
str.CancelWrite(quic.StreamErrorCode(ErrCodeFrameError))
return
}
headerBlock := make([]byte, hf.Length)
if _, err := io.ReadFull(str, headerBlock); err != nil {
return newStreamError(ErrCodeRequestIncomplete, err)
str.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
return
}
hfs, err := decoder.DecodeFull(headerBlock)
if err != nil {
// TODO: use the right error code
return newConnError(ErrCodeGeneralProtocolError, err)
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeGeneralProtocolError), "expected first frame to be a HEADERS frame")
return
}
req, err := requestFromHeaders(hfs)
if err != nil {
return newStreamError(ErrCodeMessageError, err)
str.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
str.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
return
}

connState := conn.ConnectionState().TLS
Expand All @@ -527,6 +513,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack
// Check that the client doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
// See section 4.1.2 of RFC 9114.
var httpStr Stream
closeConnection := func(e ErrCode) { conn.CloseWithError(quic.ApplicationErrorCode(e), "") }
if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 {
httpStr = newLengthLimitedStream(newStream(str, closeConnection), req.ContentLength)
} else {
Expand Down Expand Up @@ -580,7 +567,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack
}()

if body.wasStreamHijacked() {
return requestError{err: errHijacked}
return
}

// only write response when there is no panic
Expand All @@ -593,14 +580,18 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack
}
r.Flush()
}
// If the EOF was read by the handler, CancelRead() is a no-op.
str.CancelRead(quic.StreamErrorCode(ErrCodeNoError))

// abort the stream when there is a panic
if panicked {
return newStreamError(ErrCodeInternalError, errPanicked)
str.CancelRead(quic.StreamErrorCode(ErrCodeInternalError))
str.CancelWrite(quic.StreamErrorCode(ErrCodeInternalError))
return
}
return requestError{}

// If the EOF was read by the handler, CancelRead() is a no-op.
str.CancelRead(quic.StreamErrorCode(ErrCodeNoError))

str.Close()
}

// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
Expand Down
55 changes: 30 additions & 25 deletions http3/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,9 @@ var _ = Describe("Server", func() {
return len(p), nil
}).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().Close()

Expect(s.handleRequest(conn, str, qpackDecoder, nil)).To(Equal(requestError{}))
s.handleRequest(conn, str, qpackDecoder)
var req *http.Request
Eventually(requestChan).Should(Receive(&req))
Expect(req.Host).To(Equal("www.example.com"))
Expand All @@ -178,9 +179,9 @@ var _ = Describe("Server", func() {
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().Close()

serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
s.handleRequest(conn, str, qpackDecoder)
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
})
Expand All @@ -195,9 +196,9 @@ var _ = Describe("Server", func() {
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().Close()

serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
s.handleRequest(conn, str, qpackDecoder)
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
Expect(hfs).To(HaveKeyWithValue("content-length", []string{"6"}))
Expand All @@ -217,9 +218,9 @@ var _ = Describe("Server", func() {
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().Close()

serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
s.handleRequest(conn, str, qpackDecoder)
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
// status, date, content-type
Expand All @@ -238,8 +239,9 @@ var _ = Describe("Server", func() {
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
str.EXPECT().Close()

s.handleRequest(conn, str, qpackDecoder)
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
Expect(responseBuf.Bytes()).To(HaveLen(0))
Expand All @@ -257,15 +259,16 @@ var _ = Describe("Server", func() {
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
str.EXPECT().Close()

s.handleRequest(conn, str, qpackDecoder)
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
Expect(hfs).To(HaveKeyWithValue("content-length", []string{"13"}))
Expect(hfs).To(HaveKeyWithValue("content-type", []string{"text/html; charset=utf-8"}))
})

It("handles a aborting handler", func() {
It("handles an aborting handler", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic(http.ErrAbortHandler)
})
Expand All @@ -274,10 +277,10 @@ var _ = Describe("Server", func() {
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError))
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError))

serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).To(MatchError(errPanicked))
s.handleRequest(conn, str, qpackDecoder)
Expect(responseBuf.Bytes()).To(HaveLen(0))
})

Expand All @@ -290,10 +293,10 @@ var _ = Describe("Server", func() {
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError))
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError))

serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).To(MatchError(errPanicked))
s.handleRequest(conn, str, qpackDecoder)
Expect(responseBuf.Bytes()).To(HaveLen(0))
})

Expand Down Expand Up @@ -378,6 +381,7 @@ var _ = Describe("Server", func() {
setRequest(append(requestData, b...))
done := make(chan struct{})
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) })

s.handleConn(conn)
Expand All @@ -393,6 +397,7 @@ var _ = Describe("Server", func() {
testErr := errors.New("stream reset")
done := make(chan struct{})
str.EXPECT().Read(gomock.Any()).Return(0, testErr)
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) })

s.handleConn(conn)
Expand Down Expand Up @@ -420,23 +425,23 @@ var _ = Describe("Server", func() {
Eventually(done).Should(BeClosed())
})

It("closes the connection when the first frame is not a HEADERS frame", func() {
It("rejects a request that has too large request headers", func() {
handlerCalled := make(chan struct{})
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(handlerCalled)
})

// use 2*DefaultMaxHeaderBytes here. qpack will compress the requiest,
// use 2*DefaultMaxHeaderBytes here. qpack will compress the request,
// but the request will still end up larger than DefaultMaxHeaderBytes.
url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2)
req, err := http.NewRequest(http.MethodGet, "https://"+string(url), nil)
Expect(err).ToNot(HaveOccurred())
setRequest(encodeRequest(req))
// str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return len(p), nil
}).AnyTimes()
done := make(chan struct{})
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) })

s.handleConn(conn)
Expand All @@ -460,9 +465,9 @@ var _ = Describe("Server", func() {
return len(p), nil
}).AnyTimes()
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
str.EXPECT().Close()

serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
s.handleRequest(conn, str, qpackDecoder)
Eventually(handlerCalled).Should(BeClosed())
})

Expand All @@ -483,9 +488,9 @@ var _ = Describe("Server", func() {
return len(p), nil
}).AnyTimes()
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
str.EXPECT().Close()

serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
s.handleRequest(conn, str, qpackDecoder)
Eventually(handlerCalled).Should(BeClosed())
})
})
Expand Down

0 comments on commit 41456df

Please sign in to comment.