Skip to content

Commit 4fbe406

Browse files
committed
Merge pull request mitmproxy#67 from Kriechi/http2-wip
HTTP/2: preparations for pathod
2 parents 0595585 + 0d137ea commit 4fbe406

File tree

4 files changed

+223
-51
lines changed

4 files changed

+223
-51
lines changed

netlib/http2/protocol.py

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,55 +26,80 @@ class HTTP2Protocol(object):
2626
)
2727

2828
# "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"
29-
CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'
29+
CLIENT_CONNECTION_PREFACE =\
30+
'505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex')
3031

3132
ALPN_PROTO_H2 = 'h2'
3233

33-
def __init__(self, tcp_client):
34-
self.tcp_client = tcp_client
34+
def __init__(self, tcp_handler, is_server=False):
35+
self.tcp_handler = tcp_handler
36+
self.is_server = is_server
3537

3638
self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy()
3739
self.current_stream_id = None
3840
self.encoder = Encoder()
3941
self.decoder = Decoder()
42+
self.connection_preface_performed = False
4043

4144
def check_alpn(self):
42-
alp = self.tcp_client.get_alpn_proto_negotiated()
45+
alp = self.tcp_handler.get_alpn_proto_negotiated()
4346
if alp != self.ALPN_PROTO_H2:
4447
raise NotImplementedError(
4548
"HTTP2Protocol can not handle unknown ALP: %s" % alp)
4649
return True
4750

48-
def perform_connection_preface(self):
49-
self.tcp_client.wfile.write(
50-
bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex')))
51-
self.send_frame(frame.SettingsFrame(state=self))
52-
53-
# read server settings frame
54-
frm = frame.Frame.from_file(self.tcp_client.rfile, self)
51+
def _receive_settings(self):
52+
frm = frame.Frame.from_file(self.tcp_handler.rfile, self)
5553
assert isinstance(frm, frame.SettingsFrame)
5654
self._apply_settings(frm.settings)
5755

58-
# read setting ACK frame
56+
def _read_settings_ack(self):
5957
settings_ack_frame = self.read_frame()
6058
assert isinstance(settings_ack_frame, frame.SettingsFrame)
6159
assert settings_ack_frame.flags & frame.Frame.FLAG_ACK
6260
assert len(settings_ack_frame.settings) == 0
6361

62+
def perform_server_connection_preface(self, force=False):
63+
if force or not self.connection_preface_performed:
64+
self.connection_preface_performed = True
65+
66+
magic_length = len(self.CLIENT_CONNECTION_PREFACE)
67+
magic = self.tcp_handler.rfile.safe_read(magic_length)
68+
assert magic == self.CLIENT_CONNECTION_PREFACE
69+
70+
self.send_frame(frame.SettingsFrame(state=self))
71+
self._receive_settings()
72+
self._read_settings_ack()
73+
74+
def perform_client_connection_preface(self, force=False):
75+
if force or not self.connection_preface_performed:
76+
self.connection_preface_performed = True
77+
78+
self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE)
79+
80+
self.send_frame(frame.SettingsFrame(state=self))
81+
self._receive_settings()
82+
self._read_settings_ack()
83+
6484
def next_stream_id(self):
6585
if self.current_stream_id is None:
66-
self.current_stream_id = 1
86+
if self.is_server:
87+
# servers must use even stream ids
88+
self.current_stream_id = 2
89+
else:
90+
# clients must use odd stream ids
91+
self.current_stream_id = 1
6792
else:
6893
self.current_stream_id += 2
6994
return self.current_stream_id
7095

7196
def send_frame(self, frame):
7297
raw_bytes = frame.to_bytes()
73-
self.tcp_client.wfile.write(raw_bytes)
74-
self.tcp_client.wfile.flush()
98+
self.tcp_handler.wfile.write(raw_bytes)
99+
self.tcp_handler.wfile.flush()
75100

76101
def read_frame(self):
77-
frm = frame.Frame.from_file(self.tcp_client.rfile, self)
102+
frm = frame.Frame.from_file(self.tcp_handler.rfile, self)
78103
if isinstance(frm, frame.SettingsFrame):
79104
self._apply_settings(frm.settings)
80105

@@ -127,10 +152,13 @@ def create_request(self, method, path, headers=None, body=None):
127152
if headers is None:
128153
headers = []
129154

155+
authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
130156
headers = [
131157
(b':method', bytes(method)),
132158
(b':path', bytes(path)),
133-
(b':scheme', b'https')] + headers
159+
(b':scheme', b'https'),
160+
(b':authority', authority),
161+
] + headers
134162

135163
stream_id = self.next_stream_id()
136164

@@ -139,25 +167,50 @@ def create_request(self, method, path, headers=None, body=None):
139167
self._create_body(body, stream_id)))
140168

141169
def read_response(self):
170+
headers, body = self._receive_transmission()
171+
return headers[':status'], headers, body
172+
173+
def read_request(self):
174+
return self._receive_transmission()
175+
176+
def _receive_transmission(self):
177+
body_expected = True
178+
142179
header_block_fragment = b''
143180
body = b''
144181

145182
while True:
146183
frm = self.read_frame()
147-
if isinstance(frm, frame.HeadersFrame):
184+
if isinstance(frm, frame.HeadersFrame)\
185+
or isinstance(frm, frame.ContinuationFrame):
148186
header_block_fragment += frm.header_block_fragment
149-
if frm.flags | frame.Frame.FLAG_END_HEADERS:
187+
if frm.flags & frame.Frame.FLAG_END_HEADERS:
188+
if frm.flags & frame.Frame.FLAG_END_STREAM:
189+
body_expected = False
150190
break
151191

