diff --git a/proxy/http/HttpSM.cc b/proxy/http/HttpSM.cc index 11b84888f98..5fcf0d2cf40 100644 --- a/proxy/http/HttpSM.cc +++ b/proxy/http/HttpSM.cc @@ -5814,7 +5814,7 @@ HttpSM::handle_server_setup_error(int event, void *data) ink_release_assert(0); } - if (event == VC_EVENT_INACTIVITY_TIMEOUT || event == VC_EVENT_ERROR) { + if (event == VC_EVENT_INACTIVITY_TIMEOUT || event == VC_EVENT_ERROR || event == VC_EVENT_EOS) { // Clean up the vc_table entry so any events in play to the timed out server vio // don't get handled. The connection isn't there. if (server_entry) { diff --git a/tests/gold_tests/slow_post/__init__.py b/tests/gold_tests/slow_post/__init__.py new file mode 100644 index 00000000000..13db3e8393d --- /dev/null +++ b/tests/gold_tests/slow_post/__init__.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/gold_tests/slow_post/http_utils.py b/tests/gold_tests/slow_post/http_utils.py new file mode 100644 index 00000000000..b4ffa76da63 --- /dev/null +++ b/tests/gold_tests/slow_post/http_utils.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +"""Common logic between the ad hoc client and server.""" + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import socket + + +def wait_for_headers_complete(sock: socket.socket) -> bytes: + """Wait for the headers to be complete. + + :param sock: The socket to read from. + :returns: The bytes read off the socket. + """ + headers = b"" + while True: + data = sock.recv(1024) + if not data: + print("Socket closed.") + break + print(f'Received:\n{data}') + headers += data + if b"\r\n\r\n" in headers: + break + return headers + + +def determine_outstanding_bytes_to_read(read_bytes: bytes) -> int: + """Determine how many more bytes to read from the headers. + + This parses the Content-Length header to determine how many more bytes to + read. + + :param read_bytes: The bytes read so far. + :returns: The number of bytes to read, or -1 if it is chunked encoded. + """ + headers = read_bytes.decode("utf-8").split("\r\n") + content_length_value = None + for header in headers: + if header.lower().startswith("content-length:"): + content_length_value = int(header.split(":")[1].strip()) + elif header.lower().startswith("transfer-encoding: chunked"): + return -1 + if content_length_value is None: + raise ValueError("No Content-Length header found.") + + end_of_headers = read_bytes.find(b"\r\n\r\n") + if end_of_headers == -1: + raise ValueError("No end of headers found.") + + end_of_headers += 4 + return content_length_value - (len(read_bytes) - end_of_headers) + + +def drain_socket( + sock: socket.socket, + previously_read_data: bytes, + num_bytes_to_drain: int) -> None: + """Read the rest of the request. + + :param sock: The socket to drain. + :param num_bytes_to_drain: The number of bytes to drain. If -1, then drain + bytes until the final zero-length chunk is read. + """ + + read_data = previously_read_data + num_bytes_drained = 0 + while True: + if num_bytes_to_drain > 0: + if num_bytes_drained >= num_bytes_to_drain: + break + elif b'0\r\n\r\n' == read_data[-5:]: + print("Found end of chunked data.") + break + + data = sock.recv(1024) + print(f'Received:\n{data}') + if not data: + print("Socket closed.") + break + num_bytes_drained += len(data) + read_data += data diff --git a/tests/gold_tests/slow_post/quick_server.py b/tests/gold_tests/slow_post/quick_server.py new file mode 100644 index 00000000000..7a5140a0b2e --- /dev/null +++ b/tests/gold_tests/slow_post/quick_server.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +"""A server that replies without waiting for the entire request.""" + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http_utils import (wait_for_headers_complete, + determine_outstanding_bytes_to_read, + drain_socket) + +import argparse +import socket +import sys + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "address", + help="Address to listen on") + parser.add_argument( + "port", + type=int, + default=8080, + help="The port to listen on") + parser.add_argument( + '--drain-request', + action='store_true', + help="Drain the entire request before closing the connection") + parser.add_argument( + '--abort-response-headers', + action='store_true', + help="Abort the response in the midst of sending the response headers") + return parser.parse_args() + + +def get_listening_socket(address: str, port: int) -> socket.socket: + """Create a listening socket. + + :param address: The address to listen on. + :param port: The port to listen on. + :returns: A listening socket. + """ + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((address, port)) + sock.listen(1) + return sock + + +def accept_connection(sock: socket.socket) -> socket.socket: + """Accept a connection. + + :param sock: The socket to accept a connection on. + :returns: The accepted socket. + """ + return sock.accept()[0] + + +def send_response(sock: socket.socket, abort_early: bool) -> None: + """Send an HTTP response. + + :param sock: The socket to write to. + :param abort_early: If True, abort the response before sending the body. + """ + if abort_early: + response = "HTTP/1." + else: + response = ( + r"HTTP/1.1 200 OK\r\n" + r"Content-Length: 0\r\n" + r"\r\n" + ) + print(f'Sending:\n{response}') + sock.sendall(response.encode("utf-8")) + + +def main() -> int: + """Run the server.""" + args = parse_args() + + # Configure a listening socket on args.address and args.port. + with get_listening_socket(args.address, args.port) as listening_sock: + print(f"Listening on {args.address}:{args.port}") + + read_bytes: bytes = b"" + while len(read_bytes) == 0: + with accept_connection(listening_sock) as sock: + read_bytes = wait_for_headers_complete(sock) + if len(read_bytes) == 0: + # This is probably the PortOpenv4 test. Try again. + print("No bytes read on this connection. Trying again.") + sock.close() + continue + + # Send a response now, before headers are read. This implements + # the "quick" attribute of this quick_server. + send_response(sock, args.abort_response_headers) + + if args.abort_response_headers: + # We're done. + break + + if args.drain_request: + num_bytes_to_drain = determine_outstanding_bytes_to_read( + read_bytes) + print(f'Read {len(read_bytes)} bytes. ' + f'Draining {num_bytes_to_drain} bytes.') + drain_socket(sock, read_bytes, num_bytes_to_drain) + else: + print(f'Read {len(read_bytes)} bytes. ' + f'Not draining per configuration.') + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/gold_tests/slow_post/quick_server.test.py b/tests/gold_tests/slow_post/quick_server.test.py new file mode 100644 index 00000000000..763abce385d --- /dev/null +++ b/tests/gold_tests/slow_post/quick_server.test.py @@ -0,0 +1,132 @@ +"""Verify ATS handles a server that replies before receiving a full request.""" + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ports import get_port +import sys + + +Test.Summary = __doc__ + + +class QuickServerTest: + """Verify that ATS doesn't delay responses behind slow posts.""" + + _init_file = '__init__.py' + _http_utils = 'http_utils.py' + _slow_post_client = 'slow_post_client.py' + _quick_server = 'quick_server.py' + + _dns_counter = 0 + _server_counter = 0 + _ts_counter = 0 + + def __init__(self, abort_request: bool, drain_request: bool, abort_response_headers: bool): + """Initialize the test. + + :param drain_request: Whether the server should drain the request body. + :param abort_request: Whether the client should abort the request body. + before disconnecting. + """ + self._should_drain_request = drain_request + self._should_abort_request = abort_request + self._should_abort_response_headers = abort_response_headers + + def _configure_dns(self, tr: 'TestRun') -> None: + """Configure the DNS. + + :param tr: The test run to associate with the DNS process with. + """ + self._dns = tr.MakeDNServer( + f'dns-{QuickServerTest._dns_counter}', + default='127.0.0.1') + QuickServerTest._dns_counter += 1 + + def _configure_server(self, tr: 'TestRun'): + """Configure the origin server. + + This server replies with a response immediately after receiving the + request headers. + + :param tr: The test run to associate with the server process with. + """ + server = tr.Processes.Process(f'server-{QuickServerTest._server_counter}') + QuickServerTest._server_counter += 1 + port = get_port(server, "http_port") + server.Command = \ + f'{sys.executable} {self._quick_server} 127.0.0.1 {port} ' + if self._should_drain_request: + server.Command += '--drain-request ' + if self._should_abort_response_headers: + server.Command += '--abort-response-headers ' + server.Ready = When.PortOpenv4(port) + self._server = server + + def _configure_traffic_server(self, tr: 'TestRun'): + """Configure ATS. + + :param tr: The test run to associate with the ATS process with. + """ + self._ts = tr.MakeATSProcess(f'ts-{QuickServerTest._ts_counter}') + QuickServerTest._ts_counter += 1 + self._ts.Disk.remap_config.AddLine( + f'map / http://quick.server.com:{self._server.Variables.http_port}' + ) + self._ts.Disk.records_config.update({ + 'proxy.config.diags.debug.enabled': 1, + 'proxy.config.diags.debug.tags': 'http|dns|hostdb', + 'proxy.config.dns.nameservers': f'127.0.0.1:{self._dns.Variables.Port}', + 'proxy.config.dns.resolv_conf': 'NULL', + }) + + def run(self): + """Run the test.""" + tr = Test.AddTestRun() + + self._configure_dns(tr) + self._configure_server(tr) + self._configure_traffic_server(tr) + + tr.Setup.CopyAs(self._init_file, Test.RunDirectory) + tr.Setup.CopyAs(self._http_utils, Test.RunDirectory) + tr.Setup.CopyAs(self._slow_post_client, Test.RunDirectory) + tr.Setup.CopyAs(self._quick_server, Test.RunDirectory) + + client_command = ( + f'{sys.executable} {self._slow_post_client} ' + '127.0.0.1 ' + f'{self._ts.Variables.port} ' + ) + if not self._should_abort_request: + client_command += '--finish-request ' + tr.Processes.Default.Command = client_command + + tr.Processes.Default.ReturnCode = 0 + self._ts.StartBefore(self._dns) + self._ts.StartBefore(self._server) + tr.Processes.Default.StartBefore(self._ts) + tr.Timeout = 10 + + +for abort_request in [True, False]: + for drain_request in [True, False]: + for abort_response_headers in [True, False]: + test = QuickServerTest( + abort_request, + drain_request, + abort_response_headers) + test.run() diff --git a/tests/gold_tests/slow_post/slow_post.test.py b/tests/gold_tests/slow_post/slow_post.test.py index b2c4262dbc1..4e04f936a73 100644 --- a/tests/gold_tests/slow_post/slow_post.test.py +++ b/tests/gold_tests/slow_post/slow_post.test.py @@ -27,7 +27,7 @@ class SlowPostAttack: def __init__(cls): Test.Summary = 'Test how ATS handles the slow-post attack' cls._origin_max_connections = 3 - cls._slow_post_client = 'slow_post_client.py' + cls._slow_post_client = 'slow_post_clients.py' cls.setupOriginServer() cls.setupTS() cls._ts.Setup.CopyAs(cls._slow_post_client, Test.RunDirectory) diff --git a/tests/gold_tests/slow_post/slow_post_client.py b/tests/gold_tests/slow_post/slow_post_client.py index f84796e05f7..4340314a494 100644 --- a/tests/gold_tests/slow_post/slow_post_client.py +++ b/tests/gold_tests/slow_post/slow_post_client.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 +"""Implements a client which slowly POSTs a request.""" -''' -''' # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -18,44 +17,128 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time -import threading -import requests +from http_utils import (wait_for_headers_complete, + determine_outstanding_bytes_to_read, + drain_socket) + import argparse +import socket +import sys + +def parse_args() -> argparse.Namespace: + """Parse the command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "proxy_address", + help="Address of the proxy to connect to.") + parser.add_argument( + "proxy_port", + type=int, + help="The port of the proxy to connect to.") + parser.add_argument( + '-s', '--server-hostname', + dest="server_hostname", + default="some.server.com", + help="The hostname of the server to connect to.") + parser.add_argument( + "-t", "--send_time", + dest="send_time", + type=int, + default=3, + help="The number of seconds to send the POST.") + parser.add_argument( + '--finish-request', + dest="finish_request", + action='store_true', + help="Finish sending the request before closing the connection.") -def gen(slow_time): - for _ in range(slow_time): - yield b'a' - time.sleep(1) + return parser.parse_args() -def slow_post(port, slow_time): - requests.post('http://127.0.0.1:{0}/'.format(port, ), data=gen(slow_time)) +def open_connection(address: str, port: int) -> socket.socket: + """Open a connection to the desired host. + :param address: The address of the host to connect to. + :param port: The port of the host to connect to. + :return: The socket connected to the host. + """ + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect((address, port)) + print(f'Connected to {address}:{port}') + return sock -def makerequest(port, connection_limit): - client_timeout = 3 - for _ in range(connection_limit): - t = threading.Thread(target=slow_post, args=(port, client_timeout + 10)) - t.daemon = True - t.start() - time.sleep(1) - r = requests.get('http://127.0.0.1:{0}/'.format(port,)) - print(r.status_code) +def send_slow_post( + sock: socket.socket, + server_hostname: str, + send_time: int, + finish_request: bool) -> None: + """Send a slow POST request. -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--port", "-p", - type=int, - help="Port to use") - parser.add_argument("--connectionlimit", "-c", - type=int, - help="connection limit") - args = parser.parse_args() - makerequest(args.port, args.connectionlimit) + :param sock: The socket to send the request on. + :param server_hostname: The hostname of the server to connect to. + :param send_time: The number of seconds to send the request. + :param finish_request: Whether to finish sending the request before closing + the connection. + """ + # Send the POST request. + host_header = f'Host: {server_hostname}\r\n'.encode() + request = ( + b"POST / HTTP/1.1\r\n" + + host_header + + b"Transfer-Encoding: chunked\r\n" + b"\r\n") + sock.sendall(request) + print('Sent request headers:') + print(request.decode()) + + print(f'Sending POST body for {send_time} seconds.') + counter = 0 + while counter < send_time: + # Send zero padded hex string of the counter. + chunk = f'8\r\n{counter:08x}\r\n'.encode() + sock.sendall(chunk) + print(f'Sent chunk: {chunk.decode()}') + counter += 1 + + if finish_request: + # Send the last chunk. + sock.sendall(b'0\r\n\r\n') + else: + print('Aborting the request before sending the last chunk.') + sock.close() + + +def drain_response(sock: socket.socket) -> None: + """Drain the response from the server. + + :param sock: The socket to read the response from. + """ + print('Waiting for the response to complete.') + read_bytes = wait_for_headers_complete(sock) + num_bytes_to_drain = determine_outstanding_bytes_to_read(read_bytes) + drain_socket(sock, read_bytes, num_bytes_to_drain) + print('Response complete.') + + +def main() -> int: + """Run the client.""" + args = parse_args() + print(args) + + with open_connection(args.proxy_address, args.proxy_port) as sock: + send_slow_post( + sock, + args.server_hostname, + args.send_time, + args.finish_request) + + if args.finish_request: + drain_response(sock) + print('Done.') + return 0 if __name__ == '__main__': - main() + sys.exit(main()) diff --git a/tests/gold_tests/slow_post/slow_post_clients.py b/tests/gold_tests/slow_post/slow_post_clients.py new file mode 100644 index 00000000000..f84796e05f7 --- /dev/null +++ b/tests/gold_tests/slow_post/slow_post_clients.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 + +''' +''' +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import threading +import requests +import argparse + + +def gen(slow_time): + for _ in range(slow_time): + yield b'a' + time.sleep(1) + + +def slow_post(port, slow_time): + requests.post('http://127.0.0.1:{0}/'.format(port, ), data=gen(slow_time)) + + +def makerequest(port, connection_limit): + client_timeout = 3 + for _ in range(connection_limit): + t = threading.Thread(target=slow_post, args=(port, client_timeout + 10)) + t.daemon = True + t.start() + time.sleep(1) + r = requests.get('http://127.0.0.1:{0}/'.format(port,)) + print(r.status_code) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--port", "-p", + type=int, + help="Port to use") + parser.add_argument("--connectionlimit", "-c", + type=int, + help="connection limit") + args = parser.parse_args() + makerequest(args.port, args.connectionlimit) + + +if __name__ == '__main__': + main()