Skip to content
This repository was archived by the owner on Jun 1, 2018. It is now read-only.

HTTP/2: preparations for pathod #67

Merged
merged 7 commits into from
Jun 14, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 75 additions & 22 deletions netlib/http2/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,55 +26,80 @@ class HTTP2Protocol(object):
)

# "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"
CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'
CLIENT_CONNECTION_PREFACE =\
'505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex')

ALPN_PROTO_H2 = 'h2'

def __init__(self, tcp_client):
self.tcp_client = tcp_client
def __init__(self, tcp_handler, is_server=False):
self.tcp_handler = tcp_handler
self.is_server = is_server

self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy()
self.current_stream_id = None
self.encoder = Encoder()
self.decoder = Decoder()
self.connection_preface_performed = False

def check_alpn(self):
alp = self.tcp_client.get_alpn_proto_negotiated()
alp = self.tcp_handler.get_alpn_proto_negotiated()
if alp != self.ALPN_PROTO_H2:
raise NotImplementedError(
"HTTP2Protocol can not handle unknown ALP: %s" % alp)
return True

def perform_connection_preface(self):
self.tcp_client.wfile.write(
bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex')))
self.send_frame(frame.SettingsFrame(state=self))

# read server settings frame
frm = frame.Frame.from_file(self.tcp_client.rfile, self)
def _receive_settings(self):
frm = frame.Frame.from_file(self.tcp_handler.rfile, self)
assert isinstance(frm, frame.SettingsFrame)
self._apply_settings(frm.settings)

# read setting ACK frame
def _read_settings_ack(self):
settings_ack_frame = self.read_frame()
assert isinstance(settings_ack_frame, frame.SettingsFrame)
assert settings_ack_frame.flags & frame.Frame.FLAG_ACK
assert len(settings_ack_frame.settings) == 0

def perform_server_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
self.connection_preface_performed = True

magic_length = len(self.CLIENT_CONNECTION_PREFACE)
magic = self.tcp_handler.rfile.safe_read(magic_length)
assert magic == self.CLIENT_CONNECTION_PREFACE

self.send_frame(frame.SettingsFrame(state=self))
self._receive_settings()
self._read_settings_ack()

def perform_client_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
self.connection_preface_performed = True

self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE)

self.send_frame(frame.SettingsFrame(state=self))
self._receive_settings()
self._read_settings_ack()

def next_stream_id(self):
if self.current_stream_id is None:
self.current_stream_id = 1
if self.is_server:
# servers must use even stream ids
self.current_stream_id = 2
else:
# clients must use odd stream ids
self.current_stream_id = 1
else:
self.current_stream_id += 2
return self.current_stream_id

def send_frame(self, frame):
raw_bytes = frame.to_bytes()
self.tcp_client.wfile.write(raw_bytes)
self.tcp_client.wfile.flush()
self.tcp_handler.wfile.write(raw_bytes)
self.tcp_handler.wfile.flush()

def read_frame(self):
frm = frame.Frame.from_file(self.tcp_client.rfile, self)
frm = frame.Frame.from_file(self.tcp_handler.rfile, self)
if isinstance(frm, frame.SettingsFrame):
self._apply_settings(frm.settings)

Expand Down Expand Up @@ -127,10 +152,13 @@ def create_request(self, method, path, headers=None, body=None):
if headers is None:
headers = []

authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
headers = [
(b':method', bytes(method)),
(b':path', bytes(path)),
(b':scheme', b'https')] + headers
(b':scheme', b'https'),
(b':authority', authority),
] + headers

stream_id = self.next_stream_id()

Expand All @@ -139,25 +167,50 @@ def create_request(self, method, path, headers=None, body=None):
self._create_body(body, stream_id)))

def read_response(self):
headers, body = self._receive_transmission()
return headers[':status'], headers, body

def read_request(self):
return self._receive_transmission()

def _receive_transmission(self):
body_expected = True

header_block_fragment = b''
body = b''

while True:
frm = self.read_frame()
if isinstance(frm, frame.HeadersFrame):
if isinstance(frm, frame.HeadersFrame)\
or isinstance(frm, frame.ContinuationFrame):
header_block_fragment += frm.header_block_fragment
if frm.flags | frame.Frame.FLAG_END_HEADERS:
if frm.flags & frame.Frame.FLAG_END_HEADERS:
if frm.flags & frame.Frame.FLAG_END_STREAM:
body_expected = False
break