152-
while True:
192+
while body_expected:
153193
frm = self.read_frame()
154194
if isinstance(frm, frame.DataFrame):
155195
body += frm.payload
156-
if frm.flags | frame.Frame.FLAG_END_STREAM:
196+
if frm.flags & frame.Frame.FLAG_END_STREAM:
157197
break
198+
# TODO: implement window update & flow
158199

159200
headers = {}
160201
for header, value in self.decoder.decode(header_block_fragment):
161202
headers[header] = value
162203

163-
return headers[':status'], headers, body
204+
return headers, body
205+
206+
def create_response(self, code, headers=None, body=None):
207+
if headers is None:
208+
headers = []
209+
210+
headers = [(b':status', bytes(str(code)))] + headers
211+
212+
stream_id = self.next_stream_id()
213+
214+
return list(itertools.chain(
215+
self._create_headers(headers, stream_id, end_stream=(body is None)),
216+
self._create_body(body, stream_id)))

netlib/tcp.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
SSLv3_METHOD = SSL.SSLv3_METHOD
2020
SSLv23_METHOD = SSL.SSLv23_METHOD
2121
TLSv1_METHOD = SSL.TLSv1_METHOD
22+
TLSv1_1_METHOD = SSL.TLSv1_1_METHOD
23+
TLSv1_2_METHOD = SSL.TLSv1_2_METHOD
24+
2225
OP_NO_SSLv2 = SSL.OP_NO_SSLv2
2326
OP_NO_SSLv3 = SSL.OP_NO_SSLv3
2427

@@ -376,7 +379,7 @@ def _create_ssl_context(self,
376379
alpn_select=None,
377380
):
378381
"""
379-
:param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD
382+
:param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, TLSv1_1_METHOD, or TLSv1_2_METHOD
380383
:param options: A bit field consisting of OpenSSL.SSL.OP_* values
381384
:param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html
382385
:rtype : SSL.Context
@@ -404,16 +407,17 @@ def _create_ssl_context(self,
404407
context.set_info_callback(log_ssl_key)
405408

406409
if OpenSSL._util.lib.Cryptography_HAS_ALPN:
407-
# advertise application layer protocols
408410
if alpn_protos is not None:
411+
# advertise application layer protocols
409412
context.set_alpn_protos(alpn_protos)
410-
411-
# select application layer protocol
412-
if alpn_select is not None:
413-
def alpn_select_f(conn, options):
414-
return bytes(alpn_select)
415-
416-
context.set_alpn_select_callback(alpn_select_f)
413+
elif alpn_select is not None:
414+
# select application layer protocol
415+
def alpn_select_callback(conn, options):
416+
if alpn_select in options:
417+
return bytes(alpn_select)
418+
else: # pragma no cover
419+
return options[0]
420+
context.set_alpn_select_callback(alpn_select_callback)
417421

418422
return context
419423

@@ -499,9 +503,9 @@ def gettimeout(self):
499503
return self.connection.gettimeout()
500504

501505
def get_alpn_proto_negotiated(self):
502-
if OpenSSL._util.lib.Cryptography_HAS_ALPN:
506+
if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established:
503507
return self.connection.get_alpn_proto_negotiated()
504-
else: # pragma no cover
508+
else:
505509
return None
506510

507511

@@ -531,7 +535,6 @@ def create_ssl_context(self,
531535
request_client_cert=None,
532536
chain_file=None,
533537
dhparams=None,
534-
alpn_select=None,
535538
**sslctx_kwargs):
536539
"""
537540
cert: A certutils.SSLCert object.
@@ -558,9 +561,7 @@ def create_ssl_context(self,
558561
until then we're conservative.
559562
"""
560563

561-
context = self._create_ssl_context(
562-
alpn_select=alpn_select,
563-
**sslctx_kwargs)
564+
context = self._create_ssl_context(**sslctx_kwargs)
564565

565566
context.use_privatekey(key)
566567
context.use_certificate(cert.x509)
@@ -585,7 +586,7 @@ def save_cert(conn, cert, errno, depth, preverify_ok):
585586

586587
return context
587588

588-
def convert_to_ssl(self, cert, key, alpn_select=None, **sslctx_kwargs):
589+
def convert_to_ssl(self, cert, key, **sslctx_kwargs):
589590
"""
590591
Convert connection to SSL.
591592
For a list of parameters, see BaseHandler._create_ssl_context(...)
@@ -594,7 +595,6 @@ def convert_to_ssl(self, cert, key, alpn_select=None, **sslctx_kwargs):
594595
context = self.create_ssl_context(
595596
cert,
596597
key,
597-
alpn_select=alpn_select,
598598
**sslctx_kwargs)
599599
self.connection = SSL.Connection(context, self.connection)
600600
self.connection.set_accept_state()
@@ -612,6 +612,12 @@ def handle(self): # pragma: no cover
612612
def settimeout(self, n):
613613
self.connection.settimeout(n)
614614

615+
def get_alpn_proto_negotiated(self):
616+
if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established:
617+
return self.connection.get_alpn_proto_negotiated()
618+
else:
619+
return None
620+
615621

616622
class TCPServer(object):
617623
request_queue_size = 20

0 commit comments

Comments
 (0)