Skip to content

Commit

Permalink
Top-level notion of work not client (#695)
Browse files Browse the repository at this point in the history
* Top-level notion of work not client

* Update ssl echo server example
  • Loading branch information
abhinavsingh authored Nov 7, 2021
1 parent d3cee32 commit f48771f
Show file tree
Hide file tree
Showing 13 changed files with 76 additions and 73 deletions.
4 changes: 2 additions & 2 deletions examples/https_connect_tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]:

# Drop the request if not a CONNECT request
if self.request.method != httpMethods.CONNECT:
self.client.queue(
self.work.queue(
HttpsConnectTunnelHandler.PROXY_TUNNEL_UNSUPPORTED_SCHEME,
)
return True
Expand All @@ -66,7 +66,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
self.connect_upstream()

# Queue tunnel established response to client
self.client.queue(
self.work.queue(
HttpsConnectTunnelHandler.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT,
)

Expand Down
8 changes: 4 additions & 4 deletions examples/ssl_echo_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@ def initialize(self) -> None:
# here using wrap_socket() utility.
assert self.flags.keyfile is not None and self.flags.certfile is not None
conn = wrap_socket(
self.client.connection,
self.work.connection,
self.flags.keyfile,
self.flags.certfile,
)
conn.setblocking(False)
# Upgrade plain TcpClientConnection to SSL connection object
self.client = TcpClientConnection(
conn=conn, addr=self.client.addr,
self.work = TcpClientConnection(
conn=conn, addr=self.work.addr,
)

def handle_data(self, data: memoryview) -> Optional[bool]:
# echo back to client
self.client.queue(data)
self.work.queue(data)
return None


Expand Down
4 changes: 2 additions & 2 deletions examples/tcp_echo_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ class EchoServerHandler(BaseTcpServerHandler):
"""Sets client socket to non-blocking during initialization."""

def initialize(self) -> None:
self.client.connection.setblocking(False)
self.work.connection.setblocking(False)

def handle_data(self, data: memoryview) -> Optional[bool]:
# echo back to client
self.client.queue(data)
self.work.queue(data)
return None


Expand Down
9 changes: 6 additions & 3 deletions proxy/core/acceptor/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@ class Work(ABC):

def __init__(
self,
client: TcpClientConnection,
work: TcpClientConnection,
flags: argparse.Namespace,
event_queue: Optional[EventQueue] = None,
uid: Optional[UUID] = None,
) -> None:
self.client = client
# Work uuid
self.uid: UUID = uid if uid is not None else uuid4()
self.flags = flags
# Eventing core queue
self.event_queue = event_queue
self.uid: UUID = uid if uid is not None else uuid4()
# Accept work
self.work = work

@abstractmethod
def get_events(self) -> Dict[socket.socket, int]:
Expand Down
34 changes: 17 additions & 17 deletions proxy/core/base/tcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class BaseTcpServerHandler(Work):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.must_flush_before_shutdown = False
logger.debug('Connection accepted from {0}'.format(self.client.addr))
logger.debug('Connection accepted from {0}'.format(self.work.addr))

@abstractmethod
def handle_data(self, data: memoryview) -> Optional[bool]:
Expand All @@ -57,14 +57,14 @@ def get_events(self) -> Dict[socket.socket, int]:
# We always want to read from client
# Register for EVENT_READ events
if self.must_flush_before_shutdown is False:
events[self.client.connection] = selectors.EVENT_READ
events[self.work.connection] = selectors.EVENT_READ
# If there is pending buffer for client
# also register for EVENT_WRITE events
if self.client.has_buffer():
if self.client.connection in events:
events[self.client.connection] |= selectors.EVENT_WRITE
if self.work.has_buffer():
if self.work.connection in events:
events[self.work.connection] |= selectors.EVENT_WRITE
else:
events[self.client.connection] = selectors.EVENT_WRITE
events[self.work.connection] = selectors.EVENT_WRITE
return events

def handle_events(
Expand All @@ -79,32 +79,32 @@ def handle_events(
if teardown:
logger.debug(
'Shutting down client {0} connection'.format(
self.client.addr,
self.work.addr,
),
)
return teardown

def handle_writables(self, writables: Writables) -> bool:
teardown = False
if self.client.connection in writables and self.client.has_buffer():
if self.work.connection in writables and self.work.has_buffer():
logger.debug(
'Flushing buffer to client {0}'.format(self.client.addr),
'Flushing buffer to client {0}'.format(self.work.addr),
)
self.client.flush()
self.work.flush()
if self.must_flush_before_shutdown is True:
if not self.client.has_buffer():
if not self.work.has_buffer():
teardown = True
self.must_flush_before_shutdown = False
return teardown

def handle_readables(self, readables: Readables) -> bool:
teardown = False
if self.client.connection in readables:
data = self.client.recv(self.flags.client_recvbuf_size)
if self.work.connection in readables:
data = self.work.recv(self.flags.client_recvbuf_size)
if data is None:
logger.debug(
'Connection closed by client {0}'.format(
self.client.addr,
self.work.addr,
),
)
teardown = True
Expand All @@ -113,13 +113,13 @@ def handle_readables(self, readables: Readables) -> bool:
if isinstance(r, bool) and r is True:
logger.debug(
'Implementation signaled shutdown for client {0}'.format(
self.client.addr,
self.work.addr,
),
)
if self.client.has_buffer():
if self.work.has_buffer():
logger.debug(
'Client {0} has pending buffer, will be flushed before shutting down'.format(
self.client.addr,
self.work.addr,
),
)
self.must_flush_before_shutdown = True
Expand Down
4 changes: 2 additions & 2 deletions proxy/core/base/tcp_tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
pass # pragma: no cover

def initialize(self) -> None:
self.client.connection.setblocking(False)
self.work.connection.setblocking(False)

def shutdown(self) -> None:
if self.upstream:
Expand Down Expand Up @@ -87,7 +87,7 @@ def handle_events(
print('Connection closed by server')
return True
# tunnel data to client
self.client.queue(data)
self.work.queue(data)
if self.upstream and self.upstream.connection in writables:
self.upstream.flush()
return False
Expand Down
46 changes: 23 additions & 23 deletions proxy/http/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,25 +89,25 @@ def __init__(self, *args: Any, **kwargs: Any):

def initialize(self) -> None:
"""Optionally upgrades connection to HTTPS, set conn in non-blocking mode and initializes plugins."""
conn = self._optionally_wrap_socket(self.client.connection)
conn = self._optionally_wrap_socket(self.work.connection)
conn.setblocking(False)
# Update client connection reference if connection was wrapped
if self._encryption_enabled():
self.client = TcpClientConnection(conn=conn, addr=self.client.addr)
self.work = TcpClientConnection(conn=conn, addr=self.work.addr)
if b'HttpProtocolHandlerPlugin' in self.flags.plugins:
for klass in self.flags.plugins[b'HttpProtocolHandlerPlugin']:
instance: HttpProtocolHandlerPlugin = klass(
self.uid,
self.flags,
self.client,
self.work,
self.request,
self.event_queue,
)
self.plugins[instance.name()] = instance
logger.debug('Handling connection %r' % self.client.connection)
logger.debug('Handling connection %r' % self.work.connection)

def is_inactive(self) -> bool:
if not self.client.has_buffer() and \
if not self.work.has_buffer() and \
self._connection_inactive_for() > self.flags.timeout:
return True
return False
Expand All @@ -127,20 +127,20 @@ def shutdown(self) -> None:
logger.debug(
'Closing client connection %r '
'at address %r has buffer %s' %
(self.client.connection, self.client.addr, self.client.has_buffer()),
(self.work.connection, self.work.addr, self.work.has_buffer()),
)

conn = self.client.connection
conn = self.work.connection
# Unwrap if wrapped before shutdown.
if self._encryption_enabled() and \
isinstance(self.client.connection, ssl.SSLSocket):
conn = self.client.connection.unwrap()
isinstance(self.work.connection, ssl.SSLSocket):
conn = self.work.connection.unwrap()
conn.shutdown(socket.SHUT_WR)
logger.debug('Client connection shutdown successful')
except OSError:
pass
finally:
self.client.connection.close()
self.work.connection.close()
logger.debug('Client connection closed')
super().shutdown()

Expand Down Expand Up @@ -196,7 +196,7 @@ def handle_events(
def handle_data(self, data: memoryview) -> Optional[bool]:
if data is None:
logger.debug('Client closed connection, tearing down...')
self.client.closed = True
self.work.closed = True
return True

try:
Expand Down Expand Up @@ -227,7 +227,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
logger.debug(
'Updated client conn to %s', upgraded_sock,
)
self.client._conn = upgraded_sock
self.work._conn = upgraded_sock
for plugin_ in self.plugins.values():
if plugin_ != plugin:
plugin_.client._conn = upgraded_sock
Expand All @@ -237,20 +237,20 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
logger.debug('HttpProtocolException raised')
response: Optional[memoryview] = e.response(self.request)
if response:
self.client.queue(response)
self.work.queue(response)
return True
return False

def handle_writables(self, writables: Writables) -> bool:
if self.client.connection in writables and self.client.has_buffer():
if self.work.connection in writables and self.work.has_buffer():
logger.debug('Client is ready for writes, flushing buffer')
self.last_activity = time.time()

# TODO(abhinavsingh): This hook could just reside within server recv block
# instead of invoking when flushed to client.
#
# Invoke plugin.on_response_chunk
chunk = self.client.buffer
chunk = self.work.buffer
for plugin in self.plugins.values():
chunk = plugin.on_response_chunk(chunk)
if chunk is None:
Expand All @@ -272,7 +272,7 @@ def handle_writables(self, writables: Writables) -> bool:
return False

def handle_readables(self, readables: Readables) -> bool:
if self.client.connection in readables:
if self.work.connection in readables:
logger.debug('Client is ready for reads, reading')
self.last_activity = time.time()
try:
Expand All @@ -290,7 +290,7 @@ def handle_readables(self, readables: Readables) -> bool:
else:
logger.exception(
'Exception while receiving from %s connection %r with reason %r' %
(self.client.tag, self.client.connection, e),
(self.work.tag, self.work.connection, e),
)
return True
return False
Expand Down Expand Up @@ -324,7 +324,7 @@ def run(self) -> None:
except Exception as e:
logger.exception(
'Exception while handling connection %r' %
self.client.connection, exc_info=e,
self.work.connection, exc_info=e,
)
finally:
self.shutdown()
Expand Down Expand Up @@ -377,24 +377,24 @@ def _run_once(self) -> bool:

def _flush(self) -> None:
assert self.selector
if not self.client.has_buffer():
if not self.work.has_buffer():
return
try:
self.selector.register(
self.client.connection,
self.work.connection,
selectors.EVENT_WRITE,
)
while self.client.has_buffer():
while self.work.has_buffer():
ev: List[
Tuple[selectors.SelectorKey, int]
] = self.selector.select(timeout=1)
if len(ev) == 0:
continue
self.client.flush()
self.work.flush()
except BrokenPipeError:
pass
finally:
self.selector.unregister(self.client.connection)
self.selector.unregister(self.work.connection)

def _connection_inactive_for(self) -> float:
return time.time() - self.last_activity
12 changes: 6 additions & 6 deletions tests/http/exceptions/test_http_proxy_auth_failed.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def test_proxy_auth_fails_without_cred(self, mock_server_conn: mock.Mock) -> Non

self.protocol_handler._run_once()
mock_server_conn.assert_not_called()
self.assertEqual(self.protocol_handler.client.has_buffer(), True)
self.assertEqual(self.protocol_handler.work.has_buffer(), True)
self.assertEqual(
self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
self.protocol_handler.work.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
)
self._conn.send.assert_not_called()

Expand All @@ -92,9 +92,9 @@ def test_proxy_auth_fails_with_invalid_cred(self, mock_server_conn: mock.Mock) -

self.protocol_handler._run_once()
mock_server_conn.assert_not_called()
self.assertEqual(self.protocol_handler.client.has_buffer(), True)
self.assertEqual(self.protocol_handler.work.has_buffer(), True)
self.assertEqual(
self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
self.protocol_handler.work.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
)
self._conn.send.assert_not_called()

Expand All @@ -121,7 +121,7 @@ def test_proxy_auth_works_with_valid_cred(self, mock_server_conn: mock.Mock) ->

self.protocol_handler._run_once()
mock_server_conn.assert_called_once()
self.assertEqual(self.protocol_handler.client.has_buffer(), False)
self.assertEqual(self.protocol_handler.work.has_buffer(), False)

@mock.patch('proxy.http.proxy.server.TcpServerConnection')
def test_proxy_auth_works_with_mixed_case_basic_string(self, mock_server_conn: mock.Mock) -> None:
Expand All @@ -146,4 +146,4 @@ def test_proxy_auth_works_with_mixed_case_basic_string(self, mock_server_conn: m

self.protocol_handler._run_once()
mock_server_conn.assert_called_once()
self.assertEqual(self.protocol_handler.client.has_buffer(), False)
self.assertEqual(self.protocol_handler.work.has_buffer(), False)
2 changes: 1 addition & 1 deletion tests/http/test_http_proxy_tls_interception.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def mock_connection() -> Any:
)
self.assertEqual(self._conn.setblocking.call_count, 2)
self.assertEqual(
self.protocol_handler.client.connection,
self.protocol_handler.work.connection,
self.mock_ssl_wrap.return_value,
)

Expand Down
Loading

0 comments on commit f48771f

Please sign in to comment.