Skip to content

Commit

Permalink
Merge pull request PowerDNS#11604 from rgacogne/ddist-fix-proxyprotoc…
Browse files Browse the repository at this point in the history
…ol-tc-doh

dnsdist: Fix invalid proxy protocol payload on a DoH TC to TCP retry
  • Loading branch information
rgacogne authored May 16, 2022
2 parents 280a738 + 1c9c001 commit 86ec2ab
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 150 deletions.
1 change: 1 addition & 0 deletions pdns/dnsdist-tcp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,7 @@ static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& par
auto downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::string());

prependSizeToTCPQuery(query.d_buffer, proxyProtocolPayloadSize);
query.d_proxyProtocolPayloadAddedSize = proxyProtocolPayloadSize;
downstream->queueQuery(tqs, std::move(query));
}
catch (...) {
Expand Down
9 changes: 5 additions & 4 deletions pdns/dnsdistdist/dnsdist-tcp-downstream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,19 +169,20 @@ static void prepareQueryForSending(TCPQuery& query, uint16_t id, QueryState quer
if (query.d_proxyProtocolPayload.size() > 0 && !query.d_proxyProtocolPayloadAdded) {
query.d_buffer.insert(query.d_buffer.begin(), query.d_proxyProtocolPayload.begin(), query.d_proxyProtocolPayload.end());
query.d_proxyProtocolPayloadAdded = true;
query.d_proxyProtocolPayloadAddedSize = query.d_proxyProtocolPayload.size();
}
}
else if (connectionState == ConnectionState::proxySent) {
if (query.d_proxyProtocolPayloadAdded) {
if (query.d_buffer.size() < query.d_proxyProtocolPayload.size()) {
if (query.d_buffer.size() < query.d_proxyProtocolPayloadAddedSize) {
throw std::runtime_error("Trying to remove a proxy protocol payload of size " + std::to_string(query.d_proxyProtocolPayload.size()) + " from a buffer of size " + std::to_string(query.d_buffer.size()));
}
query.d_buffer.erase(query.d_buffer.begin(), query.d_buffer.begin() + query.d_proxyProtocolPayload.size());
query.d_buffer.erase(query.d_buffer.begin(), query.d_buffer.begin() + query.d_proxyProtocolPayloadAddedSize);
query.d_proxyProtocolPayloadAdded = false;
query.d_proxyProtocolPayloadAddedSize = 0;
}
}

editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayload.size() : 0, true);
editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayloadAddedSize : 0, true);
}

IOState TCPConnectionToBackend::queueNextQuery(std::shared_ptr<TCPConnectionToBackend>& conn)
Expand Down
19 changes: 3 additions & 16 deletions pdns/dnsdistdist/dnsdist-tcp.hh
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,8 @@ struct InternalQuery
{
}

InternalQuery(InternalQuery&& rhs) :
d_idstate(std::move(rhs.d_idstate)), d_proxyProtocolPayload(std::move(rhs.d_proxyProtocolPayload)), d_buffer(std::move(rhs.d_buffer)), d_xfrMasterSerial(rhs.d_xfrMasterSerial), d_xfrSerialCount(rhs.d_xfrSerialCount), d_downstreamFailures(rhs.d_downstreamFailures), d_xfrMasterSerialCount(rhs.d_xfrMasterSerialCount), d_proxyProtocolPayloadAdded(rhs.d_proxyProtocolPayloadAdded)
{
}
InternalQuery& operator=(InternalQuery&& rhs)
{
d_idstate = std::move(rhs.d_idstate);
d_buffer = std::move(rhs.d_buffer);
d_proxyProtocolPayload = std::move(rhs.d_proxyProtocolPayload);
d_xfrMasterSerial = rhs.d_xfrMasterSerial;
d_xfrSerialCount = rhs.d_xfrSerialCount;
d_downstreamFailures = rhs.d_downstreamFailures;
d_xfrMasterSerialCount = rhs.d_xfrMasterSerialCount;
d_proxyProtocolPayloadAdded = rhs.d_proxyProtocolPayloadAdded;
return *this;
}
InternalQuery(InternalQuery&& rhs) = default;
InternalQuery& operator=(InternalQuery&& rhs) = default;

