Skip to content

Commit

Permalink
Fix the dnssd code that browses on both the local and srp domains (#3…
Browse files Browse the repository at this point in the history
…2675)

* Fix the dnssd code that browses on both the local and srp domains

- Fixes the UAF issue with the timer

- Resolves critical comments from PR #32631

* Fix GetDomainFromHostName to get the domain name correctly

* Restyled by clang-format

---------

Co-authored-by: Restyled.io <commits@restyled.io>
  • Loading branch information
nivi-apple and restyled-commits authored Mar 21, 2024
1 parent 53273d1 commit 88afa33
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 110 deletions.
105 changes: 60 additions & 45 deletions src/platform/Darwin/DnssdContexts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,27 +458,35 @@ ResolveContext::ResolveContext(void * cbContext, DnssdResolveCallback cb, chip::
std::shared_ptr<uint32_t> && consumerCounterToUse) :
browseThatCausedResolve(browseCausingResolve)
{
type = ContextType::Resolve;
context = cbContext;
callback = cb;
protocol = GetProtocol(cbAddressType);
instanceName = instanceNameToResolve;
consumerCounter = std::move(consumerCounterToUse);
type = ContextType::Resolve;
context = cbContext;
callback = cb;
protocol = GetProtocol(cbAddressType);
instanceName = instanceNameToResolve;
consumerCounter = std::move(consumerCounterToUse);
hasSrpTimerStarted = false;
}

ResolveContext::ResolveContext(CommissioningResolveDelegate * delegate, chip::Inet::IPAddressType cbAddressType,
const char * instanceNameToResolve, std::shared_ptr<uint32_t> && consumerCounterToUse) :
browseThatCausedResolve(nullptr)
{
type = ContextType::Resolve;
context = delegate;
callback = nullptr;
protocol = GetProtocol(cbAddressType);
instanceName = instanceNameToResolve;
consumerCounter = std::move(consumerCounterToUse);
type = ContextType::Resolve;
context = delegate;
callback = nullptr;
protocol = GetProtocol(cbAddressType);
instanceName = instanceNameToResolve;
consumerCounter = std::move(consumerCounterToUse);
hasSrpTimerStarted = false;
}

ResolveContext::~ResolveContext() {}
ResolveContext::~ResolveContext()
{
if (this->hasSrpTimerStarted)
{
CancelSrpTimer(this);
}
}

