Skip to content

Commit

Permalink
Add QUIC TLS session ticket support.
Browse files Browse the repository at this point in the history
  • Loading branch information
rthalley committed Oct 28, 2023
1 parent 277ee25 commit 7bd3df6
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 8 deletions.
8 changes: 6 additions & 2 deletions dns/quic/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,12 @@ class AsyncioQuicManager(AsyncQuicManager):
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name)

def connect(self, address, port=853, source=None, source_port=0):
(connection, start) = self._connect(address, port, source, source_port)
def connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
(connection, start) = self._connect(
address, port, source, source_port, want_session_ticket
)
if start:
connection.run()
return connection
Expand Down
44 changes: 42 additions & 2 deletions dns/quic/_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license

import copy
import functools
import socket
import struct
import time
Expand All @@ -11,6 +13,10 @@
import dns.inet

QUIC_MAX_DATAGRAM = 2048
MAX_SESSION_TICKETS = 8
# If we hit the max sessions limit we will delete this many of the oldest connections.
# The value must be a integer > 0 and <= MAX_SESSION_TICKETS.
SESSIONS_TO_DELETE = MAX_SESSION_TICKETS // 4


class UnexpectedEOF(Exception):
Expand Down Expand Up @@ -145,6 +151,7 @@ class BaseQuicManager:
def __init__(self, conf, verify_mode, connection_factory, server_name=None):
self._connections = {}
self._connection_factory = connection_factory
self._session_tickets = {}
if conf is None:
verify_path = None
if isinstance(verify_mode, str):
Expand All @@ -159,11 +166,33 @@ def __init__(self, conf, verify_mode, connection_factory, server_name=None):
conf.load_verify_locations(verify_path)
self._conf = conf

def _connect(self, address, port=853, source=None, source_port=0):
def _connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
connection = self._connections.get((address, port))
if connection is not None:
return (connection, False)
qconn = aioquic.quic.connection.QuicConnection(configuration=self._conf)
conf = self._conf
if want_session_ticket:
try:
session_ticket = self._session_tickets.pop((address, port))
# We found a session ticket, so make a configuration that uses it.
conf = copy.copy(conf)
conf.session_ticket = session_ticket
except KeyError:
# No session ticket.
pass
# Whether or not we found a session ticket, we want a handler to save
# one.
session_ticket_handler = functools.partial(
self.save_session_ticket, address, port
)
else:
session_ticket_handler = None
qconn = aioquic.quic.connection.QuicConnection(
configuration=conf,
session_ticket_handler=session_ticket_handler,
)
lladdress = dns.inet.low_level_address_tuple((address, port))
qconn.connect(lladdress, time.time())
connection = self._connection_factory(
Expand All @@ -178,6 +207,17 @@ def closed(self, address, port):
except KeyError:
pass

def save_session_ticket(self, address, port, ticket):
# We rely on dictionaries keys() being in insertion order here. We
# can't just popitem() as that would be LIFO which is the opposite of
# what we want.
l = len(self._session_tickets)
if l >= MAX_SESSION_TICKETS:
keys_to_delete = list(self._session_tickets.keys())[0:SESSIONS_TO_DELETE]
for key in keys_to_delete:
del self._session_tickets[key]
self._session_tickets[(address, port)] = ticket


class AsyncQuicManager(BaseQuicManager):
def connect(self, address, port=853, source=None, source_port=0):
Expand Down
12 changes: 10 additions & 2 deletions dns/quic/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,13 @@ def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
super().__init__(conf, verify_mode, SyncQuicConnection, server_name)
self._lock = threading.Lock()

def connect(self, address, port=853, source=None, source_port=0):
def connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
with self._lock:
(connection, start) = self._connect(address, port, source, source_port)
(connection, start) = self._connect(
address, port, source, source_port, want_session_ticket
)
if start:
connection.run()
return connection
Expand All @@ -218,6 +222,10 @@ def closed(self, address, port):
with self._lock:
super().closed(address, port)

def save_session_ticket(self, address, port, ticket):
with self._lock:
super().save_session_ticket(address, port, ticket)

def __enter__(self):
return self

Expand Down
8 changes: 6 additions & 2 deletions dns/quic/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,12 @@ def __init__(
super().__init__(conf, verify_mode, TrioQuicConnection, server_name)
self._nursery = nursery

def connect(self, address, port=853, source=None, source_port=0):
(connection, start) = self._connect(address, port, source, source_port)
def connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
(connection, start) = self._connect(
address, port, source, source_port, want_session_ticket
)
if start:
self._nursery.start_soon(connection.run)
return connection
Expand Down

0 comments on commit 7bd3df6

Please sign in to comment.