InternalQuery(const InternalQuery& rhs) = delete;
InternalQuery& operator=(const InternalQuery& rhs) = delete;
Expand All @@ -111,6 +97,7 @@ struct InternalQuery
IDState d_idstate;
std::string d_proxyProtocolPayload;
PacketBuffer d_buffer;
uint32_t d_proxyProtocolPayloadAddedSize{0};
uint32_t d_xfrMasterSerial{0};
uint32_t d_xfrSerialCount{0};
uint32_t d_downstreamFailures{0};
Expand Down
2 changes: 1 addition & 1 deletion pdns/dnsdistdist/doh.cc
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ static void processDOHQuery(DOHUnitUniquePtr&& du)

if (du->downstream->d_config.useProxyProtocol) {
size_t payloadSize = 0;
if (addProxyProtocol(dq)) {
if (addProxyProtocol(dq, &payloadSize)) {
du->proxyProtocolPayloadSize = payloadSize;
}
}
Expand Down
145 changes: 145 additions & 0 deletions regression-tests.dnsdist/dnsdistdohtests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#!/usr/bin/env python
import base64
import dns
import os
import unittest

from dnsdisttests import DNSDistTest

import pycurl
from io import BytesIO

@unittest.skipIf('SKIP_DOH_TESTS' in os.environ, 'DNS over HTTPS tests are disabled')
class DNSDistDOHTest(DNSDistTest):

@classmethod
def getDOHGetURL(cls, baseurl, query, rawQuery=False):
if rawQuery:
wire = query
else:
wire = query.to_wire()
param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
return baseurl + "?dns=" + param

@classmethod
def openDOHConnection(cls, port, caFile, timeout=2.0):
conn = pycurl.Curl()
conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)

conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message",
"Accept: application/dns-message"])
return conn

@classmethod
def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None):
url = cls.getDOHGetURL(baseurl, query, rawQuery)
conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
response_headers = BytesIO()
#conn.setopt(pycurl.VERBOSE, True)
conn.setopt(pycurl.URL, url)
conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
if useHTTPS:
conn.setopt(pycurl.SSL_VERIFYPEER, 1)
conn.setopt(pycurl.SSL_VERIFYHOST, 2)
if caFile:
conn.setopt(pycurl.CAINFO, caFile)

conn.setopt(pycurl.HTTPHEADER, customHeaders)
conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)

if response:
if toQueue:
toQueue.put(response, True, timeout)
else:
cls._toResponderQueue.put(response, True, timeout)

receivedQuery = None
message = None
cls._response_headers = ''
data = conn.perform_rb()
cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
if cls._rcode == 200 and not rawResponse:
message = dns.message.from_wire(data)
elif rawResponse:
message = data

if useQueue:
if fromQueue:
if not fromQueue.empty():
receivedQuery = fromQueue.get(True, timeout)
else:
if not cls._fromResponderQueue.empty():
receivedQuery = cls._fromResponderQueue.get(True, timeout)

cls._response_headers = response_headers.getvalue()
return (receivedQuery, message)

@classmethod
def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True):
url = baseurl
conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
response_headers = BytesIO()
#conn.setopt(pycurl.VERBOSE, True)
conn.setopt(pycurl.URL, url)
conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
if useHTTPS:
conn.setopt(pycurl.SSL_VERIFYPEER, 1)
conn.setopt(pycurl.SSL_VERIFYHOST, 2)
if caFile:
conn.setopt(pycurl.CAINFO, caFile)

conn.setopt(pycurl.HTTPHEADER, customHeaders)
conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
conn.setopt(pycurl.POST, True)
data = query
if not rawQuery:
data = data.to_wire()

conn.setopt(pycurl.POSTFIELDS, data)

if response:
cls._toResponderQueue.put(response, True, timeout)

receivedQuery = None
message = None
cls._response_headers = ''
data = conn.perform_rb()
cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
if cls._rcode == 200 and not rawResponse:
message = dns.message.from_wire(data)
elif rawResponse:
message = data

if useQueue and not cls._fromResponderQueue.empty():
receivedQuery = cls._fromResponderQueue.get(True, timeout)

cls._response_headers = response_headers.getvalue()
return (receivedQuery, message)

def getHeaderValue(self, name):
for header in self._response_headers.decode().splitlines(False):
values = header.split(':')
key = values[0]
if key.lower() == name.lower():
return values[1].strip()
return None

def checkHasHeader(self, name, value):
got = self.getHeaderValue(name)
self.assertEqual(got, value)

def checkNoHeader(self, name):
self.checkHasHeader(name, None)