while True:
while body_expected:
frm = self.read_frame()
if isinstance(frm, frame.DataFrame):
body += frm.payload
if frm.flags | frame.Frame.FLAG_END_STREAM:
if frm.flags & frame.Frame.FLAG_END_STREAM:
break
# TODO: implement window update & flow

headers = {}
for header, value in self.decoder.decode(header_block_fragment):
headers[header] = value

return headers[':status'], headers, body
return headers, body

def create_response(self, code, headers=None, body=None):
if headers is None:
headers = []

headers = [(b':status', bytes(str(code)))] + headers

stream_id = self.next_stream_id()

return list(itertools.chain(
self._create_headers(headers, stream_id, end_stream=(body is None)),
self._create_body(body, stream_id)))
40 changes: 23 additions & 17 deletions netlib/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
SSLv3_METHOD = SSL.SSLv3_METHOD
SSLv23_METHOD = SSL.SSLv23_METHOD
TLSv1_METHOD = SSL.TLSv1_METHOD
TLSv1_1_METHOD = SSL.TLSv1_1_METHOD
TLSv1_2_METHOD = SSL.TLSv1_2_METHOD

OP_NO_SSLv2 = SSL.OP_NO_SSLv2
OP_NO_SSLv3 = SSL.OP_NO_SSLv3

Expand Down Expand Up @@ -376,7 +379,7 @@ def _create_ssl_context(self,
alpn_select=None,
):
"""
:param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD
:param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, TLSv1_1_METHOD, or TLSv1_2_METHOD
:param options: A bit field consisting of OpenSSL.SSL.OP_* values
:param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html
:rtype : SSL.Context
Expand Down Expand Up @@ -404,16 +407,17 @@ def _create_ssl_context(self,
context.set_info_callback(log_ssl_key)

if OpenSSL._util.lib.Cryptography_HAS_ALPN:
# advertise application layer protocols
if alpn_protos is not None:
# advertise application layer protocols
context.set_alpn_protos(alpn_protos)

# select application layer protocol
if alpn_select is not None:
def alpn_select_f(conn, options):
return bytes(alpn_select)

context.set_alpn_select_callback(alpn_select_f)
elif alpn_select is not None:
# select application layer protocol
def alpn_select_callback(conn, options):
if alpn_select in options:
return bytes(alpn_select)
else: # pragma no cover
return options[0]
context.set_alpn_select_callback(alpn_select_callback)

return context

Expand Down Expand Up @@ -499,9 +503,9 @@ def gettimeout(self):
return self.connection.gettimeout()

def get_alpn_proto_negotiated(self):
if OpenSSL._util.lib.Cryptography_HAS_ALPN:
if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established:
return self.connection.get_alpn_proto_negotiated()
else: # pragma no cover
else:
return None


Expand Down Expand Up @@ -531,7 +535,6 @@ def create_ssl_context(self,
request_client_cert=None,
chain_file=None,
dhparams=None,
alpn_select=None,
**sslctx_kwargs):
"""
cert: A certutils.SSLCert object.
Expand All @@ -558,9 +561,7 @@ def create_ssl_context(self,
until then we're conservative.
"""

context = self._create_ssl_context(
alpn_select=alpn_select,
**sslctx_kwargs)
context = self._create_ssl_context(**sslctx_kwargs)

context.use_privatekey(key)
context.use_certificate(cert.x509)
Expand All @@ -585,7 +586,7 @@ def save_cert(conn, cert, errno, depth, preverify_ok):

return context

def convert_to_ssl(self, cert, key, alpn_select=None, **sslctx_kwargs):
def convert_to_ssl(self, cert, key, **sslctx_kwargs):
"""
Convert connection to SSL.
For a list of parameters, see BaseHandler._create_ssl_context(...)
Expand All @@ -594,7 +595,6 @@ def convert_to_ssl(self, cert, key, alpn_select=None, **sslctx_kwargs):
context = self.create_ssl_context(
cert,
key,
alpn_select=alpn_select,
**sslctx_kwargs)
self.connection = SSL.Connection(context, self.connection)
self.connection.set_accept_state()
Expand All @@ -612,6 +612,12 @@ def handle(self): # pragma: no cover
def settimeout(self, n):
self.connection.settimeout(n)

def get_alpn_proto_negotiated(self):
if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established:
return self.connection.get_alpn_proto_negotiated()
else:
return None


class TCPServer(object):
request_queue_size = 20
Expand Down
Loading