Skip to content

Commit

Permalink
[mdns] handle list change from DiscoverCallback (host or service)
Browse files Browse the repository at this point in the history
This commit updates how the `DiscoverCallback` are tracked in
`Mdns::Publisher` and how the callbacks are invoked. In particular,
the `mDiscoverCallbacks` list itself can get updated as the callbacks
are invoked. The new code ensures that we can handle such a situation
from both `OnServiceResolved()` or `OnHostResolved()`.
  • Loading branch information
abtink committed Sep 26, 2023
1 parent 4814a08 commit e8f07ad
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 29 deletions.
85 changes: 57 additions & 28 deletions src/mdns/mdns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,31 +174,24 @@ otbrError Publisher::DecodeTxtData(Publisher::TxtList &aTxtList, const uint8_t *

void Publisher::RemoveSubscriptionCallbacks(uint64_t aSubscriberId)
{
size_t erased;

OTBR_UNUSED_VARIABLE(erased);

assert(aSubscriberId > 0);

erased = mDiscoveredCallbacks.erase(aSubscriberId);

assert(erased == 1);
mDiscoverCallbacks.remove_if(
[aSubscriberId](DiscoverCallback &aCallback) { return (aCallback.mId == aSubscriberId); });
}

uint64_t Publisher::AddSubscriptionCallbacks(Publisher::DiscoveredServiceInstanceCallback aInstanceCallback,
Publisher::DiscoveredHostCallback aHostCallback)
{
uint64_t subscriberId = mNextSubscriberId++;
uint64_t id = mNextSubscriberId++;

assert(subscriberId > 0);
assert(id > 0);
mDiscoverCallbacks.emplace_back(id, aInstanceCallback, aHostCallback);

mDiscoveredCallbacks.emplace(subscriberId, std::make_pair(std::move(aInstanceCallback), std::move(aHostCallback)));
return subscriberId;
return id;
}

void Publisher::OnServiceResolved(std::string aType, DiscoveredInstanceInfo aInstanceInfo)
{
std::vector<uint64_t> subscriberIds;
bool checkToInvoke = false;

otbrLogInfo("Service %s is resolved successfully: %s %s host %s addresses %zu", aType.c_str(),
aInstanceInfo.mRemoved ? "remove" : "add", aInstanceInfo.mName.c_str(), aInstanceInfo.mHostName.c_str(),
Expand All @@ -216,22 +209,33 @@ void Publisher::OnServiceResolved(std::string aType, DiscoveredInstanceInfo aIns
UpdateMdnsResponseCounters(mTelemetryInfo.mServiceResolutions, OTBR_ERROR_NONE);
UpdateServiceInstanceResolutionEmaLatency(aInstanceInfo.mName, aType, OTBR_ERROR_NONE);

// In a callback, the mDiscoveredCallbacks may get changed which invalidates the running iterator. We need to refer
// to the callbacks by subscriberId to avoid invalid memory access.
subscriberIds.reserve(mDiscoveredCallbacks.size());
for (const auto &subCallback : mDiscoveredCallbacks)
// The `mDiscoverCallbacks` list can get updated as the callbacks
// are invoked. We first mark `mShouldInvoke` on all non-null
// service callbacks. We clear it before invoking the callback
// and restart the iteration over the `mDiscoverCallbacks` list
// to find the next one to signal, since the list may have changed.

for (DiscoverCallback &callback : mDiscoverCallbacks)
{
subscriberIds.push_back(subCallback.first);
if (callback.mServiceCallback != nullptr)
{
callback.mShouldInvoke = true;
checkToInvoke = true;
}
}
for (const auto &subscriberId : subscriberIds)

while (checkToInvoke)
{
auto it = mDiscoveredCallbacks.find(subscriberId);
if (it != mDiscoveredCallbacks.end())
checkToInvoke = false;

for (DiscoverCallback &callback : mDiscoverCallbacks)
{
const auto &subCallback = *it;
if (subCallback.second.first != nullptr)
if (callback.mShouldInvoke)
{
subCallback.second.first(aType, aInstanceInfo);
callback.mShouldInvoke = false;
checkToInvoke = true;
callback.mServiceCallback(aType, aInstanceInfo);
break;
}
}
}
Expand All @@ -252,6 +256,8 @@ void Publisher::OnServiceRemoved(uint32_t aNetifIndex, std::string aType, std::s

void Publisher::OnHostResolved(std::string aHostName, Publisher::DiscoveredHostInfo aHostInfo)
{
bool checkToInvoke = false;

otbrLogInfo("Host %s is resolved successfully: host %s addresses %zu ttl %u", aHostName.c_str(),
aHostInfo.mHostName.c_str(), aHostInfo.mAddresses.size(), aHostInfo.mTtl);

Expand All @@ -263,11 +269,34 @@ void Publisher::OnHostResolved(std::string aHostName, Publisher::DiscoveredHostI
UpdateMdnsResponseCounters(mTelemetryInfo.mHostResolutions, OTBR_ERROR_NONE);
UpdateHostResolutionEmaLatency(aHostName, OTBR_ERROR_NONE);

for (const auto &subCallback : mDiscoveredCallbacks)
// The `mDiscoverCallbacks` list can get updated as the callbacks
// are invoked. We first mark `mShouldInvoke` on all non-null
// host callbacks. We clear it before invoking the callback
// and restart the iteration over the `mDiscoverCallbacks` list
// to find the next one to signal, since the list may have changed.

for (DiscoverCallback &callback : mDiscoverCallbacks)
{
if (subCallback.second.second != nullptr)
if (callback.mHostCallback != nullptr)
{
subCallback.second.second(aHostName, aHostInfo);
callback.mShouldInvoke = true;
checkToInvoke = true;
}
}

while (checkToInvoke)
{
checkToInvoke = false;

for (DiscoverCallback &callback : mDiscoverCallbacks)
{
if (callback.mShouldInvoke)
{
callback.mShouldInvoke = false;
checkToInvoke = true;
callback.mHostCallback(aHostName, aHostInfo);
break;
}
}
}
}
Expand Down
22 changes: 21 additions & 1 deletion src/mdns/mdns.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "openthread-br/config.h"

#include <functional>
#include <list>
#include <map>
#include <memory>
#include <string>
Expand Down Expand Up @@ -575,9 +576,28 @@ class Publisher : private NonCopyable
ServiceRegistrationMap mServiceRegistrations;
HostRegistrationMap mHostRegistrations;

struct DiscoverCallback
{
DiscoverCallback(uint64_t aId,
DiscoveredServiceInstanceCallback aServiceCallback,
DiscoveredHostCallback aHostCallback)
: mId(aId)
, mServiceCallback(aServiceCallback)
, mHostCallback(aHostCallback)
, mShouldInvoke(false)
{
}

uint64_t mId;
DiscoveredServiceInstanceCallback mServiceCallback;
DiscoveredHostCallback mHostCallback;
bool mShouldInvoke;
};

uint64_t mNextSubscriberId = 1;

std::map<uint64_t, std::pair<DiscoveredServiceInstanceCallback, DiscoveredHostCallback>> mDiscoveredCallbacks;
std::list<DiscoverCallback> mDiscoverCallbacks;

// {instance name, service type} -> the timepoint to begin service registration
std::map<std::pair<std::string, std::string>, Timepoint> mServiceRegistrationBeginTime;
// host name -> the timepoint to begin host registration
Expand Down

0 comments on commit e8f07ad

Please sign in to comment.