Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dnsdist: Gather Server Name Indication on QUIC (DoQ, DoH3) connections #15024

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions pdns/dnsdistdist/doh3.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ class H3Connection
H3Connection& operator=(H3Connection&&) = default;
~H3Connection() = default;

std::shared_ptr<const std::string> getSNI()
{
if (!d_sni) {
d_sni = std::make_shared<const std::string>(getSNIFromQuicheConnection(d_conn));
}
return d_sni;
}

ComboAddress d_peer;
ComboAddress d_localAddr;
QuicheConnection d_conn;
Expand All @@ -71,6 +79,7 @@ class H3Connection
std::unordered_map<uint64_t, dnsdist::doh3::h3_headers_t> d_headersBuffers;
std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
std::unordered_map<uint64_t, PacketBuffer> d_streamOutBuffers;
std::shared_ptr<const std::string> d_sni{nullptr};
};

static void sendBackDOH3Unit(DOH3UnitUniquePtr&& unit, const char* description);
Expand Down Expand Up @@ -566,6 +575,9 @@ static void processDOH3Query(DOH3UnitUniquePtr&& doh3Unit)
ids.origFlags = *flags;
return true;
});
if (unit->sni) {
dnsQuestion.sni = *unit->sni;
}
unit->ids.cs = &clientState;

auto result = processQuery(dnsQuestion, downstream);
Expand Down Expand Up @@ -640,7 +652,7 @@ static void processDOH3Query(DOH3UnitUniquePtr&& doh3Unit)
}
}

static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID, dnsdist::doh3::h3_headers_t&& headers)
static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID, const std::shared_ptr<const std::string>& sni, dnsdist::doh3::h3_headers_t&& headers)
{
try {
auto unit = std::make_unique<DOH3Unit>(std::move(query));
Expand All @@ -650,6 +662,7 @@ static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, con
unit->ids.protocol = dnsdist::Protocol::DoH3;
unit->serverConnID = serverConnID;
unit->streamID = streamID;
unit->sni = sni;
unit->headers = std::move(headers);

processDOH3Query(std::move(unit));
Expand Down Expand Up @@ -751,7 +764,7 @@ static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& fronten
return;
}
DEBUGLOG("Dispatching GET query");
doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), conn.d_localAddr, client, serverConnID, streamID, std::move(headers));
doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), conn.d_localAddr, client, serverConnID, streamID, conn.getSNI(), std::move(headers));
conn.d_streamBuffers.erase(streamID);
conn.d_headersBuffers.erase(streamID);
return;
Expand Down Expand Up @@ -816,7 +829,7 @@ static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend,
}

DEBUGLOG("Dispatching POST query");
doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID, std::move(headers));
doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID, conn.getSNI(), std::move(headers));
conn.d_headersBuffers.erase(streamID);
conn.d_streamBuffers.erase(streamID);
}
Expand Down
1 change: 1 addition & 0 deletions pdns/dnsdistdist/doh3.hh
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ struct DOH3Unit
PacketBuffer serverConnID;
dnsdist::doh3::h3_headers_t headers;
std::shared_ptr<DownstreamState> downstream{nullptr};
std::shared_ptr<const std::string> sni{nullptr};
std::string d_contentTypeOut;
DOH3ServerConfig* dsc{nullptr};
uint64_t streamID{0};
Expand Down
14 changes: 13 additions & 1 deletion pdns/dnsdistdist/doq-common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,18 @@ bool recvAsync(Socket& socket, PacketBuffer& buffer, ComboAddress& clientAddr, C
return !buffer.empty();
}

};
std::string getSNIFromQuicheConnection(const QuicheConnection& conn)
{
#if defined(HAVE_QUICHE_CONN_SERVER_NAME)
const uint8_t* sniPtr = nullptr;
size_t sniPtrSize = 0;
quiche_conn_server_name(conn.get(), &sniPtr, &sniPtrSize);
if (sniPtrSize > 0) {
return std::string(reinterpret_cast<const char*>(sniPtr), sniPtrSize);
}
#endif /* HAVE_QUICHE_CONN_SERVER_NAME */
return {};
}
}

#endif
3 changes: 2 additions & 1 deletion pdns/dnsdistdist/doq-common.hh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include <map>
#include <memory>
#include <string>

