diff --git a/libraries/DNSServer/src/DNSServer.cpp b/libraries/DNSServer/src/DNSServer.cpp index 42b6a88515..25f382e354 100644 --- a/libraries/DNSServer/src/DNSServer.cpp +++ b/libraries/DNSServer/src/DNSServer.cpp @@ -1,6 +1,7 @@ #include "DNSServer.h" #include #include +#include #ifdef DEBUG_ESP_PORT #define DEBUG_OUTPUT DEBUG_ESP_PORT @@ -8,6 +9,8 @@ #define DEBUG_OUTPUT Serial #endif +#define DNS_HEADER_SIZE sizeof(DNSHeader) + DNSServer::DNSServer() { _ttl = lwip_htonl(60); @@ -50,108 +53,154 @@ void DNSServer::downcaseAndRemoveWwwPrefix(String &domainName) domainName.remove(0, 4); } -void DNSServer::processNextRequest() +void DNSServer::respondToRequest(uint8_t *buffer, size_t length) { - size_t packetSize = _udp.parsePacket(); + DNSHeader *dnsHeader; + uint8_t *query, *start; + const char *matchString; + size_t remaining, labelLength, queryLength; + uint16_t qtype, qclass; + + dnsHeader = (DNSHeader *)buffer; - if (packetSize >= sizeof(DNSHeader)) - { - uint8_t* buffer = reinterpret_cast(malloc(packetSize)); - if (buffer == NULL) return; + // Must be a query for us to do anything with it + if (dnsHeader->QR != DNS_QR_QUERY) + return; - _udp.read(buffer, packetSize); + // If operation is anything other than query, we don't do it + if (dnsHeader->OPCode != DNS_OPCODE_QUERY) + return replyWithError(dnsHeader, DNSReplyCode::NotImplemented); + + // Only support requests containing single queries - everything else + // is badly defined + if (dnsHeader->QDCount != lwip_htons(1)) + return replyWithError(dnsHeader, DNSReplyCode::FormError); + + // We must return a FormError in the case of a non-zero ARCount to + // be minimally compatible with EDNS resolvers + if (dnsHeader->ANCount != 0 || dnsHeader->NSCount != 0 + || dnsHeader->ARCount != 0) + return replyWithError(dnsHeader, DNSReplyCode::FormError); + + // Even if we're not going to use the query, we need to parse it + // so we can check the address type that's being queried + + query = start = buffer + DNS_HEADER_SIZE; + remaining = length - DNS_HEADER_SIZE; + while (remaining != 0 && *start != 0) { + labelLength = *start; + if (labelLength + 1 > remaining) + return replyWithError(dnsHeader, DNSReplyCode::FormError); + remaining -= (labelLength + 1); + start += (labelLength + 1); + } - DNSHeader* dnsHeader = reinterpret_cast(buffer); + // 1 octet labelLength, 2 octet qtype, 2 octet qclass + if (remaining < 5) + return replyWithError(dnsHeader, DNSReplyCode::FormError); - if (dnsHeader->QR == DNS_QR_QUERY && - dnsHeader->OPCode == DNS_OPCODE_QUERY && - requestIncludesOnlyOneQuestion(dnsHeader) && - (_domainName == "*" || getDomainNameWithoutWwwPrefix(buffer, packetSize) == _domainName) - ) - { - replyWithIP(buffer, packetSize); - } - else if (dnsHeader->QR == DNS_QR_QUERY) - { - replyWithCustomCode(buffer, packetSize); + start += 1; // Skip the 0 length label that we found above + + memcpy(&qtype, start, sizeof(qtype)); + start += 2; + memcpy(&qclass, start, sizeof(qclass)); + start += 2; + + queryLength = start - query; + + if (qclass != lwip_htons(DNS_QCLASS_ANY) + && qclass != lwip_htons(DNS_QCLASS_IN)) + return replyWithError(dnsHeader, DNSReplyCode::NonExistentDomain, + query, queryLength); + + if (qtype != lwip_htons(DNS_QTYPE_A) + && qtype != lwip_htons(DNS_QTYPE_ANY)) + return replyWithError(dnsHeader, DNSReplyCode::NonExistentDomain, + query, queryLength); + + // If we have no domain name configured, just return an error + if (_domainName == "") + return replyWithError(dnsHeader, _errorReplyCode, + query, queryLength); + + // If we're running with a wildcard we can just return a result now + if (_domainName == "*") + return replyWithIP(dnsHeader, query, queryLength); + + matchString = _domainName.c_str(); + + start = query; + + // If there's a leading 'www', skip it + if (*start == 3 && strncasecmp("www", (char *) start + 1, 3) == 0) + start += 4; + + while (*start != 0) { + labelLength = *start; + start += 1; + while (labelLength > 0) { + if (tolower(*start) != *matchString) + return replyWithError(dnsHeader, _errorReplyCode, + query, queryLength); + ++start; + ++matchString; + --labelLength; } + if (*start == 0 && *matchString == '\0') + return replyWithIP(dnsHeader, query, queryLength); - free(buffer); + if (*matchString != '.') + return replyWithError(dnsHeader, _errorReplyCode, + query, queryLength); + ++matchString; } -} -bool DNSServer::requestIncludesOnlyOneQuestion(const DNSHeader* dnsHeader) -{ - return lwip_ntohs(dnsHeader->QDCount) == 1 && - dnsHeader->ANCount == 0 && - dnsHeader->NSCount == 0 && - dnsHeader->ARCount == 0; + return replyWithError(dnsHeader, _errorReplyCode, + query, queryLength); } -String DNSServer::getDomainNameWithoutWwwPrefix(const uint8_t* buffer, size_t packetSize) +void DNSServer::processNextRequest() { - String parsedDomainName; - - const uint8_t* pos = buffer + sizeof(DNSHeader); - const uint8_t* end = buffer + packetSize; - - // to minimize reallocations due to concats below - // we reserve enough space that a median or average domain - // name size cold be easily contained without a reallocation - // - max size would be 253, in 2013, average is 11 and max was 42 - // - parsedDomainName.reserve(32); - - uint8_t labelLength = *pos; - - while (true) - { - if (labelLength == 0) - { - // no more labels - downcaseAndRemoveWwwPrefix(parsedDomainName); - return parsedDomainName; - } + size_t currentPacketSize; - // append next label - for (int i = 0; i < labelLength && pos < end; i++) - { - pos++; - parsedDomainName += static_cast(*pos); - } + currentPacketSize = _udp.parsePacket(); + if (currentPacketSize == 0) + return; - if (pos >= end) - { - // malformed packet, return an empty domain name - parsedDomainName = ""; - return parsedDomainName; - } - else - { - // next label - pos++; - labelLength = *pos; - - // if there is another label, add delimiter - if (labelLength != 0) - { - parsedDomainName += "."; - } - } - } + // The DNS RFC requires that DNS packets be less than 512 bytes in size, + // so just discard them if they are larger + if (currentPacketSize > MAX_DNS_PACKETSIZE) + return; + + // If the packet size is smaller than the DNS header, then someone is + // messing with us + if (currentPacketSize < DNS_HEADER_SIZE) + return; + + std::unique_ptr buffer(new (std::nothrow) uint8_t[currentPacketSize]); + + if (buffer == NULL) + return; + + _udp.read(buffer.get(), currentPacketSize); + respondToRequest(buffer.get(), currentPacketSize); } -void DNSServer::replyWithIP(uint8_t* buffer, size_t packetSize) +void DNSServer::replyWithIP(DNSHeader *dnsHeader, + unsigned char * query, + size_t queryLength) { - DNSHeader* dnsHeader = reinterpret_cast(buffer); - dnsHeader->QR = DNS_QR_RESPONSE; - dnsHeader->ANCount = dnsHeader->QDCount; - dnsHeader->QDCount = dnsHeader->QDCount; - //dnsHeader->RA = 1; + dnsHeader->QDCount = lwip_htons(1); + dnsHeader->ANCount = lwip_htons(1); + dnsHeader->NSCount = 0; + dnsHeader->ARCount = 0; + + //_dnsHeader->RA = 1; _udp.beginPacket(_udp.remoteIP(), _udp.remotePort()); - _udp.write(buffer, packetSize); + _udp.write((unsigned char *) dnsHeader, sizeof(DNSHeader)); + _udp.write(query, queryLength); _udp.write((uint8_t)192); // answer name is a pointer _udp.write((uint8_t)12); // pointer to offset at 0x00c @@ -169,27 +218,32 @@ void DNSServer::replyWithIP(uint8_t* buffer, size_t packetSize) _udp.write((uint8_t)4); _udp.write(_resolvedIP, sizeof(_resolvedIP)); _udp.endPacket(); - - #ifdef DEBUG_ESP_DNS - DEBUG_OUTPUT.printf("DNS responds: %s for %s\n", - IPAddress(_resolvedIP).toString().c_str(), getDomainNameWithoutWwwPrefix(buffer, packetSize).c_str() ); - #endif } -void DNSServer::replyWithCustomCode(uint8_t* buffer, size_t packetSize) +void DNSServer::replyWithError(DNSHeader *dnsHeader, + DNSReplyCode rcode, + unsigned char *query, + size_t queryLength) { - if (packetSize < sizeof(DNSHeader)) - { - return; - } - - DNSHeader* dnsHeader = reinterpret_cast(buffer); - dnsHeader->QR = DNS_QR_RESPONSE; - dnsHeader->RCode = (unsigned char)_errorReplyCode; - dnsHeader->QDCount = 0; + dnsHeader->RCode = (unsigned char) rcode; + if (query) + dnsHeader->QDCount = lwip_htons(1); + else + dnsHeader->QDCount = 0; + dnsHeader->ANCount = 0; + dnsHeader->NSCount = 0; + dnsHeader->ARCount = 0; _udp.beginPacket(_udp.remoteIP(), _udp.remotePort()); - _udp.write(buffer, sizeof(DNSHeader)); + _udp.write((unsigned char *)dnsHeader, sizeof(DNSHeader)); + if (query != NULL) + _udp.write(query, queryLength); _udp.endPacket(); } + +void DNSServer::replyWithError(DNSHeader *dnsHeader, + DNSReplyCode rcode) +{ + replyWithError(dnsHeader, rcode, NULL, 0); +} diff --git a/libraries/DNSServer/src/DNSServer.h b/libraries/DNSServer/src/DNSServer.h index d6e7de444d..95800e23af 100644 --- a/libraries/DNSServer/src/DNSServer.h +++ b/libraries/DNSServer/src/DNSServer.h @@ -6,7 +6,14 @@ #define DNS_QR_RESPONSE 1 #define DNS_OPCODE_QUERY 0 +#define DNS_QCLASS_IN 1 +#define DNS_QCLASS_ANY 255 + +#define DNS_QTYPE_A 1 +#define DNS_QTYPE_ANY 255 + #define MAX_DNSNAME_LENGTH 253 +#define MAX_DNS_PACKETSIZE 512 enum class DNSReplyCode { @@ -65,9 +72,15 @@ class DNSServer DNSReplyCode _errorReplyCode; void downcaseAndRemoveWwwPrefix(String &domainName); - String getDomainNameWithoutWwwPrefix(const uint8_t* buffer, size_t packetSize); - bool requestIncludesOnlyOneQuestion(const DNSHeader* dnsHeader); - void replyWithIP(uint8_t* buffer, size_t packetSize); - void replyWithCustomCode(uint8_t* buffer, size_t packetSize); + void replyWithIP(DNSHeader *dnsHeader, + unsigned char * query, + size_t queryLength); + void replyWithError(DNSHeader *dnsHeader, + DNSReplyCode rcode, + unsigned char *query, + size_t queryLength); + void replyWithError(DNSHeader *dnsHeader, + DNSReplyCode rcode); + void respondToRequest(uint8_t *buffer, size_t length); }; -#endif \ No newline at end of file +#endif