Skip to content

Commit

Permalink
dnsdist: Get rid of TCPCrossProtocolQuerySender
Browse files Browse the repository at this point in the history
We need this construct to deal with cross-protocol queries, like
queries received over TCP or DoT, but forwarded over DoH, because
the thread dealing with the client and the one dealing with the
backend will not be the same in that case, and we do not want to
have different threads touching the same TCP connections.
So we pass the query and response to the correct thread via pipes.
Until now we were allocating an additional object, TCPCrossProtocolQuerySender,
to deal with that case, but I noticed that the existing IncomingTCPConnectionState
object already does everything we need, except that it needs to
know that the response is a cross-protocol one in order to pass it
via the pipe instead of treating it in a different way. This can be
done by looking if the current thread ID differs from the one that
created this object: if it does, we are dealing with a cross-protocol
response and should pass it via the pipe, and if it does not we
can deal with it directly.
This change saves the need to allocate a new object wrapped in a
shared pointer for each cross-protocol query, which is quite nice.
  • Loading branch information
rgacogne committed Dec 6, 2022
1 parent c222a56 commit 516a000
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 60 deletions.
89 changes: 30 additions & 59 deletions pdns/dnsdist-tcp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,11 @@ void IncomingTCPConnectionState::updateIO(std::shared_ptr<IncomingTCPConnectionS
/* called from the backend code when a new response has been received */
void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPResponse&& response)
{
if (std::this_thread::get_id() != d_mainThreadID) {
handleCrossProtocolResponse(now, std::move(response));
return;
}

std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();

if (response.d_connection && response.d_connection->getDS() && response.d_connection->getDS()->d_config.useProxyProtocol) {
Expand Down Expand Up @@ -566,66 +571,11 @@ struct TCPCrossProtocolResponse
struct timeval d_now;
};

class TCPCrossProtocolQuerySender : public TCPQuerySender
{
public:
TCPCrossProtocolQuerySender(std::shared_ptr<IncomingTCPConnectionState>& state): d_state(state)
{
}

bool active() const override
{
return d_state->active();
}

const ClientState* getClientState() const override
{
return d_state->getClientState();
}

void handleResponse(const struct timeval& now, TCPResponse&& response) override
{
if (d_state->d_threadData.crossProtocolResponsesPipe == -1) {
throw std::runtime_error("Invalid pipe descriptor in TCP Cross Protocol Query Sender");
}

auto ptr = new TCPCrossProtocolResponse(std::move(response), d_state, now);
static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail");
ssize_t sent = write(d_state->d_threadData.crossProtocolResponsesPipe, &ptr, sizeof(ptr));
if (sent != sizeof(ptr)) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
++g_stats.tcpCrossProtocolResponsePipeFull;
vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because the pipe is full");
}
else {
vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror());
}
delete ptr;
}
}

void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override
{
handleResponse(now, std::move(response));
}

void notifyIOError(IDState&& query, const struct timeval& now) override
{
TCPResponse response(PacketBuffer(), std::move(query), nullptr);
handleResponse(now, std::move(response));
}

private:
std::shared_ptr<IncomingTCPConnectionState> d_state;
};

class TCPCrossProtocolQuery : public CrossProtocolQuery
{
public:
TCPCrossProtocolQuery(PacketBuffer&& buffer, IDState&& ids, std::shared_ptr<DownstreamState>& ds, std::shared_ptr<TCPCrossProtocolQuerySender>& sender): d_sender(sender)
TCPCrossProtocolQuery(PacketBuffer&& buffer, IDState&& ids, std::shared_ptr<DownstreamState> ds, std::shared_ptr<IncomingTCPConnectionState> sender): CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), ds), d_sender(std::move(sender))
{
query = InternalQuery(std::move(buffer), std::move(ids));
downstream = ds;
proxyProtocolPayloadSize = 0;
}

Expand All @@ -639,9 +589,31 @@ class TCPCrossProtocolQuery : public CrossProtocolQuery
}

private:
std::shared_ptr<TCPCrossProtocolQuerySender> d_sender;
std::shared_ptr<IncomingTCPConnectionState> d_sender;
};

void IncomingTCPConnectionState::handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response)
{
if (d_threadData.crossProtocolResponsesPipe == -1) {
throw std::runtime_error("Invalid pipe descriptor in TCP Cross Protocol Query Sender");
}

std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
auto ptr = new TCPCrossProtocolResponse(std::move(response), state, now);
static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail");
ssize_t sent = write(d_threadData.crossProtocolResponsesPipe, &ptr, sizeof(ptr));
if (sent != sizeof(ptr)) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
++g_stats.tcpCrossProtocolResponsePipeFull;
vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because the pipe is full");
}
else {
vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror());
}
delete ptr;
}
}

static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
{
if (state->d_querySize < sizeof(dnsheader)) {
Expand Down Expand Up @@ -784,8 +756,7 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
proxyProtocolPayload = getProxyProtocolPayload(dq);
}

auto incoming = std::make_shared<TCPCrossProtocolQuerySender>(state);
auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(state->d_buffer), std::move(ids), ds, incoming);
auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(state->d_buffer), std::move(ids), ds, state);
cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);

ds->passCrossProtocolQuery(std::move(cpq));
Expand Down
5 changes: 4 additions & 1 deletion pdns/dnsdistdist/dnsdist-tcp-upstream.hh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public:
class IncomingTCPConnectionState : public TCPQuerySender, public std::enable_shared_from_this<IncomingTCPConnectionState>
{
public:
IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique<IOStateHandler>(*threadData.mplexer, d_ci.fd)), d_threadData(threadData)
IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique<IOStateHandler>(*threadData.mplexer, d_ci.fd)), d_threadData(threadData), d_mainThreadID(std::this_thread::get_id())
{
d_origDest.reset();
d_origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
Expand Down Expand Up @@ -125,6 +125,8 @@ static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bo
void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override;
void notifyIOError(IDState&& query, const struct timeval& now) override;

void handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response);

void terminateClientConnection();
void queueQuery(TCPQuery&& query);

Expand Down Expand Up @@ -170,6 +172,7 @@ static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bo
size_t d_proxyProtocolNeed{0};
size_t d_queriesCount{0};
size_t d_currentQueriesCount{0};
std::thread::id d_mainThreadID;
uint16_t d_querySize{0};
State d_state{State::doingHandshake};
bool d_isXFR{false};
Expand Down
5 changes: 5 additions & 0 deletions pdns/dnsdistdist/dnsdist-tcp.hh
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ struct CrossProtocolQuery
{
}

CrossProtocolQuery(InternalQuery&& query_, std::shared_ptr<DownstreamState>& downstream_) :
query(std::move(query_)), downstream(downstream_)
{
}

CrossProtocolQuery(CrossProtocolQuery&& rhs) = delete;
virtual ~CrossProtocolQuery()
{
Expand Down

0 comments on commit 516a000

Please sign in to comment.