diff --git a/iocore/net/ProxyProtocol.cc b/iocore/net/ProxyProtocol.cc index 0b993557c42..c18f563b886 100644 --- a/iocore/net/ProxyProtocol.cc +++ b/iocore/net/ProxyProtocol.cc @@ -356,6 +356,7 @@ proxy_protocol_v1_build(uint8_t *buf, size_t max_buf_len, const ProxyProtocol &p bw.fill(len); } + Debug("proxyprotocol_v1", "Proxy Protocol v1: %.*s", static_cast(bw.size()), bw.data()); bw.write("\r\n"); return bw.size(); @@ -441,6 +442,7 @@ proxy_protocol_v2_build(uint8_t *buf, size_t max_buf_len, const ProxyProtocol &p // Set len field (number of following bytes part of the header) in the hdr uint16_t len = htons(bw.size() - PPv2_CONNECTION_HEADER_LEN); memcpy(buf + len_field_offset, &len, sizeof(uint16_t)); + Debug("proxyprotocol_v2", "Proxy Protocol v2 of %zu bytes", bw.size()); return bw.size(); } diff --git a/proxy/http/HttpSM.cc b/proxy/http/HttpSM.cc index 3eded703959..97519bea861 100644 --- a/proxy/http/HttpSM.cc +++ b/proxy/http/HttpSM.cc @@ -127,9 +127,10 @@ do_outbound_proxy_protocol(MIOBuffer *miob, NetVConnection *vc_out, NetVConnecti // nothing to forward return 0; } else { + Debug("proxyprotocol", "vc_in had no Proxy Protocol. Manufacturing from the vc_in socket."); // set info from incoming NetVConnection IpEndpoint local = vc_in->get_local_endpoint(); - info = ProxyProtocol{pp_version, local.family(), local, vc_in->get_remote_endpoint()}; + info = ProxyProtocol{pp_version, local.family(), vc_in->get_remote_endpoint(), local}; } } @@ -6982,7 +6983,13 @@ HttpSM::setup_blind_tunnel(bool send_response_hdr, IOBufferReader *initial) client_response_hdr_bytes = 0; } - client_request_body_bytes = 0; + int64_t nbytes = 0; + if (t_state.txn_conf->proxy_protocol_out >= 0) { + nbytes = do_outbound_proxy_protocol(from_ua_buf, static_cast(server_entry->vc), ua_txn->get_netvc(), + t_state.txn_conf->proxy_protocol_out); + } + + client_request_body_bytes = nbytes; if (ua_raw_buffer_reader != nullptr) { client_request_body_bytes += from_ua_buf->write(ua_raw_buffer_reader, client_request_hdr_bytes); ua_raw_buffer_reader->dealloc(); diff --git a/tests/gold_tests/autest-site/trafficserver.test.ext b/tests/gold_tests/autest-site/trafficserver.test.ext index 43f768ef184..1fe9eed8d41 100755 --- a/tests/gold_tests/autest-site/trafficserver.test.ext +++ b/tests/gold_tests/autest-site/trafficserver.test.ext @@ -41,7 +41,7 @@ default_log_data = { def MakeATSProcess(obj, name, command='traffic_server', select_ports=True, enable_tls=False, enable_cache=True, enable_quic=False, block_for_debug=False, log_data=default_log_data, - use_traffic_out=True): + use_traffic_out=True, enable_proxy_protocol=False): ##################################### # common locations @@ -328,6 +328,14 @@ def MakeATSProcess(obj, name, command='traffic_server', select_ports=True, if enable_tls: get_port(p, "ssl_port") get_port(p, "ssl_portv6") + + if enable_proxy_protocol: + get_port(p, "proxy_protocol_port") + get_port(p, "proxy_protocol_portv6") + + if enable_tls: + get_port(p, "proxy_protocol_ssl_port") + get_port(p, "proxy_protocol_ssl_portv6") else: p.Variables.port = 8080 p.Variables.portv6 = 8080 @@ -382,6 +390,10 @@ def MakeATSProcess(obj, name, command='traffic_server', select_ports=True, if enable_quic: port_str += " {ssl_port}:quic {ssl_portv6}:quic:ipv6".format( ssl_port=p.Variables.ssl_port, ssl_portv6=p.Variables.ssl_portv6) + if enable_proxy_protocol: + port_str += f" {p.Variables.proxy_protocol_port}:pp {p.Variables.proxy_protocol_portv6}:pp:ipv6" + if enable_tls: + port_str += f" {p.Variables.proxy_protocol_ssl_port}:pp:ssl {p.Variables.proxy_protocol_ssl_portv6}:pp:ssl:ipv6" #p.Env['PROXY_CONFIG_HTTP_SERVER_PORTS'] = port_str p.Disk.records_config.update({ 'proxy.config.http.server_ports': port_str, diff --git a/tests/gold_tests/proxy_protocol/proxy_protocol.test.py b/tests/gold_tests/proxy_protocol/proxy_protocol.test.py index 4125ba03ca9..ed58315f604 100644 --- a/tests/gold_tests/proxy_protocol/proxy_protocol.test.py +++ b/tests/gold_tests/proxy_protocol/proxy_protocol.test.py @@ -17,6 +17,7 @@ # limitations under the License. import os +from ports import get_port import sys Test.Summary = 'Test PROXY Protocol' @@ -27,6 +28,8 @@ class ProxyProtocolTest: + """Test that ATS can receive Proxy Protocol.""" + def __init__(self): self.setupOriginServer() self.setupTS() @@ -39,7 +42,7 @@ def setupOriginServer(self): ''' def setupTS(self): - self.ts = Test.MakeATSProcess("ts", enable_tls=True, enable_cache=False) + self.ts = Test.MakeATSProcess("ts_in", enable_tls=True, enable_cache=False, enable_proxy_protocol=True) self.ts.addDefaultSSLFiles() self.ts.Disk.ssl_multicert_config.AddLine("dest_ip=* ssl_cert_name=server.pem ssl_key_name=server.key") @@ -48,7 +51,6 @@ def setupTS(self): f"map / http://127.0.0.1:{self.httpbin.Variables.Port}/") self.ts.Disk.records_config.update({ - "proxy.config.http.server_ports": f"{self.ts.Variables.port}:pp {self.ts.Variables.ssl_port}:ssl:pp", "proxy.config.http.proxy_protocol_allowlist": "127.0.0.1", "proxy.config.http.insert_forwarded": "for|by=ip|proto", "proxy.config.ssl.server.cert.path": f"{self.ts.Variables.SSLDir}", @@ -76,7 +78,7 @@ def addTestCase0(self): tr = Test.AddTestRun() tr.Processes.Default.StartBefore(self.httpbin) tr.Processes.Default.StartBefore(self.ts) - tr.Processes.Default.Command = f"curl -vs --haproxy-protocol http://localhost:{self.ts.Variables.port}/get | {self.json_printer}" + tr.Processes.Default.Command = f"curl -vs --haproxy-protocol http://localhost:{self.ts.Variables.proxy_protocol_port}/get | {self.json_printer}" tr.Processes.Default.ReturnCode = 0 tr.Processes.Default.Streams.stdout = "gold/test_case_0_stdout.gold" tr.Processes.Default.Streams.stderr = "gold/test_case_0_stderr.gold" @@ -88,7 +90,7 @@ def addTestCase1(self): Incoming PROXY Protocol v1 on SSL port """ tr = Test.AddTestRun() - tr.Processes.Default.Command = f"curl -vsk --haproxy-protocol --http1.1 https://localhost:{self.ts.Variables.ssl_port}/get | {self.json_printer}" + tr.Processes.Default.Command = f"curl -vsk --haproxy-protocol --http1.1 https://localhost:{self.ts.Variables.proxy_protocol_ssl_port}/get | {self.json_printer}" tr.Processes.Default.ReturnCode = 0 tr.Processes.Default.Streams.stdout = "gold/test_case_1_stdout.gold" tr.Processes.Default.Streams.stderr = "gold/test_case_1_stderr.gold" @@ -100,7 +102,7 @@ def addTestCase2(self): Test with netcat """ tr = Test.AddTestRun() - tr.Processes.Default.Command = f"echo 'PROXY TCP4 198.51.100.1 198.51.100.2 51137 80\r\nGET /get HTTP/1.1\r\nHost: 127.0.0.1:80\r\n' | nc localhost {self.ts.Variables.port}" + tr.Processes.Default.Command = f"echo 'PROXY TCP4 198.51.100.1 198.51.100.2 51137 80\r\nGET /get HTTP/1.1\r\nHost: 127.0.0.1:80\r\n' | nc localhost {self.ts.Variables.proxy_protocol_port}" tr.Processes.Default.ReturnCode = 0 tr.Processes.Default.Streams.stdout = "gold/test_case_2_stdout.gold" tr.StillRunningAfter = self.httpbin @@ -127,4 +129,160 @@ def run(self): self.addTestCase99() +class ProxyProtocolOutTest: + """Test that ATS can send Proxy Protocol.""" + + _pp_server = 'proxy_protocol_server.py' + + _dns_counter = 0 + _server_counter = 0 + _ts_counter = 0 + + def __init__(self, pp_version: int, is_tunnel: bool) -> None: + """Initialize a ProxyProtocolOutTest. + + :param pp_version: The Proxy Protocol version to use (1 or 2). + :param is_tunnel: Whether ATS should tunnel to the origin. + """ + + if pp_version not in (-1, 1, 2): + raise ValueError( + f'Invalid Proxy Protocol version (not 1 or 2): {pp_version}') + self._pp_version = pp_version + self._is_tunnel = is_tunnel + + def setupOriginServer(self, tr: 'TestRun') -> None: + """Configure the origin server. + + :param tr: The TestRun to associate the origin's Process with. + """ + tr.Setup.CopyAs(self._pp_server, tr.RunDirectory) + cert_file = os.path.join(Test.Variables.AtsTestToolsDir, "ssl", "server.pem") + key_file = os.path.join(Test.Variables.AtsTestToolsDir, "ssl", "server.key") + tr.Setup.Copy(cert_file) + tr.Setup.Copy(key_file) + server = tr.Processes.Process( + f'server-{ProxyProtocolOutTest._server_counter}') + ProxyProtocolOutTest._server_counter += 1 + server_port = get_port(server, "external_port") + internal_port = get_port(server, "internal_port") + command = ( + f'{sys.executable} {self._pp_server} ' + f'server.pem server.key 127.0.0.1 {server_port} {internal_port}') + if not self._is_tunnel: + command += ' --plaintext' + server.Command = command + server.Ready = When.PortOpenv4(server_port) + + self._server = server + + def setupDNS(self, tr: 'TestRun') -> None: + """Configure the DNS server. + + :param tr: The TestRun to associate the DNS's Process with. + """ + self._dns = tr.MakeDNServer( + f'dns-{ProxyProtocolOutTest._dns_counter}', + default='127.0.0.1') + ProxyProtocolOutTest._dns_counter += 1 + + def setupTS(self, tr: 'TestRun') -> None: + """Configure Traffic Server.""" + process_name = f'ts-out-{ProxyProtocolOutTest._ts_counter}' + ProxyProtocolOutTest._ts_counter += 1 + self._ts = tr.MakeATSProcess(process_name, enable_tls=True, + enable_cache=False) + + self._ts.addDefaultSSLFiles() + self._ts.Disk.ssl_multicert_config.AddLine( + "dest_ip=* ssl_cert_name=server.pem ssl_key_name=server.key" + ) + + self._ts.Disk.remap_config.AddLine( + f"map / http://backend.pp.origin.com:{self._server.Variables.external_port}/") + + self._ts.Disk.records_config.update({ + "proxy.config.ssl.server.cert.path": f"{self._ts.Variables.SSLDir}", + "proxy.config.ssl.server.private_key.path": f"{self._ts.Variables.SSLDir}", + "proxy.config.diags.debug.enabled": 1, + "proxy.config.diags.debug.tags": "http|proxyprotocol", + "proxy.config.http.proxy_protocol_out": self._pp_version, + "proxy.config.dns.nameservers": f"127.0.0.1:{self._dns.Variables.Port}", + "proxy.config.dns.resolv_conf": 'NULL' + }) + + if self._is_tunnel: + self._ts.Disk.records_config.update({ + "proxy.config.http.connect_ports": f'{self._server.Variables.external_port}', + }) + + self._ts.Disk.sni_yaml.AddLines([ + 'sni:', + '- fqdn: pp.origin.com', + f' tunnel_route: backend.pp.origin.com:{self._server.Variables.external_port}', + ]) + + def setLogExpectations(self, tr: 'TestRun') -> None: + + tr.Processes.Default.Streams.All += Testers.ContainsExpression( + "HTTP/1.1 200 OK", + "Verify that curl got a 200 response") + + if self._pp_version in (1, 2): + expected_pp = ( + 'PROXY TCP4 127.0.0.1 127.0.0.1 ' + rf'\d+ {self._ts.Variables.ssl_port}' + ) + self._server.Streams.All += Testers.ContainsExpression( + expected_pp, + "Verify the server got the expected Proxy Protocol string.") + + self._server.Streams.All += Testers.ContainsExpression( + f'Received Proxy Protocol v{self._pp_version}', + "Verify the server got the expected Proxy Protocol version.") + + if self._pp_version == -1: + self._server.Streams.All += Testers.ContainsExpression( + 'No Proxy Protocol string found', + 'There should be no Proxy Protocol string.') + + def run(self) -> None: + """Run the test.""" + description = f'Proxy Protocol v{self._pp_version} ' + if self._is_tunnel: + description += "with blind tunneling" + else: + description += "without blind tunneling" + tr = Test.AddTestRun(description) + + self.setupDNS(tr) + self.setupOriginServer(tr) + self.setupTS(tr) + + self._ts.StartBefore(self._server) + self._ts.StartBefore(self._dns) + tr.Processes.Default.StartBefore(self._ts) + + origin = f'pp.origin.com:{self._ts.Variables.ssl_port}' + command = ( + 'sleep1; curl -vsk --http1.1 ' + f'--resolve "{origin}:127.0.0.1" ' + f'https://{origin}/get' + ) + + tr.Processes.Default.Command = command + tr.Processes.Default.ReturnCode = 0 + # Its only one transaction, so this should complete quickly. The test + # server often hangs if there are issues parsing the Proxy Protocol + # string. + tr.TimeOut = 5 + self.setLogExpectations(tr) + + ProxyProtocolTest().run() + +ProxyProtocolOutTest(pp_version=-1, is_tunnel=False).run() +ProxyProtocolOutTest(pp_version=1, is_tunnel=False).run() +ProxyProtocolOutTest(pp_version=2, is_tunnel=False).run() +ProxyProtocolOutTest(pp_version=1, is_tunnel=True).run() +ProxyProtocolOutTest(pp_version=2, is_tunnel=True).run() diff --git a/tests/gold_tests/proxy_protocol/proxy_protocol_server.py b/tests/gold_tests/proxy_protocol/proxy_protocol_server.py new file mode 100644 index 00000000000..af1a52820f2 --- /dev/null +++ b/tests/gold_tests/proxy_protocol/proxy_protocol_server.py @@ -0,0 +1,381 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A simple server that expects and prints out the Proxy Protocol string.""" + +import argparse +import logging +import socket +import ssl +import struct +import sys +import threading + + +# Set a 10ms timeout for socket operations. +TIMEOUT = .010 + +PP_V2_PREFIX = b'\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a' + + +# Create a condition variable for thread initialization. +internal_thread_is_ready = threading.Condition() + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description=__doc__) + + parser.add_argument( + "certfile", + help="The path to the certificate file to use for TLS.") + parser.add_argument( + "keyfile", + help="The path to the key file to use for TLS.") + parser.add_argument( + "address", + help="The IP address to listen on.") + parser.add_argument( + "port", + type=int, + help="The port to listen on.") + parser.add_argument( + "internal_port", + type=int, + help="The internal port used to parse the TLS content.") + parser.add_argument( + "--plaintext", + action="store_true", + help="Listen for plaintext connections instead of TLS.") + + return parser.parse_args() + + +def receive_and_send_http(sock: socket.socket) -> None: + """Receive and send an HTTP request and response. + + :param sock: The socket to receive the request on. + """ + sock.settimeout(TIMEOUT) + + # Read the request until the final CRLF is received. + received_request = b'' + while True: + data = None + try: + data = sock.recv(1024) + logging.debug(f'Internal: received {len(data)} bytes') + except socket.timeout: + continue + if not data: + break + received_request += data + + if b'\r\n\r\n' in received_request: + break + logging.info("Received request:") + logging.info(received_request.decode("utf-8")) + + # Send a response. + response = ( + "HTTP/1.1 200 OK\r\n" + "Content-Length: 0\r\n" + "Connection: close\r\n" + "\r\n" + ) + logging.info(f'Sending:\n{response}') + try: + sock.sendall(response.encode("utf-8")) + except socket.timeout: + logging.error("Timeout sending a response.") + + +def run_internal_server(cert_file: str, key_file: str, + address: str, port: int, + plaintext: bool) -> None: + """Run the internal server. + + This server is receives the HTTP content with the Proxy Protocol prefix + stripped off by the client. + + :param cert_file: The path to the certificate file to use for TLS. + :param key_file: The path to the key file to use for TLS. + :param address: The IP address to listen on. + :param port: The port to listen on. + :param plaintext: Whether to listen for HTTP rather than HTTPS traffic. + """ + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((address, port)) + sock.listen() + logging.info(f"Internal HTTPS server listening on {address}:{port}") + + if plaintext: + # Notify the waiting thread that the internal server is ready. + with internal_thread_is_ready: + internal_thread_is_ready.notify() + conn, addr_in = sock.accept() + logging.info(f"Internal server accepted plaintext connection from {addr_in}") + with conn: + receive_and_send_http(conn) + else: + # Wrap the server socket to handle TLS. + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + context.load_cert_chain(certfile=cert_file, keyfile=key_file) + + with context.wrap_socket(sock, server_side=True) as ssock: + with internal_thread_is_ready: + internal_thread_is_ready.notify() + conn, addr_in = ssock.accept() + logging.info(f"Internal server accepted TLS connection from {addr_in}") + with conn: + receive_and_send_http(conn) + + +def parse_pp_v1(pp_bytes: bytes) -> int: + """Parse and print the Proxy Protocol v1 string. + + :param pp_bytes: The bytes containing the Proxy Protocol string. There may + be more bytes than the Proxy Protocol string. + + :returns: The number of bytes occupied by the proxy v1 protcol. + """ + # Proxy Protocol v1 string ends with CRLF. + end = pp_bytes.find(b'\r\n') + if end == -1: + raise ValueError("Proxy Protocol v1 string ending not found") + logging.info(pp_bytes[:end].decode("utf-8")) + return end + 2 + + +def parse_pp_v2(pp_bytes: bytes) -> int: + """Parse and print the Proxy Protocol v2 string. + + :param pp_bytes: The bytes containing the Proxy Protocol string. There may + be more bytes than the Proxy Protocol string. + + :returns: The number of bytes occupied by the proxy v2 protocol string. + """ + + # Skip the 12 byte header. + pp_bytes = pp_bytes[12:] + version_command = pp_bytes[0] + pp_bytes = pp_bytes[1:] + family_protocol = pp_bytes[0] + pp_bytes = pp_bytes[1:] + tuple_length = int.from_bytes(pp_bytes[:2], byteorder='big') + pp_bytes = pp_bytes[2:] + + # Of version_command, the highest 4 bits is the version and the lowest is + # the command. + version = version_command >> 4 + command = version_command & 0x0F + + if version != 2: + raise ValueError( + f'Invalid version: {version} (by spec, should always be 0x02)') + + if command == 0x0: + command_description = 'LOCAL' + elif command == 0x1: + command_description = 'PROXY' + else: + raise ValueError( + f'Invalid command: {command} (by spec, should be 0x00 or 0x01)') + + # Of address_family, the highest 4 bits is the address family and the + # lowest is the transport protocol. + if family_protocol == 0x0: + transport_protocol_description = 'UNSPEC' + elif family_protocol == 0x11: + transport_protocol_description = 'TCP4' + elif family_protocol == 0x12: + transport_protocol_description = 'UDP4' + elif family_protocol == 0x21: + transport_protocol_description = 'TCP6' + elif family_protocol == 0x22: + transport_protocol_description = 'UDP6' + elif family_protocol == 0x31: + transport_protocol_description = 'UNIX_STREAM' + elif family_protocol == 0x32: + transport_protocol_description = 'UNIX_DGRAM' + else: + raise ValueError( + f'Invalid address family: {family_protocol} (by spec, should be ' + '0x00, 0x11, 0x12, 0x21, 0x22, 0x31, or 0x32)') + + if family_protocol in (0x11, 0x12): + if tuple_length != 12: + raise ValueError( + "Unexpected tuple length for TCP4/UDP4: " + f"{tuple_length} (by spec, should be 12)" + ) + src_addr = socket.inet_ntop(socket.AF_INET, pp_bytes[:4]) + pp_bytes = pp_bytes[4:] + dst_addr = socket.inet_ntop(socket.AF_INET, pp_bytes[:4]) + pp_bytes = pp_bytes[4:] + src_port = int.from_bytes(pp_bytes[:2], byteorder='big') + pp_bytes = pp_bytes[2:] + dst_port = int.from_bytes(pp_bytes[:2], byteorder='big') + pp_bytes = pp_bytes[2:] + + tuple_description = f'{src_addr} {dst_addr} {src_port} {dst_port}' + logging.info( + f'{command_description} {transport_protocol_description} ' + f'{tuple_description}') + + return 16 + tuple_length + + +def accept_pp_connection(sock: socket.socket, address: str, internal_port: int) -> bool: + """Accept a connection and parse the proxy protocol header. + + :param sock: The socket to accept the connection on. + :param address: The address of the internal server to connect to. + :param internal_port: The port of the internal server to connect to. + + :returns: True if the connection had a payload, False otherwise. + """ + client_conn, address_in = sock.accept() + logging.info(f'Accepted connection from {address_in}') + with client_conn: + has_pp = False + pp_length = 0 + # Read the Proxy Protocol prefix, which ends with the first CRLF. + received_data = b'' + while True: + data = client_conn.recv(1024) + if data: + logging.debug(f"Received: {len(data)} bytes") + else: + logging.info("No data received while waiting for " + "Proxy Protocol prefix") + return False + received_data += data + + if (received_data.startswith(b'PROXY') and + b'\r\n' in received_data): + logging.info("Received Proxy Protocol v1") + pp_length = parse_pp_v1(received_data) + has_pp = True + break + + if received_data.startswith(PP_V2_PREFIX): + logging.info("Received Proxy Protocol v2") + pp_length = parse_pp_v2(received_data) + has_pp = True + break + + if len(received_data) > 108: + # The spec gaurantees that the prefix will be no more than + # 108 bytes. + logging.info("No Proxy Protocol string found.") + break + if has_pp: + # Now, strip the received_data of the prefix and blind tunnel + # the rest of the content. + for_internal = received_data[pp_length:] + logging.debug( + f"Stripped the prefix, now thare are {len(for_internal)} " + "bytes for the internal server.") + else: + for_internal = received_data + client_conn.settimeout(TIMEOUT) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as internal_sock: + logging.debug(f"Connecting to internal server on {address}:{internal_port}") + internal_sock.connect((address, internal_port)) + internal_sock.settimeout(TIMEOUT) + if for_internal: + logging.debug('Sending remaining data to internal server: ' + f'{len(for_internal)} bytes') + internal_sock.sendall(for_internal) + while True: + + logging.debug("entering loop") + + try: + from_internal = internal_sock.recv(1024) + logging.debug(f'Received {len(from_internal)} bytes from internal server') + if not from_internal: + logging.debug('No more data from internal server, closing connection') + break + client_conn.sendall(from_internal) + logging.debug(f'Sent {len(from_internal)} bytes to client') + except socket.timeout: + pass + + try: + for_internal = client_conn.recv(1024) + logging.debug(f'Received {len(for_internal)} bytes from client') + if not for_internal: + logging.debug('No more data from client, closing connection') + break + internal_sock.sendall(for_internal) + logging.debug(f'Sent {len(for_internal)} bytes to internal server') + except socket.timeout: + pass + + +def receive_pp_request(address: str, port: int, internal_port: int) -> None: + """Start a server to receive a connection which may have a proxy protocol + header. + + :param address: The address to listen on. + :param port: The port to listen on. + :param internal_port: The port of the internal server to connect to. + """ + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((address, port)) + sock.listen() + + # The PortOpen logic will create an empty request to the server. Ignore + # those until we have a connection with a real request which comes in. + request_received = False + while not request_received: + request_received = accept_pp_connection(sock, address, + internal_port) + + +def main() -> int: + """Run the server listening for Proxy Protocol.""" + args = parse_args() + + with internal_thread_is_ready: + """Start the threads to receive requests.""" + internal_server = threading.Thread( + target=run_internal_server, + args=(args.certfile, args.keyfile, args.address, + args.internal_port, args.plaintext)) + internal_server.start() + + # Wait for the internal server to start before proceeding. + internal_thread_is_ready.wait() + + receive_pp_request(args.address, args.port, args.internal_port) + internal_server.join() + + return 0 + + +if __name__ == "__main__": + logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + sys.exit(main())