Skip to content

Commit

Permalink
PR for Duplicated UDP responses from DNS confuse LookupClient #140 (#142
Browse files Browse the repository at this point in the history
)

Xid mismatch validation added.

Co-authored-by: Boris <boris.dogadov@mongodb.com>
  • Loading branch information
JamesKovacs and BorisDog authored Jan 26, 2022
1 parent 2b38644 commit d56cb12
Show file tree
Hide file tree
Showing 8 changed files with 389 additions and 2 deletions.
8 changes: 8 additions & 0 deletions src/DnsClient/DnsMessageHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ public static bool IsTransientException<T>(T exception) where T : Exception
return false;
}

protected static void ValidateResponse(DnsRequestMessage request, DnsResponseMessage response)
{
if (request != null && response != null && request.Header.Id != response.Header.Id)
{
throw new DnsXidMismatchException(request.Header.Id, response.Header.Id);
}
}

public virtual void GetRequestData(DnsRequestMessage request, DnsDatagramWriter writer)
{
var question = request.Question;
Expand Down
2 changes: 2 additions & 0 deletions src/DnsClient/DnsTcpMessageHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ public override async Task<DnsResponseMessage> QueryAsync(
}
}

ValidateResponse(request, response);

return response;
}

Expand Down
5 changes: 4 additions & 1 deletion src/DnsClient/DnsUdpMessageHandler.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System;
using System.Collections.Concurrent;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Threading;
Expand Down Expand Up @@ -52,6 +51,8 @@ public override DnsResponseMessage Query(

var response = GetResponseMessage(new ArraySegment<byte>(memory.Buffer, 0, received));

ValidateResponse(request, response);

Enqueue(server.AddressFamily, udpClient);

return response;
Expand Down Expand Up @@ -121,6 +122,8 @@ public override async Task<DnsResponseMessage> QueryAsync(
var response = GetResponseMessage(new ArraySegment<byte>(result.Buffer, 0, result.Buffer.Length));
#endif

ValidateResponse(request, response);

Enqueue(endpoint.AddressFamily, udpClient);

return response;
Expand Down
23 changes: 23 additions & 0 deletions src/DnsClient/DnsXidMismatchException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using System;

#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
namespace DnsClient
{
#if !NETSTANDARD1_3
[Serializable]
#endif

public class DnsXidMismatchException : Exception
{
public int RequestXid { get; }
public int ResponseXid { get; }

public DnsXidMismatchException(int requestXid, int responseXid)
: base()
{
RequestXid = requestXid;
ResponseXid = responseXid;
}
}
}
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
77 changes: 77 additions & 0 deletions src/DnsClient/LookupClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,21 @@ private IDnsQueryResponse ResolveQuery(

return lastQueryResponse;
}
catch (DnsXidMismatchException ex)
{
var handle = HandleDnsXidMismatchException(ex, request, settings, handler.Type, isLastServer, isLastTry, tries);

if (handle == HandleError.RetryCurrentServer)
{
continue;
}
else if (handle == HandleError.RetryNextServer)
{
break;
}

throw;
}
catch (DnsResponseParseException ex)
{
var handle = HandleDnsResponeParseException(ex, request, handler.Type, isLastServer: isLastServer);
Expand Down Expand Up @@ -1167,6 +1182,21 @@ private async Task<IDnsQueryResponse> ResolveQueryAsync(

return lastQueryResponse;
}
catch (DnsXidMismatchException ex)
{
var handle = HandleDnsXidMismatchException(ex, request, settings, handler.Type, isLastServer: isLastServer, isLastTry: isLastTry, currentTry: tries);

if (handle == HandleError.RetryCurrentServer)
{
continue;
}
else if (handle == HandleError.RetryNextServer)
{
break;
}

throw;
}
catch (DnsResponseParseException ex)
{
var handle = HandleDnsResponeParseException(ex, request, handler.Type, isLastServer: isLastServer);
Expand Down Expand Up @@ -1373,6 +1403,53 @@ private HandleError HandleDnsResponseException(DnsResponseException ex, DnsReque
return handle;
}

private HandleError HandleDnsXidMismatchException(DnsXidMismatchException ex, DnsRequestMessage request, DnsQuerySettings settings, DnsMessageHandleType handleType, bool isLastServer, bool isLastTry, int currentTry)
{
// No more retries
if (isLastServer && isLastTry)
{
_logger.LogError(
LogEventQueryFail,
ex,
"Query {0} via {1} => {2} xid mismatch {3}. Throwing the error.",
ex.RequestXid,
handleType,
request.Question,
ex.ResponseXid);

return HandleError.Throw;
}

// Last try on the current server, try the nextServer
if (isLastTry)
{
_logger.LogError(
LogEventQueryRetryErrorNextServer,
ex,
"Query {0} via {1} => {2} xid mismatch {3}. Trying next server.",
ex.RequestXid,
handleType,
request.Question,
ex.ResponseXid);

return HandleError.RetryNextServer;
}

// Next try
_logger.LogWarning(
LogEventQueryRetryErrorNextServer,
ex,
"Query {0} via {1} => {2} xid mismatch {3}. Re-trying {4}/{5}...",
ex.RequestXid,
handleType,
request.Question,
ex.ResponseXid,
currentTry,
settings.Retries + 1);

return HandleError.RetryCurrentServer;
}

private HandleError HandleDnsResponeParseException(DnsResponseParseException ex, DnsRequestMessage request, DnsMessageHandleType handleType, bool isLastServer)
{
// Don't try to fallback to TCP if we already are on TCP
Expand Down
82 changes: 82 additions & 0 deletions test/DnsClient.Tests/LookupClientRetryTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,88 @@ public async Task DnsResponseParseException_ShouldTryNextServer_ThenThrow_Async(
Assert.Equal(3, calledIps.Count);
}

[Theory]
[InlineData(1, 1)]
[InlineData(1, 3)]
[InlineData(2, 1)]
[InlineData(3, 1)]
[InlineData(4, 4)]
public void DnsDnsXidMismatchException_ShouldRetry_ThenThrow(int serversCount, int retriesCount)
{
var nameServers = Enumerable.Range(1, serversCount)
.Select(i => new NameServer(IPAddress.Parse($"127.0.10.{i}")))
.ToArray();

var options = new LookupClientOptions(nameServers)
{
EnableAuditTrail = true,
ContinueOnDnsError = false,
ThrowDnsErrors = false,
UseCache = true,
Retries = retriesCount,
UseRandomNameServer = false,
UseTcpFallback = false
};

var calledIps = new List<IPAddress>();
var udpMessageHandler = new TestMessageHandler(DnsMessageHandleType.UDP, (ip, req) =>
{
calledIps.Add(ip.Address);
throw new DnsXidMismatchException(req.Header.Id, req.Header.Id + 1);
});

var lookup = new LookupClient(options, udpHandler: udpMessageHandler);
var result = Assert.ThrowsAny<DnsXidMismatchException>(() => lookup.Query(new DnsQuestion("test.com", QueryType.SRV, QueryClass.IN)));

var expectedIps = nameServers
.SelectMany(ns => Enumerable.Repeat(ns.IPEndPoint.Address, retriesCount + 1))
.ToArray();

Assert.Equal(expectedIps, calledIps);
}

[Theory]
[InlineData(1, 1)]
[InlineData(1, 3)]
[InlineData(2, 1)]
[InlineData(3, 1)]
[InlineData(4, 4)]
public async Task DnsDnsXidMismatchException_ShouldRetry_ThenThrow_Async(int serversCount, int retriesCount)
{
var nameServers = Enumerable.Range(1, serversCount)
.Select(i => new NameServer(IPAddress.Parse($"127.0.10.{i}")))
.ToArray();

var options = new LookupClientOptions(nameServers)
{
EnableAuditTrail = true,
ContinueOnDnsError = false,
ThrowDnsErrors = false,
UseCache = true,
Retries = retriesCount,
UseRandomNameServer = false,
UseTcpFallback = false
};

var calledIps = new List<IPAddress>();
var udpMessageHandler = new TestMessageHandler(DnsMessageHandleType.UDP, (ip, req) =>
{
calledIps.Add(ip.Address);
throw new DnsXidMismatchException(req.Header.Id, req.Header.Id + 1);
});

var lookup = new LookupClient(options, udpHandler: udpMessageHandler);
var result = await Assert.ThrowsAnyAsync<DnsXidMismatchException>(() => lookup.QueryAsync(new DnsQuestion("test.com", QueryType.SRV, QueryClass.IN)));

var expectedIps = nameServers
.SelectMany(ns => Enumerable.Repeat(ns.IPEndPoint.Address, retriesCount + 1))
.ToArray();

Assert.Equal(expectedIps, calledIps);
}

/* Normal truncated response (TC flag) */

[Fact]
Expand Down
67 changes: 66 additions & 1 deletion test/DnsClient.Tests/LookupTest.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using System;
using System.Linq;
using System.Net;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using DnsClient.Protocol;
Expand Down Expand Up @@ -1113,5 +1112,71 @@ public void Lookup_SettingsFallback_KeepProvidedServers3()
Assert.NotEqual(client.Settings, settings);
Assert.NotEqual(client.NameServers, settings.NameServers);
}

[Theory]
[InlineData(0, true)]
[InlineData(1, true)]
[InlineData(3, true)]
[InlineData(0, false)]
[InlineData(1, false)]
[InlineData(3, false)]
public async Task Lookup_XidMismatch(int mismatchResponses, bool sync)
{
var serverEndpoint = new IPEndPoint(IPAddress.Parse("127.0.0.1"), 54321);
var options = new LookupClientOptions(new NameServer(serverEndpoint))
{
Retries = 20,
UseCache = false
};

using var server = new UdpServerMistmatchXid(serverEndpoint, mismatchResponses);
var client = new LookupClient(options);

var dnsQuestion = new DnsQuestion("someservice", QueryType.TXT, QueryClass.IN);
var response = sync ? client.Query(dnsQuestion) : await client.QueryAsync(dnsQuestion);

Assert.Equal(2, response.Answers.TxtRecords().Count());
Assert.Equal("example.com.", response.Answers.TxtRecords().First().DomainName.Value);
Assert.Equal(mismatchResponses, server.MistmatchedResponsesCount);
Assert.Equal(mismatchResponses + 1, server.RequestsCount);
}

[Theory]
[InlineData(0, true)]
[InlineData(1, true)]
[InlineData(3, true)]
[InlineData(5, true)]
[InlineData(0, false)]
[InlineData(1, false)]
[InlineData(3, false)]
[InlineData(5, false)]
public async Task Lookup_DuplicateUDPResponses(int duplicatesCount, bool sync)
{
var serverEndpoint = new IPEndPoint(IPAddress.Parse("127.0.0.1"), 54321);
var options = new LookupClientOptions(new NameServer(serverEndpoint))
{
Retries = 20,
UseCache = false,
Timeout = TimeSpan.FromSeconds(5)
};

using var server = new UdpServerDuplicateResponses(serverEndpoint, duplicatesCount);
var client = new LookupClient(options);

var dnsQuestion = new DnsQuestion("someservice", QueryType.TXT, QueryClass.IN);
var response1 = sync ? client.Query(dnsQuestion) : await client.QueryAsync(dnsQuestion);
var response2 = sync ? client.Query(dnsQuestion) : await client.QueryAsync(dnsQuestion);

Assert.Equal(2, response1.Answers.TxtRecords().Count());
Assert.Equal("example.com.", response1.Answers.TxtRecords().First().DomainName.Value);

Assert.Equal(2, response2.Answers.TxtRecords().Count());
Assert.Equal("example.com.", response2.Answers.TxtRecords().First().DomainName.Value);

Assert.True(server.RequestsCount >= 2, "At least 2 requests are expected");

// Validate that duplicate response was not picked up
Assert.NotEqual(response1.Header.Id, response2.Header.Id);
}
}
}
Loading

0 comments on commit d56cb12

Please sign in to comment.