@classmethod
def setUpClass(cls):

# for some reason, @unittest.skipIf() is not applied to derived classes with some versions of Python
if 'SKIP_DOH_TESTS' in os.environ:
raise unittest.SkipTest('DNS over HTTPS tests are disabled')

cls.startResponders()
cls.startDNSDist()
cls.setUpSockets()

print("Launching tests..")
131 changes: 2 additions & 129 deletions regression-tests.dnsdist/test_DOH.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,142 +2,15 @@
import base64
import dns
import os
import re
import time
import unittest
import clientsubnetoption
from dnsdisttests import DNSDistTest

from dnsdistdohtests import DNSDistDOHTest

import pycurl
from io import BytesIO

@unittest.skipIf('SKIP_DOH_TESTS' in os.environ, 'DNS over HTTPS tests are disabled')
class DNSDistDOHTest(DNSDistTest):

@classmethod
def getDOHGetURL(cls, baseurl, query, rawQuery=False):
if rawQuery:
wire = query
else:
wire = query.to_wire()
param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
return baseurl + "?dns=" + param

@classmethod
def openDOHConnection(cls, port, caFile, timeout=2.0):
conn = pycurl.Curl()
conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)

conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message",
"Accept: application/dns-message"])
return conn

@classmethod
def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True):
url = cls.getDOHGetURL(baseurl, query, rawQuery)
conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
response_headers = BytesIO()
#conn.setopt(pycurl.VERBOSE, True)
conn.setopt(pycurl.URL, url)
conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
if useHTTPS:
conn.setopt(pycurl.SSL_VERIFYPEER, 1)
conn.setopt(pycurl.SSL_VERIFYHOST, 2)
if caFile:
conn.setopt(pycurl.CAINFO, caFile)

conn.setopt(pycurl.HTTPHEADER, customHeaders)
conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)

if response:
cls._toResponderQueue.put(response, True, timeout)

receivedQuery = None
message = None
cls._response_headers = ''
data = conn.perform_rb()
cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
if cls._rcode == 200 and not rawResponse:
message = dns.message.from_wire(data)
elif rawResponse:
message = data

if useQueue and not cls._fromResponderQueue.empty():
receivedQuery = cls._fromResponderQueue.get(True, timeout)

cls._response_headers = response_headers.getvalue()
return (receivedQuery, message)

@classmethod
def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True):
url = baseurl
conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
response_headers = BytesIO()
#conn.setopt(pycurl.VERBOSE, True)
conn.setopt(pycurl.URL, url)
conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
if useHTTPS:
conn.setopt(pycurl.SSL_VERIFYPEER, 1)
conn.setopt(pycurl.SSL_VERIFYHOST, 2)
if caFile:
conn.setopt(pycurl.CAINFO, caFile)

conn.setopt(pycurl.HTTPHEADER, customHeaders)
conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
conn.setopt(pycurl.POST, True)
data = query
if not rawQuery:
data = data.to_wire()

conn.setopt(pycurl.POSTFIELDS, data)

if response:
cls._toResponderQueue.put(response, True, timeout)

receivedQuery = None
message = None
cls._response_headers = ''
data = conn.perform_rb()
cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
if cls._rcode == 200 and not rawResponse:
message = dns.message.from_wire(data)
elif rawResponse:
message = data

if useQueue and not cls._fromResponderQueue.empty():
receivedQuery = cls._fromResponderQueue.get(True, timeout)

cls._response_headers = response_headers.getvalue()
return (receivedQuery, message)

def getHeaderValue(self, name):
for header in self._response_headers.decode().splitlines(False):
values = header.split(':')
key = values[0]
if key.lower() == name.lower():
return values[1].strip()
return None

def checkHasHeader(self, name, value):
got = self.getHeaderValue(name)
self.assertEqual(got, value)

def checkNoHeader(self, name):
self.checkHasHeader(name, None)

@classmethod
def setUpClass(cls):

# for some reason, @unittest.skipIf() is not applied to derived classes with some versions of Python
if 'SKIP_DOH_TESTS' in os.environ:
raise unittest.SkipTest('DNS over HTTPS tests are disabled')

cls.startResponders()
cls.startDNSDist()
cls.setUpSockets()

print("Launching tests..")

class TestDOH(DNSDistDOHTest):

_serverKey = 'server.key'
Expand Down
Loading

0 comments on commit 86ec2ab

Please sign in to comment.