void ResolveContext::DispatchFailure(const char * errorStr, CHIP_ERROR err)
{
Expand Down Expand Up @@ -526,8 +534,7 @@ void ResolveContext::DispatchSuccess()

for (auto interfaceIndex : priorityInterfaceIndices)
{
// Try finding interfaces for domains kLocalDot and kOpenThreadDot and delete them.
if (TryReportingResultsForInterfaceIndex(static_cast<uint32_t>(interfaceIndex), std::string(kLocalDot)))
if (TryReportingResultsForInterfaceIndex(interfaceIndex))
{
if (needDelete)
{
Expand All @@ -536,7 +543,7 @@ void ResolveContext::DispatchSuccess()
return;
}

if (TryReportingResultsForInterfaceIndex(static_cast<uint32_t>(interfaceIndex), std::string(kOpenThreadDot)))
if (TryReportingResultsForInterfaceIndex(interfaceIndex))
{
if (needDelete)
{
Expand All @@ -548,7 +555,8 @@ void ResolveContext::DispatchSuccess()

for (auto & interface : interfaces)
{
if (TryReportingResultsForInterfaceIndex(interface.first.first, interface.first.second))
auto interfaceId = interface.first.first;
if (TryReportingResultsForInterfaceIndex(interfaceId))
{
break;
}
Expand All @@ -560,52 +568,60 @@ void ResolveContext::DispatchSuccess()
}
}

bool ResolveContext::TryReportingResultsForInterfaceIndex(uint32_t interfaceIndex, std::string domainName)
bool ResolveContext::TryReportingResultsForInterfaceIndex(uint32_t interfaceIndex)
{
if (interfaceIndex == 0)
{
// Not actually an interface we have.
return false;
}

std::pair<uint32_t, std::string> interfaceKey = std::make_pair(interfaceIndex, domainName);
auto & interface = interfaces[interfaceKey];
auto & ips = interface.addresses;

// Some interface may not have any ips, just ignore them.
if (ips.size() == 0)
std::map<std::pair<uint32_t, std::string>, InterfaceInfo>::iterator iter = interfaces.begin();
while (iter != interfaces.end())
{
return false;
}
std::pair<uint32_t, std::string> key = iter->first;
if (key.first == interfaceIndex)
{
auto & interface = interfaces[key];
auto & ips = interface.addresses;

ChipLogProgress(Discovery, "Mdns: Resolve success on interface %" PRIu32, interfaceIndex);
// Some interface may not have any ips, just ignore them.
if (ips.size() == 0)
{
return false;
}

auto & service = interface.service;
auto addresses = Span<Inet::IPAddress>(ips.data(), ips.size());
if (nullptr == callback)
{
auto delegate = static_cast<CommissioningResolveDelegate *>(context);
DiscoveredNodeData nodeData;
service.ToDiscoveredNodeData(addresses, nodeData);
delegate->OnNodeDiscovered(nodeData);
}
else
{
callback(context, &service, addresses, CHIP_NO_ERROR);
}
ChipLogProgress(Discovery, "Mdns: Resolve success on interface %" PRIu32, interfaceIndex);

return true;
auto & service = interface.service;
auto addresses = Span<Inet::IPAddress>(ips.data(), ips.size());
if (nullptr == callback)
{
auto delegate = static_cast<CommissioningResolveDelegate *>(context);
DiscoveredNodeData nodeData;
service.ToDiscoveredNodeData(addresses, nodeData);
delegate->OnNodeDiscovered(nodeData);
}
else
{
callback(context, &service, addresses, CHIP_NO_ERROR);
}

return true;
}
}
return false;
}

CHIP_ERROR ResolveContext::OnNewAddress(const std::pair<uint32_t, std::string> interfaceKey, const struct sockaddr * address)
CHIP_ERROR ResolveContext::OnNewAddress(const std::pair<uint32_t, std::string> & interfaceKey, const struct sockaddr * address)
{
// If we don't have any information about this interfaceId, just ignore the
// address, since it won't be usable anyway without things like the port.
// This can happen if "local" is set up as a search domain in the DNS setup
// on the system, because the hostnames we are looking up all end in
// ".local". In other words, we can get regular DNS results in here, not
// just DNS-SD ones.
uint32_t interfaceId = interfaceKey.first;
auto interfaceId = interfaceKey.first;

if (interfaces.find(interfaceKey) == interfaces.end())
{
Expand Down Expand Up @@ -720,8 +736,7 @@ void ResolveContext::OnNewInterface(uint32_t interfaceId, const char * fullname,
}

std::pair<uint32_t, std::string> interfaceKey = std::make_pair(interfaceId, domainFromHostname);

interfaces.insert(std::make_pair(interfaceKey, std::move(interface)));
interfaces.insert(std::make_pair(std::move(interfaceKey), std::move(interface)));
}

bool ResolveContext::HasInterface()
Expand Down
121 changes: 67 additions & 54 deletions src/platform/Darwin/DnssdImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ using namespace chip::Dnssd::Internal;

namespace {

// The extra time in milliseconds that we will wait for the resolution on the open thread domain to complete.
constexpr uint16_t kOpenThreadTimeoutInMsec = 250;
constexpr char kLocalDot[] = "local.";

constexpr char kSrpDot[] = "default.service.arpa.";

// The extra time in milliseconds that we will wait for the resolution on the srp domain to complete.
constexpr uint16_t kSrpTimeoutInMsec = 250;

constexpr DNSServiceFlags kRegisterFlags = kDNSServiceFlagsNoAutoRename;
constexpr DNSServiceFlags kBrowseFlags = kDNSServiceFlagsShareConnection;
Expand Down Expand Up @@ -144,59 +148,57 @@ std::string GetDomainFromHostName(const char * hostnameWithDomain)
{
std::string hostname = std::string(hostnameWithDomain);

// Find the last occurence of '.'
size_t last_pos = hostname.find_last_of(".");
if (last_pos != std::string::npos)
{
// Get a substring without last '.'
std::string substring = hostname.substr(0, last_pos);

// Find the last occurence of '.' in the substring created above.
size_t pos = substring.find_last_of(".");
if (pos != std::string::npos)
{
// Return the domain name between the last 2 occurences of '.' including the trailing dot'.'.
return std::string(hostname.substr(pos + 1, last_pos));
}
}
return std::string();
}
// Find the first occurence of '.'
size_t first_pos = hostname.find(".");

Global<MdnsContexts> MdnsContexts::sInstance;
// if not found, return empty string
VerifyOrReturnValue(first_pos != std::string::npos, std::string());

namespace {
// Get a substring after the first occurence of '.' to the end of the string
return hostname.substr(first_pos + 1, hostname.size());
}

/**
* @brief Callback that is called when the timeout for resolving on the kOpenThreadDot domain has expired.
* @brief Callback that is called when the timeout for resolving on the kSrpDot domain has expired.
*
* @param[in] systemLayer The system layer.
* @param[in] callbackContext The context passed to the timer callback.
*/
void OpenThreadTimerExpiredCallback(System::Layer * systemLayer, void * callbackContext)
void SrpTimerExpiredCallback(System::Layer * systemLayer, void * callbackContext)
{
ChipLogProgress(Discovery, "Mdns: Timer expired for resolve to complete on the open thread domain.");
ChipLogProgress(Discovery, "Mdns: Timer expired for resolve to complete on the srp domain.");
auto sdCtx = static_cast<ResolveContext *>(callbackContext);
VerifyOrDie(sdCtx != nullptr);

if (sdCtx->hasOpenThreadTimerStarted)
{
sdCtx->Finalize();
}
sdCtx->Finalize();
}

/**
* @brief Starts a timer to wait for the resolution on the kOpenThreadDot domain to happen.
* @brief Starts a timer to wait for the resolution on the kSrpDot domain to happen.
*
* @param[in] timeoutSeconds The timeout in seconds.
* @param[in] ResolveContext The resolve context.
*/
void StartOpenThreadTimer(uint16_t timeoutInMSecs, ResolveContext * ctx)
CHIP_ERROR StartSrpTimer(uint16_t timeoutInMSecs, ResolveContext * ctx)
{
VerifyOrReturn(ctx != nullptr, ChipLogError(Discovery, "Can't schedule open thread timer since context is null"));
DeviceLayer::SystemLayer().StartTimer(System::Clock::Milliseconds16(timeoutInMSecs), OpenThreadTimerExpiredCallback,
reinterpret_cast<void *>(ctx));
VerifyOrReturnValue(ctx != nullptr, CHIP_ERROR_INCORRECT_STATE);
return DeviceLayer::SystemLayer().StartTimer(System::Clock::Milliseconds16(timeoutInMSecs), SrpTimerExpiredCallback,
reinterpret_cast<void *>(ctx));
}

/**
* @brief Cancels the timer that was started to wait for the resolution on the kSrpDot domain to happen.
*
* @param[in] ResolveContext The resolve context.
*/
void CancelSrpTimer(ResolveContext * ctx)
{
DeviceLayer::SystemLayer().CancelTimer(SrpTimerExpiredCallback, reinterpret_cast<void *>(ctx));
}

Global<MdnsContexts> MdnsContexts::sInstance;

namespace {

static void OnRegister(DNSServiceRef sdRef, DNSServiceFlags flags, DNSServiceErrorType err, const char * name, const char * type,
const char * domain, void * context)
{
Expand Down Expand Up @@ -248,17 +250,17 @@ CHIP_ERROR Browse(BrowseHandler * sdCtx, uint32_t interfaceId, const char * type
auto err = DNSServiceCreateConnection(&sdCtx->serviceRef);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));

// We will browse on both the local domain and the open thread domain.
// We will browse on both the local domain and the srp domain.
ChipLogProgress(Discovery, "Browsing for: %s on domain %s", StringOrNullMarker(type), kLocalDot);

auto sdRefLocal = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
err = DNSServiceBrowse(&sdRefLocal, kBrowseFlags, interfaceId, type, kLocalDot, OnBrowse, sdCtx);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));

ChipLogProgress(Discovery, "Browsing for: %s on domain %s", StringOrNullMarker(type), kOpenThreadDot);
ChipLogProgress(Discovery, "Browsing for: %s on domain %s", StringOrNullMarker(type), kSrpDot);

DNSServiceRef sdRefOpenThread = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
err = DNSServiceBrowse(&sdRefOpenThread, kBrowseFlags, interfaceId, type, kOpenThreadDot, OnBrowse, sdCtx);
auto sdRefSrp = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
err = DNSServiceBrowse(&sdRefSrp, kBrowseFlags, interfaceId, type, kSrpDot, OnBrowse, sdCtx);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));

return MdnsContexts::GetInstance().Add(sdCtx, sdCtx->serviceRef);
Expand Down Expand Up @@ -307,25 +309,37 @@ static void OnGetAddrInfo(DNSServiceRef sdRef, DNSServiceFlags flags, uint32_t i
{
VerifyOrReturn(sdCtx->HasAddress(), sdCtx->Finalize(kDNSServiceErr_BadState));

if (domainName.compare(kOpenThreadDot) == 0)
if (domainName.compare(kSrpDot) == 0)
{
ChipLogProgress(Discovery, "Mdns: Resolve completed on the open thread domain.");
ChipLogProgress(Discovery, "Mdns: Resolve completed on the srp domain.");

// Cancel the timer if one has been started
if (sdCtx->hasSrpTimerStarted)
{
CancelSrpTimer(sdCtx);
}
sdCtx->Finalize();
}
else if (domainName.compare(kLocalDot) == 0)
{
ChipLogProgress(
Discovery,
"Mdns: Resolve completed on the local domain. Starting a timer for the open thread resolve to come back");
ChipLogProgress(Discovery,
"Mdns: Resolve completed on the local domain. Starting a timer for the srp resolve to come back");

// Usually the resolution on the local domain is quicker than on the open thread domain. We would like to give the
// resolution on the open thread domain around 250 millisecs more to give it a chance to resolve before finalizing
// Usually the resolution on the local domain is quicker than on the srp domain. We would like to give the
// resolution on the srp domain around 250 millisecs more to give it a chance to resolve before finalizing
// the resolution.
if (!sdCtx->hasOpenThreadTimerStarted)
if (!sdCtx->hasSrpTimerStarted)
{
// Schedule a timer to allow the resolve on OpenThread domain to complete.
StartOpenThreadTimer(kOpenThreadTimeoutInMsec, sdCtx);
sdCtx->hasOpenThreadTimerStarted = true;
// Schedule a timer to allow the resolve on Srp domain to complete.
CHIP_ERROR error = StartSrpTimer(kSrpTimeoutInMsec, sdCtx);

// If the timer fails to start, finalize the context and return.
if (error != CHIP_NO_ERROR)
{
sdCtx->Finalize();
return;
}
sdCtx->hasSrpTimerStarted = true;
}
}
}
Expand Down Expand Up @@ -367,8 +381,7 @@ static void OnResolve(DNSServiceRef sdRef, DNSServiceFlags flags, uint32_t inter
if (!sdCtx->isResolveRequested)
{
GetAddrInfo(sdCtx);
sdCtx->isResolveRequested = true;
sdCtx->hasOpenThreadTimerStarted = false;
sdCtx->isResolveRequested = true;
}
}
}
Expand All @@ -382,13 +395,13 @@ static CHIP_ERROR Resolve(ResolveContext * sdCtx, uint32_t interfaceId, chip::In
auto err = DNSServiceCreateConnection(&sdCtx->serviceRef);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));

// Similar to browse, will try to resolve using both the local domain and the open thread domain.
// Similar to browse, will try to resolve using both the local domain and the srp domain.
auto sdRefLocal = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
err = DNSServiceResolve(&sdRefLocal, kResolveFlags, interfaceId, name, type, kLocalDot, OnResolve, sdCtx);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));

auto sdRefOpenThread = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
err = DNSServiceResolve(&sdRefOpenThread, kResolveFlags, interfaceId, name, type, kOpenThreadDot, OnResolve, sdCtx);
auto sdRefSrp = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
err = DNSServiceResolve(&sdRefSrp, kResolveFlags, interfaceId, name, type, kSrpDot, OnResolve, sdCtx);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));

auto retval = MdnsContexts::GetInstance().Add(sdCtx, sdCtx->serviceRef);
Expand Down
Loading

0 comments on commit 88afa33

Please sign in to comment.