diff --git a/pdns/dnsdistdist/doh3.cc b/pdns/dnsdistdist/doh3.cc index 640d0738fdfa..1a6150911071 100644 --- a/pdns/dnsdistdist/doh3.cc +++ b/pdns/dnsdistdist/doh3.cc @@ -298,7 +298,11 @@ static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const }, }; quiche_h3_send_response(conn, quic_conn, - streamID, headers, 2, false); + streamID, headers, 2, len == 0); + + if (len == 0) { + return; + } size_t pos = 0; while (pos < len) { @@ -330,7 +334,12 @@ static void handleResponse(DOH3Frontend& frontend, H3Connection& conn, const uin else { ++frontend.d_errorResponses; } - h3_send_response(conn, streamID, statusCode, &response.at(0), response.size()); + if (response.empty()) { + quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast(DOQ_Error_Codes::DOQ_UNSPECIFIED_ERROR)); + } + else { + h3_send_response(conn, streamID, statusCode, &response.at(0), response.size()); + } } static void fillRandom(PacketBuffer& buffer, size_t size) diff --git a/pdns/dnsdistdist/doq.cc b/pdns/dnsdistdist/doq.cc index e2fc597d132e..521314f6e507 100644 --- a/pdns/dnsdistdist/doq.cc +++ b/pdns/dnsdistdist/doq.cc @@ -271,17 +271,6 @@ class DOQCrossProtocolQuery : public CrossProtocolQuery std::shared_ptr DOQCrossProtocolQuery::s_sender = std::make_shared(); -/* from rfc9250 section-4.3 */ -enum class DOQ_Error_Codes : uint64_t -{ - DOQ_NO_ERROR = 0, - DOQ_INTERNAL_ERROR = 1, - DOQ_PROTOCOL_ERROR = 2, - DOQ_REQUEST_CANCELLED = 3, - DOQ_EXCESSIVE_LOAD = 4, - DOQ_UNSPECIFIED_ERROR = 5 -}; - static void handleResponse(DOQFrontend& frontend, Connection& conn, const uint64_t streamID, const PacketBuffer& response) { if (response.empty()) { diff --git a/pdns/dnsdistdist/doq.hh b/pdns/dnsdistdist/doq.hh index 64d080bfd113..efc50ef218f6 100644 --- a/pdns/dnsdistdist/doq.hh +++ b/pdns/dnsdistdist/doq.hh @@ -28,6 +28,7 @@ #include "iputils.hh" #include "libssl.hh" #include "noinitvector.hh" +#include "doq.hh" #include "stat_t.hh" #include "dnsdist-idstate.hh" @@ -36,6 +37,17 @@ struct DownstreamState; #ifdef HAVE_DNS_OVER_QUIC +/* from rfc9250 section-4.3 */ +enum class DOQ_Error_Codes : uint64_t +{ + DOQ_NO_ERROR = 0, + DOQ_INTERNAL_ERROR = 1, + DOQ_PROTOCOL_ERROR = 2, + DOQ_REQUEST_CANCELLED = 3, + DOQ_EXCESSIVE_LOAD = 4, + DOQ_UNSPECIFIED_ERROR = 5 +}; + struct DOQFrontend { DOQFrontend(); diff --git a/regression-tests.dnsdist/doh3client.py b/regression-tests.dnsdist/doh3client.py index eeebb4c3b6fb..c04b06b152a4 100644 --- a/regression-tests.dnsdist/doh3client.py +++ b/regression-tests.dnsdist/doh3client.py @@ -23,9 +23,11 @@ PushPromiseReceived, ) from aioquic.quic.configuration import QuicConfiguration -from aioquic.quic.events import QuicEvent +from aioquic.quic.events import QuicEvent, StreamDataReceived, StreamReset #from aioquic.quic.logger import QuicFileLogger from aioquic.tls import CipherSuite, SessionTicket + +from doqclient import StreamResetError # #class DnsClientProtocol(QuicConnectionProtocol): # def __init__(self, *args, **kwargs): @@ -155,6 +157,10 @@ def http_event_received(self, event: H3Event) -> None: self.pushes[event.push_id].append(event) def quic_event_received(self, event: QuicEvent) -> None: + if isinstance(event, StreamReset): + waiter = self._request_waiter.pop(event.stream_id) + waiter.set_result([event]) + #  pass event to the HTTP layer if self._http is not None: for http_event in self._http.handle_event(event): @@ -215,9 +221,11 @@ async def perform_http_request( for http_event in http_events: if isinstance(http_event, DataReceived): result += http_event.data + if isinstance(http_event, StreamReset): + result = http_event return result - - + + async def async_h3_query( configuration: QuicConfiguration, baseurl: str, @@ -228,7 +236,6 @@ async def async_h3_query( ) -> None: url = "{}?dns={}".format(baseurl, base64.urlsafe_b64encode(query.to_wire()).decode('UTF8').rstrip('=')) - print("Querying for {}".format(url)) async with connect( "127.0.0.1", port, @@ -237,7 +244,6 @@ async def async_h3_query( ) as client: client = cast(HttpClient, client) - print("Sending DNS query") try: async with async_timeout.timeout(timeout): @@ -253,11 +259,6 @@ async def async_h3_query( except asyncio.TimeoutError as e: return e -class StreamResetError(Exception): - def __init__(self, error, message="Stream reset by peer"): - self.error = error - super().__init__(message) - def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None): configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True) if verify: @@ -272,9 +273,9 @@ def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname create_protocol=HttpClient ) ) - # if (isinstance(result, StreamReset)): - # raise StreamResetError(result.error_code) + + if (isinstance(result, StreamReset)): + raise StreamResetError(result.error_code) if (isinstance(result, asyncio.TimeoutError)): raise TimeoutError() - return result - + return dns.message.from_wire(result) diff --git a/regression-tests.dnsdist/quictests.py b/regression-tests.dnsdist/quictests.py new file mode 100644 index 000000000000..dfca492bb6a1 --- /dev/null +++ b/regression-tests.dnsdist/quictests.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python + +import dns +from doqclient import StreamResetError + +class QUICTests(object): + + def testQUICSimple(self): + """ + QUIC: Simple query + """ + name = 'simple.doq.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN', use_edns=False) + query.id = 0 + expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096) + expectedQuery.id = 0 + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + (receivedQuery, receivedResponse) = self.sendQUICQuery(query, response=response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.assertEqual(expectedQuery, receivedQuery) + self.assertEqual(receivedResponse, response) + + def testQUICMultipleStreams(self): + """ + QUIC: Test multiple queries using the same connection + """ + name = 'simple.doq.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN', use_edns=False) + query.id = 0 + expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096) + expectedQuery.id = 0 + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + + connection = self.getQUICConnection() + + (receivedQuery, receivedResponse) = self.sendQUICQuery(query, response=response, connection=connection) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.assertEqual(expectedQuery, receivedQuery) + + (receivedQuery, receivedResponse) = self.sendQUICQuery(query, response=response, connection=connection) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.assertEqual(expectedQuery, receivedQuery) + + def testDropped(self): + """ + QUIC: Dropped query + """ + name = 'drop.doq.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + dropped = False + try: + (_, receivedResponse) = self.sendQUICQuery(query, response=None, useQueue=False) + self.assertTrue(False) + except StreamResetError as e: + self.assertEqual(e.error, 5); + + def testRefused(self): + """ + QUIC: Refused + """ + name = 'refused.doq.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + query.id = 0 + query.flags &= ~dns.flags.RD + expectedResponse = dns.message.make_response(query) + expectedResponse.set_rcode(dns.rcode.REFUSED) + + (_, receivedResponse) = self.sendQUICQuery(query, response=None, useQueue=False) + self.assertEqual(receivedResponse, expectedResponse) + + def testSpoof(self): + """ + QUIC: Spoofed + """ + name = 'spoof.doq.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + query.id = 0 + query.flags &= ~dns.flags.RD + expectedResponse = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '1.2.3.4') + expectedResponse.answer.append(rrset) + + (_, receivedResponse) = self.sendQUICQuery(query, response=None, useQueue=False) + self.assertEqual(receivedResponse, expectedResponse) + + def testQUICNoBackend(self): + """ + QUIC: No backend + """ + name = 'no-backend.doq.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN', use_edns=False) + dropped = False + try: + (_, receivedResponse) = self.sendQUICQuery(query, response=None, useQueue=False) + self.assertTrue(False) + except StreamResetError as e : + self.assertEqual(e.error, 5); + +class QUICWithCacheTests(object): + def testCached(self): + """ + QUIC Cache: Served from cache + """ + numberOfQueries = 10 + name = 'cached.quic.tests.powerdns.com.' + query = dns.message.make_query(name, 'AAAA', 'IN') + query.id = 0 + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.AAAA, + '::1') + response.answer.append(rrset) + + # first query to fill the cache + (receivedQuery, receivedResponse) = self.sendQUICQuery(query, response=response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(receivedResponse, response) + + for _ in range(numberOfQueries): + (_, receivedResponse) = self.sendQUICQuery(query, response=None, useQueue=False) + self.assertEqual(receivedResponse, response) + + total = 0 + for key in self._responsesCounter: + total += self._responsesCounter[key] + + self.assertEqual(total, 1) diff --git a/regression-tests.dnsdist/test_DOH3.py b/regression-tests.dnsdist/test_DOH3.py index 74e4bb15a0d7..dcff35e09690 100644 --- a/regression-tests.dnsdist/test_DOH3.py +++ b/regression-tests.dnsdist/test_DOH3.py @@ -4,10 +4,10 @@ from dnsdisttests import DNSDistTest from dnsdisttests import pickAvailablePort - +from quictests import QUICTests, QUICWithCacheTests import doh3client -class TestDOH3(DNSDistTest): +class TestDOH3(QUICTests, DNSDistTest): _serverKey = 'server.key' _serverCert = 'server.chain' _serverName = 'tls.tests.dnsdist.org' @@ -22,29 +22,13 @@ class TestDOH3(DNSDistTest): addAction("spoof.doq.tests.powerdns.com.", SpoofAction("1.2.3.4")) addAction("no-backend.doq.tests.powerdns.com.", PoolAction('this-pool-has-no-backend')) - addDOH3Local("127.0.0.1:%d", "%s", "%s") + addDOH3Local("127.0.0.1:%d", "%s", "%s", {keyLogFile='/tmp/keys'}) """ _config_params = ['_testServerPort', '_doqServerPort','_serverCert', '_serverKey'] _verboseMode = True - def testDOH3Simple(self): - """ - DOH3: Simple query - """ - name = 'simple.doq.tests.powerdns.com.' - query = dns.message.make_query(name, 'A', 'IN', use_edns=False) - query.id = 0 - expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096) - expectedQuery.id = 0 - response = dns.message.make_response(query) - rrset = dns.rrset.from_text(name, - 3600, - dns.rdataclass.IN, - dns.rdatatype.A, - '127.0.0.1') - response.answer.append(rrset) - (receivedQuery, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, serverName=self._serverName) - self.assertTrue(receivedQuery) - self.assertTrue(receivedResponse) - receivedQuery.id = expectedQuery.id - self.assertEqual(expectedQuery, receivedQuery) + def getQUICConnection(self): + return self.getDOQConnection(self._doqServerPort, self._caCert) + + def sendQUICQuery(self, query, response=None, useQueue=True, connection=None): + return self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName, connection=connection) diff --git a/regression-tests.dnsdist/test_DOQ.py b/regression-tests.dnsdist/test_DOQ.py index 9a87b62255fc..965e9d6dd32a 100644 --- a/regression-tests.dnsdist/test_DOQ.py +++ b/regression-tests.dnsdist/test_DOQ.py @@ -5,6 +5,7 @@ from dnsdisttests import DNSDistTest from dnsdisttests import pickAvailablePort from doqclient import quic_bogus_query +from quictests import QUICTests, QUICWithCacheTests import doqclient class TestDOQBogus(DNSDistTest): @@ -37,7 +38,7 @@ def testDOQBogus(self): except doqclient.StreamResetError as e : self.assertEqual(e.error, 2); -class TestDOQ(DNSDistTest): +class TestDOQ(QUICTests, DNSDistTest): _serverKey = 'server.key' _serverCert = 'server.chain' _serverName = 'tls.tests.dnsdist.org' @@ -56,120 +57,13 @@ class TestDOQ(DNSDistTest): _config_params = ['_testServerPort', '_doqServerPort','_serverCert', '_serverKey'] _verboseMode = True - def testDOQSimple(self): - """ - DOQ: Simple query - """ - name = 'simple.doq.tests.powerdns.com.' - query = dns.message.make_query(name, 'A', 'IN', use_edns=False) - query.id = 0 - expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096) - expectedQuery.id = 0 - response = dns.message.make_response(query) - rrset = dns.rrset.from_text(name, - 3600, - dns.rdataclass.IN, - dns.rdatatype.A, - '127.0.0.1') - response.answer.append(rrset) - (receivedQuery, receivedResponse) = self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, serverName=self._serverName) - self.assertTrue(receivedQuery) - self.assertTrue(receivedResponse) - receivedQuery.id = expectedQuery.id - self.assertEqual(expectedQuery, receivedQuery) - - def testDOQMultipleStreams(self): - """ - DOQ: Test multiple queries using the same connection - """ - - name = 'simple.doq.tests.powerdns.com.' - query = dns.message.make_query(name, 'A', 'IN', use_edns=False) - query.id = 0 - expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096) - expectedQuery.id = 0 - response = dns.message.make_response(query) - rrset = dns.rrset.from_text(name, - 3600, - dns.rdataclass.IN, - dns.rdatatype.A, - '127.0.0.1') - response.answer.append(rrset) - - connection = self.getDOQConnection(self._doqServerPort, self._caCert) - - (receivedQuery, receivedResponse) = self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, serverName=self._serverName, connection=connection) - self.assertTrue(receivedQuery) - self.assertTrue(receivedResponse) - receivedQuery.id = expectedQuery.id - self.assertEqual(expectedQuery, receivedQuery) - - (receivedQuery, receivedResponse) = self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, serverName=self._serverName, connection=connection) - self.assertTrue(receivedQuery) - self.assertTrue(receivedResponse) - receivedQuery.id = expectedQuery.id - self.assertEqual(expectedQuery, receivedQuery) - - def testDropped(self): - """ - DOQ: Dropped query - """ - name = 'drop.doq.tests.powerdns.com.' - query = dns.message.make_query(name, 'A', 'IN') - dropped = False - try: - (_, receivedResponse) = self.sendDOQQuery(self._doqServerPort, query, response=None, caFile=self._caCert, useQueue=False, serverName=self._serverName) - self.assertTrue(False) - except doqclient.StreamResetError as e : - self.assertEqual(e.error, 5); + def getQUICConnection(self): + return self.getDOQConnection(self._doqServerPort, self._caCert) - def testRefused(self): - """ - DOQ: Refused - """ - name = 'refused.doq.tests.powerdns.com.' - query = dns.message.make_query(name, 'A', 'IN') - query.id = 0 - query.flags &= ~dns.flags.RD - expectedResponse = dns.message.make_response(query) - expectedResponse.set_rcode(dns.rcode.REFUSED) - - (_, receivedResponse) = self.sendDOQQuery(self._doqServerPort, query, response=None, caFile=self._caCert, useQueue=False, serverName=self._serverName) - self.assertEqual(receivedResponse, expectedResponse) - - def testSpoof(self): - """ - DOQ: Spoofed - """ - name = 'spoof.doq.tests.powerdns.com.' - query = dns.message.make_query(name, 'A', 'IN') - query.id = 0 - query.flags &= ~dns.flags.RD - expectedResponse = dns.message.make_response(query) - rrset = dns.rrset.from_text(name, - 3600, - dns.rdataclass.IN, - dns.rdatatype.A, - '1.2.3.4') - expectedResponse.answer.append(rrset) - - (_, receivedResponse) = self.sendDOQQuery(self._doqServerPort, query, response=None, caFile=self._caCert, useQueue=False, serverName=self._serverName) - self.assertEqual(receivedResponse, expectedResponse) - - def testDOQNoBackend(self): - """ - DOQ: No backend - """ - name = 'no-backend.doq.tests.powerdns.com.' - query = dns.message.make_query(name, 'A', 'IN', use_edns=False) - dropped = False - try: - (_, receivedResponse) = self.sendDOQQuery(self._doqServerPort, query, response=None, caFile=self._caCert, useQueue=False, serverName=self._serverName) - self.assertTrue(False) - except doqclient.StreamResetError as e : - self.assertEqual(e.error, 5); + def sendQUICQuery(self, query, response=None, useQueue=True, connection=None): + return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName, connection=connection) -class TestDOQWithCache(DNSDistTest): +class TestDOQWithCache(QUICWithCacheTests, DNSDistTest): _serverKey = 'server.key' _serverCert = 'server.chain' _serverName = 'tls.tests.dnsdist.org' @@ -186,41 +80,8 @@ class TestDOQWithCache(DNSDistTest): _config_params = ['_testServerPort', '_doqServerPort','_serverCert', '_serverKey'] _verboseMode = True - def testCached(self): - """ - Cache: Served from cache - - dnsdist is configured to cache entries, we are sending several - identical requests and checking that the backend only receive - the first one. - """ - numberOfQueries = 10 - name = 'cached.cache.tests.powerdns.com.' - query = dns.message.make_query(name, 'AAAA', 'IN') - query.id = 0 - response = dns.message.make_response(query) - rrset = dns.rrset.from_text(name, - 3600, - dns.rdataclass.IN, - dns.rdatatype.AAAA, - '::1') - response.answer.append(rrset) - - # first query to fill the cache - (receivedQuery, receivedResponse) = self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, serverName=self._serverName) - self.assertTrue(receivedQuery) - self.assertTrue(receivedResponse) - receivedQuery.id = query.id - self.assertEqual(query, receivedQuery) - self.assertEqual(receivedResponse, response) - - for _ in range(numberOfQueries): - (_, receivedResponse) = self.sendDOQQuery(self._doqServerPort, query, response=None, caFile=self._caCert, useQueue=False, serverName=self._serverName) - self.assertEqual(receivedResponse, response) - - total = 0 - for key in self._responsesCounter: - total += self._responsesCounter[key] - TestDOQWithCache._responsesCounter[key] = 0 + def getQUICConnection(self): + return self.getDOQConnection(self._doqServerPort, self._caCert) - self.assertEqual(total, 1) + def sendQUICQuery(self, query, response=None, useQueue=True, connection=None): + return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName, connection=connection)