#include "config.h"

Expand Down Expand Up @@ -97,7 +98,7 @@ void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, co
void flushEgress(Socket& sock, QuicheConnection& conn, const ComboAddress& peer, const ComboAddress& localAddr, PacketBuffer& buffer);
void configureQuiche(QuicheConfig& config, const QuicheParams& params, bool isHTTP);
bool recvAsync(Socket& socket, PacketBuffer& buffer, ComboAddress& clientAddr, ComboAddress& localAddr);

std::string getSNIFromQuicheConnection(const QuicheConnection& conn);
};

#endif
17 changes: 15 additions & 2 deletions pdns/dnsdistdist/doq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,22 @@ class Connection
Connection& operator=(Connection&&) = default;
~Connection() = default;

std::shared_ptr<const std::string> getSNI()
{
if (!d_sni) {
d_sni = std::make_shared<const std::string>(getSNIFromQuicheConnection(d_conn));
}
return d_sni;
}

ComboAddress d_peer;
ComboAddress d_localAddr;
QuicheConnection d_conn;
QuicheConfig d_config;

std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
std::unordered_map<uint64_t, PacketBuffer> d_streamOutBuffers;
std::shared_ptr<const std::string> d_sni{nullptr};
};

static void sendBackDOQUnit(DOQUnitUniquePtr&& unit, const char* description);
Expand Down Expand Up @@ -472,6 +481,9 @@ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit)
ids.origFlags = *flags;
return true;
});
if (unit->sni) {
dnsQuestion.sni = *unit->sni;
}
unit->ids.cs = &clientState;

auto result = processQuery(dnsQuestion, downstream);
Expand Down Expand Up @@ -541,7 +553,7 @@ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit)
}
}

static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID)
static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID, const std::shared_ptr<const std::string>& sni)
{
try {
auto unit = std::make_unique<DOQUnit>(std::move(query));
Expand All @@ -551,6 +563,7 @@ static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const
unit->ids.protocol = dnsdist::Protocol::DoQ;
unit->serverConnID = serverConnID;
unit->streamID = streamID;
unit->sni = sni;

processDOQQuery(std::move(unit));
}
Expand Down Expand Up @@ -649,7 +662,7 @@ static void handleReadableStream(DOQFrontend& frontend, ClientState& clientState
return;
}
DEBUGLOG("Dispatching query");
doq_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID);
doq_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), conn.d_localAddr, client, serverConnID, streamID, conn.getSNI());
conn.d_streamBuffers.erase(streamID);
}

Expand Down
1 change: 1 addition & 0 deletions pdns/dnsdistdist/doq.hh
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ struct DOQUnit
PacketBuffer response;
PacketBuffer serverConnID;
std::shared_ptr<DownstreamState> downstream{nullptr};
std::shared_ptr<const std::string> sni{nullptr};
DOQServerConfig* dsc{nullptr};
uint64_t streamID{0};
size_t proxyProtocolPayloadSize{0};
Expand Down
10 changes: 10 additions & 0 deletions pdns/dnsdistdist/m4/pdns_with_quiche.m4
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ AC_DEFUN([PDNS_WITH_QUICHE], [
AC_DEFINE([HAVE_QUICHE], [1], [Define to 1 if you have quiche])
], [ : ])
])
AS_IF([test "x$HAVE_QUICHE" = "x1"], [
save_CFLAGS=$CFLAGS
save_LIBS=$LIBS
CFLAGS="$QUICHE_CFLAGS $CFLAGS"
LIBS="$QUICHE_LIBS $LIBS"
AC_CHECK_FUNCS([quiche_conn_server_name])
CFLAGS=$save_CFLAGS
LIBS=$save_LIBS

])
])
])
AM_CONDITIONAL([HAVE_QUICHE], [test "x$QUICHE_LIBS" != "x"])
Expand Down
2 changes: 1 addition & 1 deletion regression-tests.dnsdist/doh3client.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ async def async_h3_query(


def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False, additional_headers=None, raw_response=False):
configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True)
configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True, server_name=server_hostname)
if verify:
configuration.load_verify_locations(verify)

