diff --git a/include/proxy/http2/Http2CommonSession.h b/include/proxy/http2/Http2CommonSession.h index f24ee8b0947..9030364b7d4 100644 --- a/include/proxy/http2/Http2CommonSession.h +++ b/include/proxy/http2/Http2CommonSession.h @@ -36,15 +36,17 @@ // HTTP2_SESSION_EVENT_RECV Http2Frame * Received a frame // HTTP2_SESSION_EVENT_PRIO Http2Frame * Send this priority frame // HTTP2_SESSION_EVENT_DATA Http2Frame * Send the data frames in the stream +// HTTP2_SESSION_EVENT_XMIT Http2CommonSession * Try retransmitting frames. #define HTTP2_SESSION_EVENT_INIT (HTTP2_SESSION_EVENTS_START + 1) #define HTTP2_SESSION_EVENT_FINI (HTTP2_SESSION_EVENTS_START + 2) #define HTTP2_SESSION_EVENT_RECV (HTTP2_SESSION_EVENTS_START + 3) -#define HTTP2_SESSION_EVENT_XMIT (HTTP2_SESSION_EVENTS_START + 4) +#define HTTP2_SESSION_EVENT_PRIO (HTTP2_SESSION_EVENTS_START + 4) #define HTTP2_SESSION_EVENT_DATA (HTTP2_SESSION_EVENTS_START + 5) -#define HTTP2_SESSION_EVENT_SHUTDOWN_INIT (HTTP2_SESSION_EVENTS_START + 6) -#define HTTP2_SESSION_EVENT_SHUTDOWN_CONT (HTTP2_SESSION_EVENTS_START + 7) -#define HTTP2_SESSION_EVENT_REENABLE (HTTP2_SESSION_EVENTS_START + 8) +#define HTTP2_SESSION_EVENT_XMIT (HTTP2_SESSION_EVENTS_START + 6) +#define HTTP2_SESSION_EVENT_SHUTDOWN_INIT (HTTP2_SESSION_EVENTS_START + 7) +#define HTTP2_SESSION_EVENT_SHUTDOWN_CONT (HTTP2_SESSION_EVENTS_START + 8) +#define HTTP2_SESSION_EVENT_REENABLE (HTTP2_SESSION_EVENTS_START + 9) enum class Http2SessionCod : int { NOT_PROVIDED, diff --git a/include/proxy/http2/Http2ConnectionState.h b/include/proxy/http2/Http2ConnectionState.h index 83ff405649b..4952f979fc9 100644 --- a/include/proxy/http2/Http2ConnectionState.h +++ b/include/proxy/http2/Http2ConnectionState.h @@ -158,6 +158,8 @@ class Http2ConnectionState : public Continuation void schedule_stream_to_send_priority_frames(Http2Stream *stream); void send_data_frames_depends_on_priority(); void schedule_stream_to_send_data_frames(Http2Stream *stream); + void schedule_retransmit(ink_hrtime t); + void cancel_retransmit(); void send_data_frames(Http2Stream *stream); Http2SendDataFrameResult send_a_data_frame(Http2Stream *stream, size_t &payload_length); void send_headers_frame(Http2Stream *stream); @@ -398,6 +400,7 @@ class Http2ConnectionState : public Continuation Event *shutdown_cont_event = nullptr; Event *fini_event = nullptr; Event *zombie_event = nullptr; + Event *retransmit_event = nullptr; uint32_t configured_max_settings_frames_per_minute = 0; uint32_t configured_max_ping_frames_per_minute = 0; diff --git a/src/proxy/http2/Http2ClientSession.cc b/src/proxy/http2/Http2ClientSession.cc index bf4732fb46c..97ede81730d 100644 --- a/src/proxy/http2/Http2ClientSession.cc +++ b/src/proxy/http2/Http2ClientSession.cc @@ -210,7 +210,7 @@ Http2ClientSession::main_event_handler(int event, void *edata) retval = 0; break; - case HTTP2_SESSION_EVENT_XMIT: + case HTTP2_SESSION_EVENT_PRIO: default: Http2SsnDebug("unexpected event=%d edata=%p", event, edata); ink_release_assert(0); diff --git a/src/proxy/http2/Http2CommonSession.cc b/src/proxy/http2/Http2CommonSession.cc index 56c82e40ac0..7581c47826d 100644 --- a/src/proxy/http2/Http2CommonSession.cc +++ b/src/proxy/http2/Http2CommonSession.cc @@ -158,6 +158,11 @@ Http2CommonSession::xmit(const Http2TxFrame &frame, bool flush) // A frame size can be 16MB at maximum so blocks can be added, but that's fine. if (this->_pending_sending_data_size >= this->_write_size_threshold) { flush = true; + } else { + Note("Calling schedule_transmit because write threshold is not exceeded."); + // Observe that schedule_transmit will only schedule the first time we + // don't flush because the threshold is not met. + this->connection_state.schedule_retransmit(HRTIME_MSECONDS(Http2::write_time_threshold)); } } if (flush) { @@ -170,6 +175,7 @@ Http2CommonSession::xmit(const Http2TxFrame &frame, bool flush) void Http2CommonSession::flush() { + this->connection_state.cancel_retransmit(); if (this->_pending_sending_data_size > 0) { this->_pending_sending_data_size = 0; this->_write_buffer_last_flush = ink_get_hrtime(); diff --git a/src/proxy/http2/Http2ConnectionState.cc b/src/proxy/http2/Http2ConnectionState.cc index 703586c383b..b4fcd28b68a 100644 --- a/src/proxy/http2/Http2ConnectionState.cc +++ b/src/proxy/http2/Http2ConnectionState.cc @@ -1354,6 +1354,9 @@ Http2ConnectionState::destroy() if (zombie_event) { zombie_event->cancel(); } + if (retransmit_event) { + retransmit_event->cancel(); + } // release the mutex after the events are cancelled and sessions are destroyed. mutex = nullptr; // magic happens - assigning to nullptr frees the ProxyMutex } @@ -1430,6 +1433,8 @@ Http2ConnectionState::main_event_handler(int event, void *edata) ink_release_assert(zombie_event == nullptr); } else if (edata == fini_event) { fini_event = nullptr; + } else if (edata == retransmit_event) { + retransmit_event = nullptr; } ++recursion; switch (event) { @@ -1445,7 +1450,7 @@ Http2ConnectionState::main_event_handler(int event, void *edata) SET_HANDLER(&Http2ConnectionState::state_closed); } break; - case HTTP2_SESSION_EVENT_XMIT: { + case HTTP2_SESSION_EVENT_PRIO: { REMEMBER(event, this->recursion); SCOPED_MUTEX_LOCK(lock, this->mutex, this_ethread()); send_data_frames_depends_on_priority(); @@ -1459,6 +1464,13 @@ Http2ConnectionState::main_event_handler(int event, void *edata) _data_scheduled = false; } break; + case HTTP2_SESSION_EVENT_XMIT: { + REMEMBER(event, this->recursion); + SCOPED_MUTEX_LOCK(lock, this->mutex, this_ethread()); + Note("Flushing due to XMIT event"); + this->session->flush(); + } break; + // Initiate a graceful shutdown case HTTP2_SESSION_EVENT_SHUTDOWN_INIT: { REMEMBER(event, this->recursion); @@ -1522,6 +1534,8 @@ Http2ConnectionState::state_closed(int event, void *edata) fini_event = nullptr; } else if (edata == shutdown_cont_event) { shutdown_cont_event = nullptr; + } else if (edata == retransmit_event) { + retransmit_event = nullptr; } return 0; } @@ -2037,7 +2051,7 @@ Http2ConnectionState::schedule_stream_to_send_priority_frames(Http2Stream *strea _priority_scheduled = true; SET_HANDLER(&Http2ConnectionState::main_event_handler); - this_ethread()->schedule_imm_local((Continuation *)this, HTTP2_SESSION_EVENT_XMIT); + this_ethread()->schedule_imm_local((Continuation *)this, HTTP2_SESSION_EVENT_PRIO); } } @@ -2056,6 +2070,31 @@ Http2ConnectionState::schedule_stream_to_send_data_frames(Http2Stream *stream) } } +void +Http2ConnectionState::schedule_retransmit(ink_hrtime t) +{ + Http2StreamDebug(session, 0, "Scheduling retransmitting data frames"); + SCOPED_MUTEX_LOCK(lock, this->mutex, this_ethread()); + + if (retransmit_event == nullptr) { + Note("Scheduling retransmit in %" PRId64 "ms", t / HRTIME_MSECOND); + + SET_HANDLER(&Http2ConnectionState::main_event_handler); + retransmit_event = this_ethread()->schedule_in((Continuation *)this, t, HTTP2_SESSION_EVENT_XMIT); + } +} + +void +Http2ConnectionState::cancel_retransmit() +{ + Http2StreamDebug(session, 0, "Scheduling retransmitting data frames"); + SCOPED_MUTEX_LOCK(lock, this->mutex, this_ethread()); + if (retransmit_event != nullptr) { + retransmit_event->cancel(); + retransmit_event = nullptr; + } +} + void Http2ConnectionState::send_data_frames_depends_on_priority() { @@ -2098,7 +2137,7 @@ Http2ConnectionState::send_data_frames_depends_on_priority() break; } - this_ethread()->schedule_imm_local((Continuation *)this, HTTP2_SESSION_EVENT_XMIT); + this_ethread()->schedule_imm_local((Continuation *)this, HTTP2_SESSION_EVENT_PRIO); return; } diff --git a/src/proxy/http2/Http2ServerSession.cc b/src/proxy/http2/Http2ServerSession.cc index 46b37674208..d6cec0b0746 100644 --- a/src/proxy/http2/Http2ServerSession.cc +++ b/src/proxy/http2/Http2ServerSession.cc @@ -22,6 +22,7 @@ */ #include "proxy/http2/Http2ServerSession.h" +#include "iocore/net/TLSSNISupport.h" #include "proxy/http/HttpDebugNames.h" #include "tscore/ink_base64.h" #include "proxy/http2/Http2CommonSessionInternal.h" @@ -127,6 +128,14 @@ Http2ServerSession::new_connection(NetVConnection *new_vc, MIOBuffer *iobuf, IOB this->_write_buffer_reader = this->write_buffer->alloc_reader(); this->_write_size_threshold = index_to_buffer_size(buffer_block_size_index) * Http2::write_size_threshold; + uint32_t buffer_water_mark; + if (auto snis = this->_vc->get_service(); snis && snis->hints_from_sni.http2_buffer_water_mark.has_value()) { + buffer_water_mark = snis->hints_from_sni.http2_buffer_water_mark.value(); + } else { + buffer_water_mark = Http2::buffer_water_mark; + } + this->write_buffer->water_mark = buffer_water_mark; + this->_handle_if_ssl(new_vc); do_api_callout(TS_HTTP_SSN_START_HOOK); @@ -206,7 +215,7 @@ Http2ServerSession::main_event_handler(int event, void *edata) retval = 0; break; - case HTTP2_SESSION_EVENT_XMIT: + case HTTP2_SESSION_EVENT_PRIO: default: Http2SsnDebug("unexpected event=%d edata=%p", event, edata); ink_release_assert(0); diff --git a/tests/Pipfile b/tests/Pipfile index ebdd37f1898..2a81ac8c8d2 100644 --- a/tests/Pipfile +++ b/tests/Pipfile @@ -53,5 +53,8 @@ pyyaml ="*" grpcio = "*" grpcio-tools = "*" +pyOpenSSL = "*" +eventlet = "*" + [requires] python_version = "3" diff --git a/tests/gold_tests/h2/http2_write_threshold.test.py b/tests/gold_tests/h2/http2_write_threshold.test.py new file mode 100644 index 00000000000..694037a7a56 --- /dev/null +++ b/tests/gold_tests/h2/http2_write_threshold.test.py @@ -0,0 +1,137 @@ +"""Test proxy.config.http2.write_size_threshold.""" + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from ports import get_port +import sys + + +class TestGrpc(): + """Test proxy.config.http2.write_size_threshold and its associated timeout.""" + + def __init__(self, description: str, write_threshold: int, write_timeout: int) -> None: + """Configure a TestRun for gRPC traffic. + + :param description: The description for the test runs. + """ + self._description = description + tr = Test.AddTestRun(self._description) + dns = self._configure_dns(tr) + server = self._configure_h2_server(tr, write_timeout) + ts = self._configure_traffic_server(tr, dns.Variables.Port, server.Variables.port, write_threshold, write_timeout) + + ts.StartBefore(dns) + ts.StartBefore(server) + tr.Processes.Default.StartBefore(ts) + + tr.TimeOut = 10 + + self._configure_h2_client(tr, ts.Variables.ssl_port, write_timeout) + + def _configure_dns(self, tr: 'TestRun') -> 'Process': + """Configure a locally running MicroDNS server. + + :param tr: The TestRun with which to associate the MicroDNS server. + :return: The MicroDNS server process. + """ + self._dns = tr.MakeDNServer("dns", default=['127.0.0.1']) + return self._dns + + def _configure_h2_server(self, tr: 'TestRun', write_timeout: int) -> 'Process': + """Set up the go HTTP/2 server. + + :param tr: The TestRun with which to associate the server. + :param write_timeout: The expected maximum amount of time frames should be delivered. + :return: The server process. + """ + tr.Setup.Copy('trickle_server.py') + self._server = tr.Processes.Process('server') + + server_pem = os.path.join(Test.Variables.AtsTestToolsDir, "ssl", "server.pem") + server_key = os.path.join(Test.Variables.AtsTestToolsDir, "ssl", "server.key") + self._server.Setup.Copy(server_pem) + self._server.Setup.Copy(server_key) + + port = get_port(self._server, 'port') + command = (f'{sys.executable} {tr.RunDirectory}/trickle_server.py {port} ' + f'server.pem server.key {write_timeout}') + self._server.Command = command + self._server.ReturnCode = 0 + self._server.Ready = When.PortOpen(port) + return self._server + + def _configure_traffic_server( + self, tr: 'TestRun', dns_port: int, server_port: int, write_threshold: int, write_timeout: int) -> 'Process': + """Configure the traffic server process. + + :param tr: The TestRun with which to associate the traffic server. + :param dns_port: The MicroDNS server port that traffic server should connect to. + :param server_port: The server port that traffic server should connect to. + :param write_threshold: The value to set for proxy.config.http2.write_size_threshold. + :param write_timeout: The value to set for proxy.config.http2.write_time_threshold. + :return: The traffic server process. + """ + self._ts = tr.MakeATSProcess("ts", enable_tls=True, enable_cache=False) + + self._ts.addDefaultSSLFiles() + self._ts.Disk.ssl_multicert_config.AddLine("dest_ip=* ssl_cert_name=server.pem ssl_key_name=server.key") + + self._ts.Disk.remap_config.AddLine(f"map / https://example.com:{server_port}/") + + self._ts.Disk.records_config.update( + { + "proxy.config.ssl.server.cert.path": self._ts.Variables.SSLDir, + "proxy.config.ssl.server.private_key.path": self._ts.Variables.SSLDir, + 'proxy.config.ssl.client.alpn_protocols': 'h2,http/1.1', + 'proxy.config.http.server_session_sharing.pool': 'thread', + 'proxy.config.ssl.client.verify.server.policy': 'PERMISSIVE', + 'proxy.config.dns.nameservers': f"127.0.0.1:{dns_port}", + 'proxy.config.dns.resolv_conf': "NULL", + 'proxy.config.http2.write_size_threshold': write_threshold, + 'proxy.config.http2.write_time_threshold': write_timeout, + + # Only enable debug logging during manual exectution. All the + # DATA frames get multiple logs and it makes the traffic.out too + # unwieldy. + "proxy.config.diags.debug.enabled": 0, + "proxy.config.diags.debug.tags": "http", + }) + return self._ts + + def _configure_h2_client(self, tr: 'TestRun', proxy_port: int, write_timeout: int) -> None: + """Start the HTTP/2 client. + + :param tr: The TestRun with which to associate the client. + :param proxy_port: The proxy_port to which to connect. + """ + tr.Setup.Copy('trickle_client.py') + ca = os.path.join(Test.Variables.AtsTestToolsDir, "ssl", "server.pem") + key = os.path.join(Test.Variables.AtsTestToolsDir, "ssl", "server.key") + + self._server.Setup.Copy(ca) + self._server.Setup.Copy(key) + # The cert is for example.com, so we must use that domain. + hostname = 'example.com' + command = (f'{sys.executable} {tr.RunDirectory}/trickle_client.py ' + f'{hostname} {proxy_port} server.pem {write_timeout}') + p = tr.Processes.Default + p.Command = command + p.ReturnCode = 0 + + +test = TestGrpc("Test proxy.config.http2.write_size_threshold", 0.5, 10) diff --git a/tests/gold_tests/h2/trickle_client.py b/tests/gold_tests/h2/trickle_client.py new file mode 100644 index 00000000000..5b777086ffe --- /dev/null +++ b/tests/gold_tests/h2/trickle_client.py @@ -0,0 +1,349 @@ +'''Implement a client that sends many small DATA frames.''' + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from email.message import EmailMessage as HttpHeaders +import logging +import math +import socket +import sys +import ssl +from OpenSSL.SSL import Error as SSLError +from OpenSSL.SSL import SysCallError as SSLSysCallError +import statistics +import traceback +import time + +from typing import Dict, List, Tuple + +import eventlet +from h2.connection import H2Connection +from h2.events import StreamEnded, ResponseReceived, DataReceived, TrailersReceived, StreamReset, ConnectionTerminated +from h2.exceptions import StreamClosedError + + +def get_body_text() -> bytes: + """Create a body of text to send in the request.""" + body_items: List[str] = [] + for i in range(100): + # Create a chunk of 0 padded bytes, followed by a space. + chunk_payload = f'{i:06x} '.encode("utf-8") + body_items.append(chunk_payload) + return b''.join(body_items) + + +class RequestInfo: + """POD for request headers, etc.""" + + def __init__(self, stream_id: int, headers: Dict[str, str], body: str): + self.stream_id: int = stream_id + self.headers: Dict[str, str] = headers + self.body_bytes: str = body + + +class ResponseInfo: + """POD for response headers, etc.""" + + def __init__( + self, + status: int, + headers: Dict[bytes, bytes], + body: bytes, + trailers: Dict[bytes, bytes] = None, + errors: List[str] = None): + self.status_code: int = status + self.headers: Dict[bytes, bytes] = headers + self.body_bytes: bytes = body + self.trailers: Dict[bytes, bytes] = trailers + self.errors: List[str] = errors + + +def print_transaction(request: RequestInfo, response: ResponseInfo) -> None: + """Print a description of the transaction. + + :param request: The details about the request. + :param response: The details about the response. + """ + + description = "\n==== REQUEST HEADERS ====\n" + for k, v in request.headers.items(): + if isinstance(k, bytes): + k, v = (k.decode('ascii'), v.decode('ascii')) + description += f"{k}: {v}\n" + + if request.body_bytes is not None: + description += f"\n==== REQUEST BODY ====\n{request.body_bytes}\n" + + description += "\n==== RESPONSE ====\n" + description += f"{response.status_code}\n" + + description += "\n==== RESPONSE HEADERS ====\n" + for k, v in response.headers: + if isinstance(k, bytes): + k, v = (k.decode('ascii'), v.decode('ascii')) + description += f"{k}: {v}\n" + + if response.body_bytes is not None: + description += f"\n==== RESPONSE BODY ====\n{response.body_bytes.decode()}\n" + + if response.trailers is not None: + description += "\n==== RESPONSE TRAILERS ====\n" + for k, v in response.trailers.items(): + if isinstance(k, bytes): + k, v = (k.decode('ascii'), v.decode('ascii')) + description += "{k}: {v}\n" + + description += "\n==== END ====\n" + + logging.info(description) + + +class Http2Connection: + ''' + This class manages a single HTTP/2 connection to a server. It is not + thread-safe. For our purpose though, no lock is neccessary as the streams of + each connection are processed sequentially. + ''' + + def __init__(self, sock, h2conn): + self.sock = sock + self.conn = h2conn + + def send_request(self, request: RequestInfo) -> Tuple[ResponseInfo, List[int]]: + ''' + Sends a request to the h2 connection and returns the response object containing the headers, body, and possible errors. + ''' + self.conn.send_headers(request.stream_id, request.headers.items()) + logging.info(f'Sent headers.') + # Send the data over the socket. + self.sock.sendall(self.conn.data_to_send()) + response_headers_raw = None + response_body = b'' + response_stream_ended = False + request_stream_ended = False + trailers = None + errors = [] + bytes_sent = 0 + bytes_left = len(request.body_bytes) + data_frame_differentials: List[int] = [] + time_of_last_frame = time.perf_counter_ns() + while not response_stream_ended: + + send_window = self.conn.local_flow_control_window(request.stream_id) + bytes_to_send = min(send_window, bytes_left) + # Send one byte at a time, every millisecond. + while bytes_to_send > 0: + chunk_size = 1 + byte_to_send = request.body_bytes[bytes_sent:bytes_sent + chunk_size] + logging.debug(f'Sending {byte_to_send}') + self.conn.send_data(request.stream_id, byte_to_send) + self.sock.sendall(self.conn.data_to_send()) + bytes_left -= chunk_size + bytes_sent += chunk_size + bytes_to_send -= chunk_size + time.sleep(0.001) + + if not request_stream_ended and bytes_left == 0: + logging.debug('Closing the connection') + self.conn.end_stream(request.stream_id) + request_stream_ended = True + + logging.debug('Reading any responses from the socket') + data = self.sock.recv(65536 * 1024) + if not data: + break + + # Feed raw data into h2 engine, and process resulting events. + logging.debug('Feeding the data into the connection') + events = self.conn.receive_data(data) + have_counted_data_delay = False + for event in events: + if isinstance(event, ResponseReceived): + # Received response headers. + response_headers_raw = event.headers + time_of_last_frame = time.perf_counter_ns() + logging.info('Received response headers.') + if isinstance(event, DataReceived): + # Update flow control so the server doesn't starve us. + self.conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id) + # Received more response body data. + response_body += event.data + current_time = time.perf_counter_ns() + ms_delay = (current_time - time_of_last_frame) / (1000 * 1000) + if not have_counted_data_delay: + data_frame_differentials.append(ms_delay) + time_of_last_frame = time.perf_counter_ns() + logging.debug(f"Received {len(event.data)} bytes of data after {ms_delay} ms") + have_counted_data_delay = True + if isinstance(event, TrailersReceived): + # Received trailer headers. + trailers = event.headers + if isinstance(event, StreamReset): + # Stream reset by the server. + logging.debug(f"Received RST_STREAM from the server: {event}") + errors.append('StreamReset') + response_stream_ended = True + break + if isinstance(event, ConnectionTerminated): + # Received GOAWAY frame from the server. + logging.debug(f"Received GOAWAY from the server: {event}") + errors.append('ConnectionTerminated') + response_stream_ended = True + break + if isinstance(event, StreamEnded): + # Received complete response body. + logging.info('Received stream end.') + response_stream_ended = True + break + + if not errors: + # Send any pending data to the server. + self.sock.sendall(self.conn.data_to_send()) + + # Decode the header fields. + status_code = next((t[1] for t in response_headers_raw if t[0].decode() == ':status'), None) + status_code = int(status_code) if status_code else 0 + return ResponseInfo(status_code, response_headers_raw, response_body, trailers, errors), data_frame_differentials + + def close(self): + """Tell the server we are closing the h2 connection.""" + self.conn.close_connection() + self.sock.sendall(self.conn.data_to_send()) + self.sock.close() + + +def create_ssl_context(cert): + """ + Create a SSL context with the given cert file. + """ + ctx = ssl.create_default_context() + ctx.set_alpn_protocols(['h2', 'http/1.1']) + # Load the cert file + ctx.load_cert_chain(cert) + # Do not verify the server's certificate + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + return ctx + + +def send_http2_request_to_server(hostname: str, port: int, cert_file: str, write_timeout: int) -> int: + """Establish a connection with the server and send a request. + + :param hostname: The hostname to use for the :authority header. + :param port: The port to connect to. + :param cert_file: The TLS certificate file. + :param write_timeout: The expected maximum amount of time frames should be delivered. + + :return: 0 if the request was successful, 1 otherwise. + """ + + request_headers = HttpHeaders() + request_headers.add_header(':method', 'GET') + request_headers.add_header(':path', '/some/path') + request_headers.add_header(':authority', hostname) + request_headers.add_header(':scheme', 'https') + + scheme = request_headers[':scheme'] + replay_server = f"127.0.0.1:{port}" + path = request_headers[':path'] + authority = request_headers.get(':authority', '') + + stream_id = 1 + body = get_body_text() + request: RequestInfo = RequestInfo(stream_id, request_headers, body) + + try: + # Open a socket to the server and initiate TLS/SSL. + ssl_context = create_ssl_context(cert=cert_file) + setattr(ssl_context, "old_wrap_socket", ssl_context.wrap_socket) + + def new_wrap_socket(sock, *args, **kwargs): + # Make the SNI line up with the :authority header value. + kwargs['server_hostname'] = hostname + return ssl_context.old_wrap_socket(sock, *args, **kwargs) + + setattr(ssl_context, "wrap_socket", new_wrap_socket) + # Opens a connection to the server. + logging.info(f"Connecting to '{scheme}://{replay_server}' with request to '{authority}{path}'") + sock = socket.create_connection(('127.0.0.1', port)) + sock = ssl_context.wrap_socket(sock) + + # Initiate an HTTP/2 connection. + http2_connection = H2Connection() + http2_connection.initiate_connection() + # Initial SETTINGS frame, etc. + sock.sendall(http2_connection.data_to_send()) + client = Http2Connection(sock, http2_connection) + response, data_delays = client.send_request(request) + if response.errors: + try: + if 'StreamReset' in response.errors: + http2_connection.reset_stream(stream_id) + if 'ConnectionTerminated' in response.errors: + http2_connection.close_connection(last_stream_id=0) + except StreamClosedError as err: + logging.error(err) + return 1 + else: + client.close() + except Exception as e: + logging.error(f"Connection to '{replay_server}' initiated with request to " + f"'{scheme}://{authority}{path}' failed: {e}") + traceback.print_exc(file=sys.stdout) + return 1 + + print_transaction(request, response) + logging.info(f'Smallest delay: {min(data_delays)} ms') + logging.info(f'Largest delay: {max(data_delays)} ms') + average = statistics.mean(data_delays) + logging.info(f'Average delay over {len(data_delays)} reads: {average} ms') + isclose = math.isclose(average, write_timeout, rel_tol=0.2) + if isclose: + logging.info(f'Average delay of {average} is within 20% of the expected delay: {write_timeout} ms') + else: + logging.info(f'Average delay of {average} is not within 20% of the expected delay: {write_timeout} ms') + + return 0 if isclose else 1 + + +def parse_args() -> argparse.Namespace: + """Parse the command line arguments. + + :return: The parsed arguments. + """ + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument('hostname', type=str, help='The hostname to use for the :authority header.') + parser.add_argument('port', type=int, help='The port to connect to.') + parser.add_argument('cert', type=str, help='The TLS certificate file.') + parser.add_argument('write_timeout', type=int, help='The expected maximum amount of time frames should be delivered.') + parser.add_argument('-v', '--verbose', action='store_true', help='Enable verbose logging.') + return parser.parse_args() + + +def main() -> int: + """Start the client.""" + args = parse_args() + + log_level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig(level=log_level, format='%(asctime)s - %(levelname)s - %(message)s') + + return send_http2_request_to_server(args.hostname, args.port, args.cert, args.write_timeout) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/tests/gold_tests/h2/trickle_server.py b/tests/gold_tests/h2/trickle_server.py new file mode 100644 index 00000000000..1ccfa8fc1b8 --- /dev/null +++ b/tests/gold_tests/h2/trickle_server.py @@ -0,0 +1,407 @@ +''' +Implement an HTTP/2 server that monitors DATA frame statistics. +''' +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import math +import statistics +import sys +import time +from OpenSSL.SSL import Error as SSLError +from OpenSSL.SSL import SysCallError as SSLSysCallError +import threading + +import eventlet +from eventlet.green.OpenSSL import SSL, crypto +from h2.config import H2Configuration +from h2.connection import H2Connection +from h2.events import StreamEnded, RequestReceived, DataReceived, StreamReset, ConnectionTerminated +from h2.errors import ErrorCodes as H2ErrorCodes +from h2.exceptions import StreamClosedError, StreamIDTooLowError + +from typing import Dict, List, Optional, Set + + +def get_body_text() -> bytes: + """Create a body of text to send in the response.""" + body_items: List[str] = [] + for i in range(100): + # Create a chunk of 0 padded bytes, followed by a space. + chunk_payload = f'{i:06x} '.encode("utf-8") + body_items.append(chunk_payload) + return b''.join(body_items) + + +class RequestInfo: + """POD for request headers, etc.""" + + def __init__(self, stream_id: int): + self.stream_id: int = stream_id + self.headers: Dict[bytes, bytes] = None + self.body_bytes: bytes = None + + +class ResponseInfo: + """POD for response headers, etc.""" + + def __init__( + self, + status: int, + headers: Dict[bytes, bytes], + body: bytes, + trailers: Dict[bytes, bytes] = None, + errors: List[str] = None): + self.status_code: int = status + self.headers: Dict[bytes, bytes] = headers + self.body_bytes: bytes = body + self.trailers: Dict[bytes, bytes] = trailers + self.errors: List[str] = errors + + +class Http2ConnectionManager: + """Manages a single HTTP/2 connection.""" + + def __init__(self, sock: eventlet.greenio.GreenSocket): + listening_config = H2Configuration(client_side=False, validate_inbound_headers=False) + self.tls = threading.local() + self.sock = sock + self.sock.settimeout(1.0) + self.listening_conn: H2Connection = H2Connection(config=listening_config) + self.requests: Dict[int, RequestInfo] = {} + self.completed_stream_ids: Set[int] = set() + + # Delay times in ms between each data frame. + self._data_delays: List[int] = [] + # The last time in ms since epoch that a packet was received. + self.last_packet_time: int = 0 + + def _send_responses(self, responses: Dict[int, ResponseInfo]) -> None: + """Send any responses that have been generated. + + :param responses: A dictionary of responses we wish to send. + """ + responded_streams = [] + for stream_id, response in responses.items(): + try: + self.listening_conn.send_headers(stream_id, response.headers) + + send_window = self.listening_conn.local_flow_control_window(stream_id) + body_size = len(response.body_bytes) + bytes_to_send = min(send_window, body_size) + if bytes_to_send < body_size: + raise ValueError( + f'We do not have a big enough window: body size of {body_size} bytes vs {send_window} byte window') + # Send one byte at a time, every millisecond. + bytes_sent = 0 + while bytes_to_send > 0: + chunk_size = 1 + byte_to_send = response.body_bytes[bytes_sent:bytes_sent + chunk_size] + logging.debug(f'Sending {byte_to_send}') + self.listening_conn.send_data(stream_id, byte_to_send) + self.sock.sendall(self.listening_conn.data_to_send()) + bytes_sent += chunk_size + bytes_to_send -= chunk_size + time.sleep(0.001) + + self.listening_conn.send_data(stream_id, response.body_bytes, end_stream=False if response.trailers else True) + if response.trailers is not None: + self.listening_conn.send_headers(stream_id, response.trailers, end_stream=True) + responded_streams.append(stream_id) + + except StreamClosedError as e: + logging.debug(e) + except StreamIDTooLowError as e: + logging.debug(e) + try: + # Send the responses we added to the listening_conn. + self.sock.sendall(self.listening_conn.data_to_send()) + except (SSLError, SSLSysCallError) as e: + logging.debug(f'Ignoring sock.sendall exception for now: {e}') + + # Clean up any responses we sent. + for stream_id in responded_streams: + del responses[stream_id] + + def _receive_data(self, responses: Dict[int, ResponseInfo]) -> Optional[bytes]: + """Receive data from the socket. + + :param responses: A dictionary of stream IDs to responses that have accumulated. + + :return: The data received, or None if the connection for the socket has closed. + """ + data: Optional[bytes] = None + while not data: + try: + logging.debug('Receiving data on the socket.') + data = self.sock.recv(65535) + except SSLError: + logging.debug('recv error: the socket is closed.') + return None + except TimeoutError: + # Take time to send any responses we've generated. + self._send_responses(responses) + + # Loop back around to receive more data. + logging.debug('Timeout, waiting again for more data.') + continue + return data + + def _process_events(self, events: List, responses: Dict[int, ResponseInfo]) -> None: + """Process events from the H2 connection. + + :param events: The events to process. + :param responses: A dictionary of stream IDs to responses that have accumulated. + """ + have_counted_data_delay = False + for event in events: + if hasattr(event, 'stream_id'): + stream_id = event.stream_id + if stream_id not in self.requests: + self.requests[stream_id] = RequestInfo(stream_id) + + request_info = self.requests[stream_id] + + if isinstance(event, DataReceived): + if request_info.body_bytes is None: + request_info.body_bytes = b'' + logging.debug(f'Got data for stream {stream_id}: {event.data.decode()}') + request_info.body_bytes += event.data + + if not have_counted_data_delay: + ms_since_last_packet = (time.perf_counter_ns() - self.last_packet_time) / (1000 * 1000) + logging.debug(f'Counting data delay for stream {stream_id}: {ms_since_last_packet} ms') + self._data_delays.append(ms_since_last_packet) + self.last_packet_time = time.perf_counter_ns() + have_counted_data_delay = True + + if isinstance(event, RequestReceived): + logging.info(f'Incoming request received for stream {event.stream_id}.') + logging.debug(f'Headers received: {event.headers}') + request_info.headers = event.headers + self.last_packet_time = time.perf_counter_ns() + + if isinstance(event, StreamReset): + self.completed_stream_ids.add(stream_id) + err = H2ErrorCodes(event.error_code).name + logging.debug(f'Received RST_STREAM frame with error code {err} on stream {event.stream_id}.') + if stream_id not in responses.keys(): + response = self._process_request(request_info) + if response is not None: + responses[stream_id] = response + + if isinstance(event, StreamEnded): + logging.debug('StreamEnded') + self.completed_stream_ids.add(stream_id) + if stream_id not in responses.keys(): + response = self._process_request(request_info) + if response is not None: + responses[stream_id] = response + + else: + if isinstance(event, ConnectionTerminated): + err = H2ErrorCodes(event.error_code).name + logging.debug(f'Received GOAWAY frame with error code {err} on with last stream id {event.last_stream_id}.') + self.listening_conn.close_connection() + + def _cleanup_closed_stream_ids(self) -> None: + """Clean up any closed streams.""" + for stream_id in self.completed_stream_ids: + try: + if self.listening_conn.streams[stream_id].closed: + del self.requests[stream_id] + except KeyError: + pass + try: + self.completed_stream_ids = set([id for id in self.completed_stream_ids if not self.listening_conn.streams[id].closed]) + except KeyError: + pass + + def run_forever(self): + self.listening_conn.initiate_connection() + + try: + self.sock.sendall(self.listening_conn.data_to_send()) + except (SSLError, SSLSysCallError) as e: + logging.debug(f'Initial sock.sendall exception: {e}') + return + + responses: Dict[int, ResponseInfo] = {} + while True: + data = self._receive_data(responses) + if not data: + # Connection ended. + break + + logging.debug(f'Giving {len(data)} bytes to the h2 connection') + events = self.listening_conn.receive_data(data) + self._process_events(events, responses) + self._cleanup_closed_stream_ids() + self._send_responses(responses) + + logging.debug('Sending data on the socket') + try: + self.sock.sendall(self.listening_conn.data_to_send()) + except (SSLError, SSLSysCallError) as e: + logging.debug(f'Ignoring end-loop sock.sendall exception for now: {e}') + pass + + def get_data_delays(self) -> List[int]: + """Get the DATA frame timing list. + + :return: The list of DATA frame timings. + """ + return self._data_delays + + def _process_request(self, request: RequestInfo) -> ResponseInfo: + """Handle a request from a client. + + :return: A response to send back to the client. + """ + logging.debug(f'Request received for stream id: {request.stream_id}') + response_headers = [ + (':status'.encode(), '200'.encode()), + ('content-type'.encode(), 'text/plain'.encode()), + ] + response = ResponseInfo(200, response_headers, get_body_text()) + + self._print_transaction(request, response) + return response + + def _print_transaction(self, request: RequestInfo, response: ResponseInfo) -> None: + """Print the details of the request and response.""" + + description = '' + description += '\n==== REQUEST HEADERS ====\n' + for k, v in request.headers: + if isinstance(k, bytes): + k, v = (k.decode('ascii'), v.decode('ascii')) + description += f"{k}: {v}\n" + + if request.body_bytes is not None: + description += f"\n==== REQUEST BODY ====\n{request.body_bytes.decode()}\n" + + description += "\n==== RESPONSE ====\n" + description += f"{response.status_code}\n" + + description += "\n==== RESPONSE HEADERS ====\n" + for k, v in response.headers: + if isinstance(k, bytes): + k, v = (k.decode('ascii'), v.decode('ascii')) + description += f"{k}: {v}\n" + + if response.body_bytes is not None: + description += f"\n==== RESPONSE BODY ====\n{response.body_bytes.decode()}\n" + + if response.trailers is not None: + description += "\n==== RESPONSE TRAILERS ====\n" + for k, v in response.trailers: + if isinstance(k, bytes): + k, v = (k.decode('ascii'), v.decode('ascii')) + description += f"{k}: {k}\n" + description += "\n==== END ====\n" + logging.info(description) + + +def alpn_callback(conn, protos): + """The OpenSSL callback for selecting the protocol.""" + if b'h2' in protos: + return b'h2' + + raise RuntimeError("No acceptable protocol offered!") + + +def servername_callback(conn): + """The OpenSSL callback for inspecting the SNI.""" + sni = conn.get_servername() + conn.set_app_data({'sni': sni}) + logging.info(f"Got SNI from client: {sni}") + + +def run_server(listen_port, https_pem, ca_pem) -> List[int]: + """Run the HTTP/2 server. + + :param listen_port: The port to listen on. + :param https_pem: The path to the certificate key. + :param ca_pem: The path to the CA certificate. + + :return: The list of DATA frame delays. + """ + options = (SSL.OP_NO_COMPRESSION | SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_NO_TLSv1 | SSL.OP_NO_TLSv1_1) + context = SSL.Context(SSL.TLSv1_2_METHOD) + context.set_options(options) + context.set_verify(SSL.VERIFY_NONE, lambda *args: True) + context.use_privatekey_file(https_pem) + context.use_certificate_file(https_pem) + context.set_alpn_select_callback(alpn_callback) + context.set_tlsext_servername_callback(servername_callback) + context.set_cipher_list("RSA+AESGCM".encode()) + context.set_tmp_ecdh(crypto.get_elliptic_curve('prime256v1')) + + listening_socket = eventlet.listen(('0.0.0.0', listen_port)) + listening_socket = SSL.Connection(context, listening_socket) + logging.info(f"Serving HTTP/2 Proxy on 127.0.0.1:{listen_port} with pem '{https_pem}'") + pool = eventlet.GreenPool() + + while True: + try: + new_connection_socket, _ = listening_socket.accept() + manager = Http2ConnectionManager(new_connection_socket) + manager.cert_file = https_pem + manager.ca_file = ca_pem + pool.spawn_n(manager.run_forever) + except KeyboardInterrupt as e: + logging.debug("Handling KeyboardInterrupt") + return manager.get_data_delays() + except SystemExit: + break + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument('listen_port', type=int, help='Port to listen on.') + parser.add_argument('cert_key', type=str, help='Path to the certificate key.') + parser.add_argument('ca_cert', type=str, help='Path to the CA certificate.') + parser.add_argument('write_timeout', type=int, help='The timeout between sending frames.') + parser.add_argument('-v', '--verbose', action='store_true', help='Enable verbose logging.') + return parser.parse_args() + + +def main() -> int: + """Start the HTTP/2 server.""" + args = parse_args() + + log_level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig(level=log_level, format='%(asctime)s - %(levelname)s - %(message)s') + + data_delays = run_server(args.listen_port, args.cert_key, args.ca_cert) + logging.info(f'Smallest delay: {min(data_delays)} ms') + logging.info(f'Largest delay: {max(data_delays)} ms') + average = statistics.mean(data_delays) + logging.info(f'Average delay over {len(data_delays)} reads: {average} ms') + isclose = math.isclose(average, args.write_timeout, rel_tol=0.2) + if isclose: + logging.info(f'Average delay of {average} is within 20% of the expected delay: {args.write_timeout} ms') + else: + logging.info(f'Average delay of {average} is not within 20% of the expected delay: {args.write_timeout} ms') + return 0 if isclose else 1 + + +if __name__ == '__main__': + sys.exit(main())