From c0ff6953320861f7d96153a26eca8b628d5056d3 Mon Sep 17 00:00:00 2001 From: Masakazu Kitajo Date: Sat, 19 Dec 2015 21:52:05 +0900 Subject: [PATCH] TS-3967: Set stream state to close after a RST_STEAM has been sent --- proxy/http2/HTTP2.cc | 27 +++++++++------ proxy/http2/HTTP2.h | 2 +- proxy/http2/Http2ConnectionState.cc | 52 ++++++++++++++++++++++------- proxy/http2/Http2Stream.h | 12 +++++-- proxy/http2/RegressionHPACK.cc | 8 +++-- 5 files changed, 72 insertions(+), 29 deletions(-) diff --git a/proxy/http2/HTTP2.cc b/proxy/http2/HTTP2.cc index 27ee49c58e2..4f1fb4bd87d 100644 --- a/proxy/http2/HTTP2.cc +++ b/proxy/http2/HTTP2.cc @@ -614,12 +614,14 @@ http2_write_header_fragment(HTTPHdr *in, MIMEFieldIter &field_iter, uint8_t *out * Decode Header Blocks to Header List. */ int64_t -http2_decode_header_blocks(HTTPHdr *hdr, const uint8_t *buf_start, const uint8_t *buf_end, Http2IndexingTable &indexing_table) +http2_decode_header_blocks(HTTPHdr *hdr, const uint8_t *buf_start, const uint8_t *buf_end, Http2IndexingTable &indexing_table, + bool &trailing_header) { const uint8_t *cursor = buf_start; HdrHeap *heap = hdr->m_heap; HTTPHdrImpl *hh = hdr->m_http; bool header_field_started = false; + bool is_trailing_header = trailing_header; while (cursor < buf_end) { int64_t read_bytes = 0; @@ -684,22 +686,29 @@ http2_decode_header_blocks(HTTPHdr *hdr, const uint8_t *buf_start, const uint8_t } } - // when The TE header field is received, it MUST NOT contain any - // value other than "trailers". + // when The TE header field is received, it MUST NOT contain any value other than "trailers". if (name_len == MIME_LEN_TE && strncmp(name, MIME_FIELD_TE, name_len) == 0) { int value_len = 0; const char *value = field->value_get(&value_len); - char trailers[] = "trailers"; + const char trailers[] = "trailers"; if (!(value_len == (sizeof(trailers) - 1) && memcmp(value, trailers, value_len) == 0)) { return HPACK_ERROR_HTTP2_PROTOCOL_ERROR; } } + // turn on that we have a trailer header + const char trailer_name[] = "trailer"; + if (name_len == (sizeof(trailer_name) - 1) && strncmp(name, trailer_name, sizeof(trailer_name) - 1) == 0) { + trailing_header = true; + } + // Store to HdrHeap mime_hdr_field_attach(hh->m_fields_impl, field, 1, NULL); + } + if (!is_trailing_header) { // Check psuedo headers - if (hdr->fields_count() == 4) { + if (hdr->fields_count() >= 4) { if (hdr->field_find(HPACK_VALUE_SCHEME, HPACK_LEN_SCHEME) == NULL || hdr->field_find(HPACK_VALUE_METHOD, HPACK_LEN_METHOD) == NULL || hdr->field_find(HPACK_VALUE_PATH, HPACK_LEN_PATH) == NULL || @@ -707,14 +716,12 @@ http2_decode_header_blocks(HTTPHdr *hdr, const uint8_t *buf_start, const uint8_t // Decoded header field is invalid return HPACK_ERROR_HTTP2_PROTOCOL_ERROR; } + } else { + // Psuedo headers is insufficient + return HPACK_ERROR_HTTP2_PROTOCOL_ERROR; } } - // Psuedo headers is insufficient - if (hdr->fields_count() < 4) { - return HPACK_ERROR_HTTP2_PROTOCOL_ERROR; - } - // Parsing all headers is done return cursor - buf_start; } diff --git a/proxy/http2/HTTP2.h b/proxy/http2/HTTP2.h index 21f29e50954..814b42e4034 100644 --- a/proxy/http2/HTTP2.h +++ b/proxy/http2/HTTP2.h @@ -326,7 +326,7 @@ bool http2_parse_goaway(IOVec, Http2Goaway &); bool http2_parse_window_update(IOVec, uint32_t &); -int64_t http2_decode_header_blocks(HTTPHdr *, const uint8_t *, const uint8_t *, Http2IndexingTable &); +int64_t http2_decode_header_blocks(HTTPHdr *, const uint8_t *, const uint8_t *, Http2IndexingTable &, bool &); MIMEParseResult convert_from_2_to_1_1_header(HTTPHdr *); diff --git a/proxy/http2/Http2ConnectionState.cc b/proxy/http2/Http2ConnectionState.cc index c073e2e9cf2..36969d30533 100644 --- a/proxy/http2/Http2ConnectionState.cc +++ b/proxy/http2/Http2ConnectionState.cc @@ -179,14 +179,18 @@ rcv_headers_frame(Http2ClientSession &cs, Http2ConnectionState &cstate, const Ht return Http2Error(HTTP2_ERROR_CLASS_CONNECTION, HTTP2_ERROR_PROTOCOL_ERROR); } + Http2Stream *stream = NULL; if (stream_id <= cstate.get_latest_stream_id()) { - return Http2Error(HTTP2_ERROR_CLASS_STREAM, HTTP2_ERROR_STREAM_CLOSED); - } - - // Create new stream - Http2Stream *stream = cstate.create_stream(stream_id); - if (!stream) { - return Http2Error(HTTP2_ERROR_CLASS_CONNECTION, HTTP2_ERROR_PROTOCOL_ERROR); + stream = cstate.find_stream(stream_id); + if (stream == NULL || !stream->has_trailing_header()) { + return Http2Error(HTTP2_ERROR_CLASS_STREAM, HTTP2_ERROR_STREAM_CLOSED); + } + } else { + // Create new stream + stream = cstate.create_stream(stream_id); + if (!stream) { + return Http2Error(HTTP2_ERROR_CLASS_CONNECTION, HTTP2_ERROR_PROTOCOL_ERROR); + } } // keep track of how many bytes we get in the frame @@ -247,10 +251,22 @@ rcv_headers_frame(Http2ClientSession &cs, Http2ConnectionState &cstate, const Ht if (frame.header().flags & HTTP2_FLAGS_HEADERS_END_HEADERS) { // NOTE: If there are END_HEADERS flag, decode stored Header Blocks. - if (!stream->change_state(HTTP2_FRAME_TYPE_HEADERS, frame.header().flags)) { + if (!stream->change_state(HTTP2_FRAME_TYPE_HEADERS, frame.header().flags) && stream->has_trailing_header() == false) { return Http2Error(HTTP2_ERROR_CLASS_CONNECTION, HTTP2_ERROR_PROTOCOL_ERROR); } + bool skip_fetcher = false; + if (stream->has_trailing_header()) { + if (!(frame.header().flags & HTTP2_FLAGS_HEADERS_END_STREAM)) { + return Http2Error(HTTP2_ERROR_CLASS_STREAM, HTTP2_ERROR_PROTOCOL_ERROR); + } + // If the flag has already been set before decoding header blocks, this is the trailing header. + // Set a flag to avoid initializing fetcher for now. + // Decoding header blocks is stil needed to maintain a HPACK dynamic table. + // TODO: TS-3812 + skip_fetcher = true; + } + const int64_t decoded_bytes = stream->decode_header_blocks(*cstate.local_indexing_table); if (decoded_bytes == 0 || decoded_bytes == HPACK_ERROR_COMPRESSION_ERROR) { @@ -259,8 +275,10 @@ rcv_headers_frame(Http2ClientSession &cs, Http2ConnectionState &cstate, const Ht return Http2Error(HTTP2_ERROR_CLASS_STREAM, HTTP2_ERROR_PROTOCOL_ERROR); } - if (!stream->init_fetcher(cstate)) { - return Http2Error(HTTP2_ERROR_CLASS_STREAM, HTTP2_ERROR_PROTOCOL_ERROR); + if (!skip_fetcher) { + if (!stream->init_fetcher(cstate)) { + return Http2Error(HTTP2_ERROR_CLASS_STREAM, HTTP2_ERROR_PROTOCOL_ERROR); + } } } else { // NOTE: Expect CONTINUATION Frame. Do NOT change state of stream or decode @@ -573,7 +591,7 @@ rcv_window_update_frame(Http2ClientSession &cs, Http2ConnectionState &cstate, co stream->client_rwnd += size; ssize_t wnd = min(cstate.client_rwnd, stream->client_rwnd); - if (wnd > 0) { + if (stream->get_state() == HTTP2_STREAM_STATE_HALF_CLOSED_REMOTE && wnd > 0) { cstate.send_data_frame(stream->get_fetcher()); } } @@ -834,7 +852,7 @@ Http2ConnectionState::restart_streams() Http2Stream *s = stream_list.head; while (s) { Http2Stream *next = s->link.next; - if (min(this->client_rwnd, s->client_rwnd) > 0) { + if (s->get_state() == HTTP2_STREAM_STATE_HALF_CLOSED_REMOTE && min(this->client_rwnd, s->client_rwnd) > 0) { this->send_data_frame(s->get_fetcher()); } s = next; @@ -894,6 +912,10 @@ Http2ConnectionState::send_data_frame(FetchSM *fetch_sm) Http2Stream *stream = static_cast(fetch_sm->ext_get_user_data()); + if (stream->get_state() == HTTP2_STREAM_STATE_CLOSED) { + return; + } + for (;;) { uint8_t flags = 0x00; @@ -1015,6 +1037,12 @@ Http2ConnectionState::send_rst_stream_frame(Http2StreamId id, Http2ErrorCode ec) http2_write_rst_stream(static_cast(ec), rst_stream.write()); rst_stream.finalize(HTTP2_RST_STREAM_LEN); + // change state to closed + Http2Stream *stream = find_stream(id); + if (stream != NULL) { + stream->change_state(HTTP2_FRAME_TYPE_RST_STREAM, 0); + } + // xmit event SCOPED_MUTEX_LOCK(lock, this->ua_session->mutex, this_ethread()); this->ua_session->handleEvent(HTTP2_SESSION_EVENT_XMIT, &rst_stream); diff --git a/proxy/http2/Http2Stream.h b/proxy/http2/Http2Stream.h index 31f1d756dd6..4bc196dd57e 100644 --- a/proxy/http2/Http2Stream.h +++ b/proxy/http2/Http2Stream.h @@ -34,8 +34,8 @@ class Http2Stream public: Http2Stream(Http2StreamId sid = 0, ssize_t initial_rwnd = Http2::initial_window_size) : client_rwnd(initial_rwnd), server_rwnd(Http2::initial_window_size), header_blocks(NULL), header_blocks_length(0), - request_header_length(0), end_stream(false), _id(sid), _state(HTTP2_STREAM_STATE_IDLE), _fetch_sm(NULL), body_done(false), - data_length(0) + request_header_length(0), end_stream(false), _id(sid), _state(HTTP2_STREAM_STATE_IDLE), _fetch_sm(NULL), + trailing_header(false), body_done(false), data_length(0) { _thread = this_ethread(); HTTP2_INCREMENT_THREAD_DYN_STAT(HTTP2_STAT_CURRENT_CLIENT_STREAM_COUNT, _thread); @@ -90,12 +90,17 @@ class Http2Stream return _state; } bool change_state(uint8_t type, uint8_t flags); + bool + has_trailing_header() const + { + return trailing_header; + } int64_t decode_header_blocks(Http2IndexingTable &indexing_table) { return http2_decode_header_blocks(&_req_header, (const uint8_t *)header_blocks, - (const uint8_t *)header_blocks + header_blocks_length, indexing_table); + (const uint8_t *)header_blocks + header_blocks_length, indexing_table, trailing_header); } // Check entire DATA payload length if content-length: header is exist @@ -131,6 +136,7 @@ class Http2Stream HTTPHdr _req_header; FetchSM *_fetch_sm; + bool trailing_header; bool body_done; uint64_t data_length; }; diff --git a/proxy/http2/RegressionHPACK.cc b/proxy/http2/RegressionHPACK.cc index aee2f3c3e67..e8826ff6d47 100644 --- a/proxy/http2/RegressionHPACK.cc +++ b/proxy/http2/RegressionHPACK.cc @@ -559,14 +559,16 @@ REGRESSION_TEST(HPACK_Decode)(RegressionTest *t, int, int *pstatus) box = REGRESSION_TEST_PASSED; Http2IndexingTable indexing_table; + bool trailing_header = false; for (unsigned int i = 0; i < sizeof(encoded_field_request_test_case) / sizeof(encoded_field_request_test_case[0]); i++) { ats_scoped_obj headers(new HTTPHdr); headers->create(HTTP_TYPE_REQUEST); - http2_decode_header_blocks( - headers, encoded_field_request_test_case[i].encoded_field, - encoded_field_request_test_case[i].encoded_field + encoded_field_request_test_case[i].encoded_field_len, indexing_table); + http2_decode_header_blocks(headers, encoded_field_request_test_case[i].encoded_field, + encoded_field_request_test_case[i].encoded_field + + encoded_field_request_test_case[i].encoded_field_len, + indexing_table, trailing_header); for (unsigned int j = 0; j < sizeof(raw_field_request_test_case[i]) / sizeof(raw_field_request_test_case[i][0]); j++) { const char *expected_name = raw_field_request_test_case[i][j].raw_name;