Expand Down
4 changes: 2 additions & 2 deletions regression-tests.dnsdist/doqclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(self, error, message="Stream reset by peer"):
super().__init__(message)

def quic_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None):
configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True)
configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True, server_name=server_hostname)
if verify:
configuration.load_verify_locations(verify)
(result, serial) = asyncio.run(
Expand All @@ -108,7 +108,7 @@ def quic_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server
return (result, serial)

def quic_bogus_query(query, host='127.0.0.1', timeout=2, port=853, verify=None, server_hostname=None):
configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True)
configuration = QuicConfiguration(alpn_protocols=["doq"], is_client=True, server_name=server_hostname)
if verify:
configuration.load_verify_locations(verify)
(result, _) = asyncio.run(
Expand Down
82 changes: 82 additions & 0 deletions regression-tests.dnsdist/test_SNI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#!/usr/bin/env python
import base64
import dns
import os
import unittest
import pycurl

from dnsdisttests import DNSDistTest, pickAvailablePort

class TestSNI(DNSDistTest):
_serverKey = 'server.key'
_serverCert = 'server.chain'
_serverName = 'tls.tests.dnsdist.org'
_caCert = 'ca.pem'
_tlsServerPort = pickAvailablePort()
_dohWithNGHTTP2ServerPort = pickAvailablePort()
_doqServerPort = pickAvailablePort()
_doh3ServerPort = pickAvailablePort()
_dohWithNGHTTP2BaseURL = ("https://%s:%d/" % (_serverName, _dohWithNGHTTP2ServerPort))
_dohBaseURL = ("https://%s:%d/" % (_serverName, _doh3ServerPort))

_config_template = """
newServer{address="127.0.0.1:%d"}

addTLSLocal("127.0.0.1:%d", "%s", "%s", { provider="openssl" })
addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library="nghttp2"})
addDOQLocal("127.0.0.1:%d", "%s", "%s")
addDOH3Local("127.0.0.1:%d", "%s", "%s")

function displaySNI(dq)
local sni = dq:getServerNameIndication()
if sni ~= '%s' then
return DNSAction.Spoof, '1.2.3.4'
end
return DNSAction.Allow
end
addAction(AllRule(), LuaAction(displaySNI))
"""
_config_params = ['_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey', '_dohWithNGHTTP2ServerPort', '_serverCert', '_serverKey', '_doqServerPort', '_serverCert', '_serverKey', '_doh3ServerPort', '_serverCert', '_serverKey', '_serverName']

# enable these once Quiche > 0.22 is available, including https://github.com/cloudflare/quiche/pull/1895
@unittest.skipUnless('ENABLE_SNI_TESTS_WITH_QUICHE' in os.environ, "SNI tests with Quicheare disabled")
def testServerNameIndicationWithQuiche(self):
name = 'simple.sni.tests.powerdns.com.'
query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
response = dns.message.make_response(query)
rrset = dns.rrset.from_text(name,
3600,
dns.rdataclass.IN,
dns.rdatatype.A,
'127.0.0.1')
response.answer.append(rrset)
for method in ["sendDOQQueryWrapper", "sendDOH3QueryWrapper"]:
sender = getattr(self, method)
(receivedQuery, receivedResponse) = sender(query, response, timeout=1)
self.assertTrue(receivedQuery)
receivedQuery.id = query.id
self.assertEqual(query, receivedQuery)
self.assertTrue(receivedResponse)
if method == 'sendDOQQueryWrapper':
# dnspython sets the ID to 0
receivedResponse.id = response.id
self.assertEqual(response, receivedResponse)

def testServerNameIndication(self):
name = 'simple.sni.tests.powerdns.com.'
query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
response = dns.message.make_response(query)
rrset = dns.rrset.from_text(name,
3600,
dns.rdataclass.IN,
dns.rdatatype.A,
'127.0.0.1')
response.answer.append(rrset)
for method in ["sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper"]:
sender = getattr(self, method)
(receivedQuery, receivedResponse) = sender(query, response, timeout=1)
self.assertTrue(receivedQuery)
receivedQuery.id = query.id
self.assertEqual(query, receivedQuery)
self.assertTrue(receivedResponse)
self.assertEqual(response, receivedResponse)
Loading