Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework DNSServer to be more robust #5573

Merged
merged 3 commits into from
Jan 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
276 changes: 169 additions & 107 deletions libraries/DNSServer/src/DNSServer.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
#include "DNSServer.h"
#include <lwip/def.h>
#include <Arduino.h>
#include <memory>

#ifdef DEBUG_ESP_PORT
#define DEBUG_OUTPUT DEBUG_ESP_PORT
#else
#define DEBUG_OUTPUT Serial
#endif

#define DNS_HEADER_SIZE sizeof(DNSHeader)

DNSServer::DNSServer()
{
_ttl = lwip_htonl(60);
Expand Down Expand Up @@ -46,149 +49,208 @@ void DNSServer::stop()
void DNSServer::downcaseAndRemoveWwwPrefix(String &domainName)
{
domainName.toLowerCase();
domainName.replace("www.", "");
if (domainName.startsWith("www."))
domainName.remove(0, 4);
}

void DNSServer::processNextRequest()
void DNSServer::respondToRequest(uint8_t *buffer, size_t length)
{
size_t packetSize = _udp.parsePacket();
DNSHeader *dnsHeader;
uint8_t *query, *start;
const char *matchString;
size_t remaining, labelLength, queryLength;
uint16_t qtype, qclass;

dnsHeader = (DNSHeader *)buffer;

if (packetSize >= sizeof(DNSHeader))
{
uint8_t* buffer = reinterpret_cast<uint8_t*>(malloc(packetSize));
if (buffer == NULL) return;
// Must be a query for us to do anything with it
if (dnsHeader->QR != DNS_QR_QUERY)
return;

_udp.read(buffer, packetSize);
// If operation is anything other than query, we don't do it
if (dnsHeader->OPCode != DNS_OPCODE_QUERY)
return replyWithError(dnsHeader, DNSReplyCode::NotImplemented);

// Only support requests containing single queries - everything else
// is badly defined
if (dnsHeader->QDCount != lwip_htons(1))
return replyWithError(dnsHeader, DNSReplyCode::FormError);

// We must return a FormError in the case of a non-zero ARCount to
// be minimally compatible with EDNS resolvers
if (dnsHeader->ANCount != 0 || dnsHeader->NSCount != 0
|| dnsHeader->ARCount != 0)
return replyWithError(dnsHeader, DNSReplyCode::FormError);

// Even if we're not going to use the query, we need to parse it
// so we can check the address type that's being queried

query = start = buffer + DNS_HEADER_SIZE;
remaining = length - DNS_HEADER_SIZE;
while (remaining != 0 && *start != 0) {
labelLength = *start;
if (labelLength + 1 > remaining)
return replyWithError(dnsHeader, DNSReplyCode::FormError);
remaining -= (labelLength + 1);
start += (labelLength + 1);
}

DNSHeader* dnsHeader = reinterpret_cast<DNSHeader*>(buffer);
// 1 octet labelLength, 2 octet qtype, 2 octet qclass
if (remaining < 5)
return replyWithError(dnsHeader, DNSReplyCode::FormError);

if (dnsHeader->QR == DNS_QR_QUERY &&
dnsHeader->OPCode == DNS_OPCODE_QUERY &&
requestIncludesOnlyOneQuestion(dnsHeader) &&
(_domainName == "*" || getDomainNameWithoutWwwPrefix(buffer, packetSize) == _domainName)
)
{
replyWithIP(buffer, packetSize);
}
else if (dnsHeader->QR == DNS_QR_QUERY)
{
replyWithCustomCode(buffer, packetSize);
start += 1; // Skip the 0 length label that we found above

memcpy(&qtype, start, sizeof(qtype));
start += 2;
memcpy(&qclass, start, sizeof(qclass));
start += 2;

queryLength = start - query;

if (qclass != lwip_htons(DNS_QCLASS_ANY)
&& qclass != lwip_htons(DNS_QCLASS_IN))
return replyWithError(dnsHeader, DNSReplyCode::NonExistentDomain,
query, queryLength);

if (qtype != lwip_htons(DNS_QTYPE_A)
&& qtype != lwip_htons(DNS_QTYPE_ANY))
return replyWithError(dnsHeader, DNSReplyCode::NonExistentDomain,
query, queryLength);

// If we have no domain name configured, just return an error
if (_domainName == "")
return replyWithError(dnsHeader, _errorReplyCode,
query, queryLength);

// If we're running with a wildcard we can just return a result now
if (_domainName == "*")
devyte marked this conversation as resolved.
Show resolved Hide resolved
return replyWithIP(dnsHeader, query, queryLength);

matchString = _domainName.c_str();

start = query;

// If there's a leading 'www', skip it
if (*start == 3 && strncasecmp("www", (char *) start + 1, 3) == 0)
start += 4;

while (*start != 0) {
labelLength = *start;
start += 1;
while (labelLength > 0) {
if (tolower(*start) != *matchString)
return replyWithError(dnsHeader, _errorReplyCode,
query, queryLength);
++start;
++matchString;
--labelLength;
}
if (*start == 0 && *matchString == '\0')
return replyWithIP(dnsHeader, query, queryLength);

free(buffer);
if (*matchString != '.')
return replyWithError(dnsHeader, _errorReplyCode,
query, queryLength);
++matchString;
}
}

bool DNSServer::requestIncludesOnlyOneQuestion(const DNSHeader* dnsHeader)
{
return lwip_ntohs(dnsHeader->QDCount) == 1 &&
dnsHeader->ANCount == 0 &&
dnsHeader->NSCount == 0 &&
dnsHeader->ARCount == 0;
return replyWithError(dnsHeader, _errorReplyCode,
query, queryLength);
}

String DNSServer::getDomainNameWithoutWwwPrefix(const uint8_t* buffer, size_t packetSize)
void DNSServer::processNextRequest()
{
String parsedDomainName;

const uint8_t* pos = buffer + sizeof(DNSHeader);
const uint8_t* end = buffer + packetSize;

// to minimize reallocations due to concats below
// we reserve enough space that a median or average domain
// name size cold be easily contained without a reallocation
// - max size would be 253, in 2013, average is 11 and max was 42
//
parsedDomainName.reserve(32);

uint8_t labelLength = *pos;

while (true)
{
if (labelLength == 0)
{
// no more labels
downcaseAndRemoveWwwPrefix(parsedDomainName);
return parsedDomainName;
}
size_t currentPacketSize;

// append next label
for (int i = 0; i < labelLength && pos < end; i++)
{
pos++;
parsedDomainName += static_cast<char>(*pos);
}
currentPacketSize = _udp.parsePacket();
if (currentPacketSize == 0)
return;

if (pos >= end)
{
// malformed packet, return an empty domain name
parsedDomainName = "";
return parsedDomainName;
}
else
{
// next label
pos++;
labelLength = *pos;

// if there is another label, add delimiter
if (labelLength != 0)
{
parsedDomainName += ".";
}
}
}
// The DNS RFC requires that DNS packets be less than 512 bytes in size,
// so just discard them if they are larger
if (currentPacketSize > MAX_DNS_PACKETSIZE)
return;

// If the packet size is smaller than the DNS header, then someone is
// messing with us
if (currentPacketSize < DNS_HEADER_SIZE)
return;

std::unique_ptr<uint8_t[]> buffer(new (std::nothrow) uint8_t[currentPacketSize]);

if (buffer == NULL)
return;

_udp.read(buffer.get(), currentPacketSize);
respondToRequest(buffer.get(), currentPacketSize);
}

void DNSServer::writeNBOShort(uint16_t value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this factoring out, but my previous comment about the literal 2 in the write() below is still not addressed.
I suggest the following:
_udp.write((usigned char *)&value, sizeof(value));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll change this. Writing 2 octets of data here is a property of the protocol, though, and not of the C++ representation that's feeding it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand, and that's why I like the factoring out into a function: it makes it clear that it's a protocol property. My comment is more for maintainability and readability. If the function ever gets copied for e.g.: 4 octets or whatever, or templated for the arg type, with sizeof() there's no need to change anything => less error-prone. And when glancing at the code, sizeof() makes what is going on evident.
It's just good habits and programming guidelines, not that there's anything with your original code.

Also, I'll be porting these changes to my async dnsserver 😆

{
_udp.write((unsigned char *)&value, 2);
}

void DNSServer::replyWithIP(uint8_t* buffer, size_t packetSize)
void DNSServer::replyWithIP(DNSHeader *dnsHeader,
unsigned char * query,
size_t queryLength)
{
DNSHeader* dnsHeader = reinterpret_cast<DNSHeader*>(buffer);
uint16_t value;

dnsHeader->QR = DNS_QR_RESPONSE;
dnsHeader->ANCount = dnsHeader->QDCount;
dnsHeader->QDCount = dnsHeader->QDCount;
//dnsHeader->RA = 1;
dnsHeader->QDCount = lwip_htons(1);
dnsHeader->ANCount = lwip_htons(1);
dnsHeader->NSCount = 0;
dnsHeader->ARCount = 0;

_udp.beginPacket(_udp.remoteIP(), _udp.remotePort());
_udp.write(buffer, packetSize);
_udp.write((unsigned char *) dnsHeader, sizeof(DNSHeader));
_udp.write(query, queryLength);

// Rather than restate the name here, we use a pointer to the name contained
// in the query section. Pointers have the top two bits set.
value = 0xC000 | DNS_HEADER_SIZE;
writeNBOShort(lwip_htons(value));

_udp.write((uint8_t)192); // answer name is a pointer
_udp.write((uint8_t)12); // pointer to offset at 0x00c
// Answer is type A (an IPv4 address)
writeNBOShort(lwip_htons(DNS_QTYPE_A));

_udp.write((uint8_t)0); // 0x0001 answer is type A query (host address)
_udp.write((uint8_t)1);
// Answer is in the Internet Class
writeNBOShort(lwip_htons(DNS_QCLASS_IN));

_udp.write((uint8_t)0); //0x0001 answer is class IN (internet address)
_udp.write((uint8_t)1);

// Output TTL (already NBO)
_udp.write((unsigned char*)&_ttl, 4);

// Length of RData is 4 bytes (because, in this case, RData is IPv4)
_udp.write((uint8_t)0);
_udp.write((uint8_t)4);
writeNBOShort(lwip_htons(sizeof(_resolvedIP)));
_udp.write(_resolvedIP, sizeof(_resolvedIP));
_udp.endPacket();

#ifdef DEBUG_ESP_DNS
DEBUG_OUTPUT.printf("DNS responds: %s for %s\n",
IPAddress(_resolvedIP).toString().c_str(), getDomainNameWithoutWwwPrefix(buffer, packetSize).c_str() );
#endif
}

void DNSServer::replyWithCustomCode(uint8_t* buffer, size_t packetSize)
void DNSServer::replyWithError(DNSHeader *dnsHeader,
DNSReplyCode rcode,
unsigned char *query,
size_t queryLength)
{
if (packetSize < sizeof(DNSHeader))
{
return;
}

DNSHeader* dnsHeader = reinterpret_cast<DNSHeader*>(buffer);

dnsHeader->QR = DNS_QR_RESPONSE;
dnsHeader->RCode = (unsigned char)_errorReplyCode;
dnsHeader->QDCount = 0;
dnsHeader->RCode = (unsigned char) rcode;
if (query)
dnsHeader->QDCount = lwip_htons(1);
else
dnsHeader->QDCount = 0;
dnsHeader->ANCount = 0;
dnsHeader->NSCount = 0;
dnsHeader->ARCount = 0;

_udp.beginPacket(_udp.remoteIP(), _udp.remotePort());
_udp.write(buffer, sizeof(DNSHeader));
_udp.write((unsigned char *)dnsHeader, sizeof(DNSHeader));
if (query != NULL)
_udp.write(query, queryLength);
_udp.endPacket();
}

void DNSServer::replyWithError(DNSHeader *dnsHeader,
DNSReplyCode rcode)
{
replyWithError(dnsHeader, rcode, NULL, 0);
}
24 changes: 19 additions & 5 deletions libraries/DNSServer/src/DNSServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
#define DNS_QR_RESPONSE 1
#define DNS_OPCODE_QUERY 0

#define DNS_QCLASS_IN 1
#define DNS_QCLASS_ANY 255

#define DNS_QTYPE_A 1
#define DNS_QTYPE_ANY 255

#define MAX_DNSNAME_LENGTH 253
#define MAX_DNS_PACKETSIZE 512

enum class DNSReplyCode
{
Expand Down Expand Up @@ -65,9 +72,16 @@ class DNSServer
DNSReplyCode _errorReplyCode;

void downcaseAndRemoveWwwPrefix(String &domainName);
String getDomainNameWithoutWwwPrefix(const uint8_t* buffer, size_t packetSize);
bool requestIncludesOnlyOneQuestion(const DNSHeader* dnsHeader);
void replyWithIP(uint8_t* buffer, size_t packetSize);
void replyWithCustomCode(uint8_t* buffer, size_t packetSize);
void replyWithIP(DNSHeader *dnsHeader,
unsigned char * query,
size_t queryLength);
void replyWithError(DNSHeader *dnsHeader,
DNSReplyCode rcode,
unsigned char *query,
devyte marked this conversation as resolved.
Show resolved Hide resolved
size_t queryLength);
void replyWithError(DNSHeader *dnsHeader,
DNSReplyCode rcode);
void respondToRequest(uint8_t *buffer, size_t length);
void writeNBOShort(uint16_t value);
};
#endif
#endif