Skip to content

Commit

Permalink
host reg changes
Browse files Browse the repository at this point in the history
  • Loading branch information
abtink committed Sep 9, 2023
1 parent 3b42a7f commit 8a646ee
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 139 deletions.
204 changes: 93 additions & 111 deletions src/mdns/mdns_mdnssd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,59 +461,115 @@ void PublisherMDnsSd::DnssdServiceRegistration::HandleRegisterResult(DNSServiceF
}
}

PublisherMDnsSd::DnssdHostRegistration::~DnssdHostRegistration(void)
otbrError PublisherMDnsSd::DnssdHostRegistration::Register(void)
{
int dnsError;
DNSServiceErrorType dnsError = kDNSServiceErr_NoError;

VerifyOrExit(mServiceRef != nullptr);
otbrLogInfo("Registering new host %s", mName.c_str());

for (const auto &recordRefAndAddress : GetRecordRefMap())
for (const Ip6Address &address : mAddresses)
{
const DNSRecordRef &recordRef = recordRefAndAddress.first;
const Ip6Address &address = recordRefAndAddress.second;
if (IsCompleted())
DNSRecordRef recordRef = nullptr;

dnsError = DNSServiceRegisterRecord(GetPublisher().mHostsRef, &recordRef, kDNSServiceFlagsShared,
kDNSServiceInterfaceIndexAny, MakeFullHostName(mName).c_str(),
kDNSServiceType_AAAA, kDNSServiceClass_IN, sizeof(address.m8), address.m8,
/* ttl */ 0, HandleRegisterResult, this);

VerifyOrExit(dnsError == kDNSServiceErr_NoError);

mAddrRecordRefs.push_back(recordRef);
mAddrRegistered.push_back(false);
}

exit:
if ((dnsError != kDNSServiceErr_NoError) || mAddresses.empty())
{
HandleRegisterResult(/* aRecordRef */ nullptr, dnsError);
}

return GetPublisher().DnsErrorToOtbrError(dnsError);
}

void PublisherMDnsSd::DnssdHostRegistration::Unregister(void)
{
DNSServiceErrorType dnsError;

VerifyOrExit(GetPublisher().mHostsRef != nullptr);

for (size_t index = 0; index < mAddrRecordRefs.size(); index++)
{
const Ip6Address &address = mAddresses[index];

if (mAddrRegistered[index])
{
// The Bonjour mDNSResponder somehow doesn't send goodbye message for the AAAA record when it is
// removed by `DNSServiceRemoveRecord`. Per RFC 6762, a goodbye message of a record sets its TTL
// to zero but the receiver should record the TTL of 1 and flushes the cache 1 second later. Here
// we remove the AAAA record after updating its TTL to 1 second. This has the same effect as
// sending a goodbye message.
// TODO: resolve the goodbye issue with Bonjour mDNSResponder.
dnsError = DNSServiceUpdateRecord(mServiceRef, recordRef, kDNSServiceFlagsUnique, sizeof(address.m8),
address.m8, /* ttl */ 1);
dnsError = DNSServiceUpdateRecord(GetPublisher().mHostsRef, mAddrRecordRefs[index], kDNSServiceFlagsUnique,
sizeof(address.m8), address.m8, /* ttl */ 1);
otbrLogResult(DNSErrorToOtbrError(dnsError), "Send goodbye message for host %s address %s: %s",
MakeFullHostName(mName).c_str(), address.ToString().c_str(), DNSErrorToString(dnsError));
}
dnsError = DNSServiceRemoveRecord(mServiceRef, recordRef, /* flags */ 0);

dnsError = DNSServiceRemoveRecord(GetPublisher().mHostsRef, mAddrRecordRefs[index], /* flags */ 0);

otbrLogResult(DNSErrorToOtbrError(dnsError), "Remove record for host %s address %s: %s",
MakeFullHostName(mName).c_str(), address.ToString().c_str(), DNSErrorToString(dnsError));
// TODO: ?
// DNSRecordRefDeallocate(recordRef);
}

exit:
return;
mAddrRegistered.clear();
mAddrRecordRefs.clear();
}

PublisherMDnsSd::DnssdHostRegistration *PublisherMDnsSd::FindHostRegistration(const DNSServiceRef &aServiceRef,
const DNSRecordRef &aRecordRef)
void PublisherMDnsSd::DnssdHostRegistration::HandleRegisterResult(DNSServiceRef aServiceRef,
DNSRecordRef aRecordRef,
DNSServiceFlags aFlags,
DNSServiceErrorType aError,
void *aContext)
{
DnssdHostRegistration *hostReg;
OT_UNUSED_VARIABLE(aServiceRef);
OT_UNUSED_VARIABLE(aFlags);

static_cast<DnssdHostRegistration *>(aContext)->HandleRegisterResult(aRecordRef, aError);
}

for (auto &kv : mHostRegistrations)
void PublisherMDnsSd::DnssdHostRegistration::HandleRegisterResult(DNSRecordRef aRecordRef, DNSServiceErrorType aError)
{
if (aError != kDNSServiceErr_NoError)
{
hostReg = static_cast<DnssdHostRegistration *>(kv.second.get());
otbrLogErr("Failed to register host %s: %s", mName.c_str(), DNSErrorToString(aError));
GetPublisher().RemoveHostRegistration(mName, DNSErrorToOtbrError(aError));
}
else
{
bool shouldComplete = !IsCompleted();

if ((hostReg->mServiceRef == aServiceRef) && hostReg->mRecordRefMap.count(aRecordRef))
for (size_t index = 0; index < mAddrRecordRefs.size(); index++)
{
ExitNow();
if ((mAddrRecordRefs[index] == aRecordRef) && !mAddrRegistered[index])
{
mAddrRegistered[index] = true;
otbrLogInfo("Successfully registered host %s address %s", mName.c_str(),
mAddresses[index].ToString().c_str());
}

if (!mAddrRegistered[index])
{
shouldComplete = false;
}
}
}

hostReg = nullptr;

exit:
return hostReg;
if (shouldComplete)
{
otbrLogInfo("Successfully registered all host %s addresses", mName.c_str());
Complete(OTBR_ERROR_NONE);
}
}
}

otbrError PublisherMDnsSd::PublishServiceImpl(const std::string &aHostName,
Expand Down Expand Up @@ -565,54 +621,26 @@ otbrError PublisherMDnsSd::PublishHostImpl(const std::string &aName,
const AddressList &aAddresses,
ResultCallback &&aCallback)
{
otbrError ret = OTBR_ERROR_NONE;
int error = 0;
std::string fullName;
std::unique_ptr<DnssdHostRegistration> registration;

VerifyOrExit(mState == Publisher::State::kReady, ret = OTBR_ERROR_INVALID_STATE);

fullName = MakeFullHostName(aName);

aCallback = HandleDuplicateHostRegistration(aName, aAddresses, std::move(aCallback));
VerifyOrExit(!aCallback.IsNull());
VerifyOrExit(!aAddresses.empty(), std::move(aCallback)(OTBR_ERROR_NONE));
otbrError error = OTBR_ERROR_NONE;
DnssdHostRegistration *hostReg;

if (mHostsRef == nullptr)
if (mState != State::kReady)
{
SuccessOrExit(error = DNSServiceCreateConnection(&mHostsRef));
otbrLogDebug("Created new DNSServiceRef for hosts: %p", mHostsRef);
error = OTBR_ERROR_INVALID_STATE;
std::move(aCallback)(error);
ExitNow();
}

registration.reset(new DnssdHostRegistration(aName, aAddresses, std::move(aCallback), mHostsRef, this));
aCallback = HandleDuplicateHostRegistration(aName, aAddresses, std::move(aCallback));
VerifyOrExit(!aCallback.IsNull());

otbrLogInfo("Registering new host %s", aName.c_str());
for (const auto &address : aAddresses)
{
DNSRecordRef recordRef = nullptr;
// Supports only IPv6 for now, may support IPv4 in the future.
SuccessOrExit(error = DNSServiceRegisterRecord(mHostsRef, &recordRef, kDNSServiceFlagsShared,
kDNSServiceInterfaceIndexAny, fullName.c_str(),
kDNSServiceType_AAAA, kDNSServiceClass_IN, sizeof(address.m8),
address.m8, /* ttl */ 0, HandleRegisterHostResult, this));
registration->GetRecordRefMap()[recordRef] = address;
}
hostReg = new DnssdHostRegistration(aName, aAddresses, std::move(aCallback), this);
AddHostRegistration(std::unique_ptr<DnssdHostRegistration>(hostReg));

AddHostRegistration(std::move(registration));
error = hostReg->Register();

exit:
if (error != kDNSServiceErr_NoError || ret != OTBR_ERROR_NONE)
{
if (error != kDNSServiceErr_NoError)
{
ret = DNSErrorToOtbrError(error);
otbrLogErr("Failed to publish/update host %s for mdnssd error: %s!", aName.c_str(),
DNSErrorToString(error));
}

std::move(aCallback)(ret);
}
return ret;
return error;
}

void PublisherMDnsSd::UnpublishHost(const std::string &aName, ResultCallback &&aCallback)
Expand All @@ -629,52 +657,6 @@ void PublisherMDnsSd::UnpublishHost(const std::string &aName, ResultCallback &&a
std::move(aCallback)(error);
}

void PublisherMDnsSd::HandleRegisterHostResult(DNSServiceRef aServiceRef,
DNSRecordRef aRecordRef,
DNSServiceFlags aFlags,
DNSServiceErrorType aError,
void *aContext)
{
static_cast<PublisherMDnsSd *>(aContext)->HandleRegisterHostResult(aServiceRef, aRecordRef, aFlags, aError);
}

void PublisherMDnsSd::HandleRegisterHostResult(DNSServiceRef aServiceRef,
DNSRecordRef aRecordRef,
DNSServiceFlags aFlags,
DNSServiceErrorType aError)
{
OTBR_UNUSED_VARIABLE(aFlags);

otbrError error = DNSErrorToOtbrError(aError);
auto *hostReg = static_cast<DnssdHostRegistration *>(FindHostRegistration(aServiceRef, aRecordRef));

std::string hostName;

VerifyOrExit(hostReg != nullptr);

hostName = MakeFullHostName(hostReg->mName);

otbrLogInfo("Received reply for host %s: %s", hostName.c_str(), DNSErrorToString(aError));

if (error == OTBR_ERROR_NONE)
{
--hostReg->mCallbackCount;
if (!hostReg->mCallbackCount)
{
otbrLogInfo("Successfully registered host %s", hostName.c_str());
hostReg->Complete(OTBR_ERROR_NONE);
}
}
else
{
otbrLogWarning("Failed to register host %s for mdnssd error: %s", hostName.c_str(), DNSErrorToString(aError));
RemoveHostRegistration(hostReg->mName, error);
}

exit:
return;
}

// See `regtype` parameter of the DNSServiceRegister() function for more information.
std::string PublisherMDnsSd::MakeRegType(const std::string &aType, SubTypeList aSubTypeList)
{
Expand Down
45 changes: 17 additions & 28 deletions src/mdns/mdns_mdnssd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,26 +132,25 @@ class PublisherMDnsSd : public MainloopProcessor, public Publisher
class DnssdHostRegistration : public HostRegistration
{
public:
DnssdHostRegistration(const std::string &aName,
const AddressList &aAddresses,
ResultCallback &&aCallback,
DNSServiceRef aServiceRef,
Publisher *aPublisher)
: HostRegistration(aName, aAddresses, std::move(aCallback), aPublisher)
, mServiceRef(aServiceRef)
, mRecordRefMap()
, mCallbackCount(aAddresses.size())
{
}
using HostRegistration::HostRegistration; // Inherit base class constructor

~DnssdHostRegistration(void) override { Unregister(); }

~DnssdHostRegistration(void) override;
const DNSServiceRef &GetServiceRef() const { return mServiceRef; }
const std::map<DNSRecordRef, Ip6Address> &GetRecordRefMap() const { return mRecordRefMap; }
std::map<DNSRecordRef, Ip6Address> &GetRecordRefMap() { return mRecordRefMap; }
otbrError Register(void);

private:
void Unregister(void);
PublisherMDnsSd &GetPublisher(void) { return *static_cast<PublisherMDnsSd *>(mPublisher); }
void HandleRegisterResult(DNSRecordRef aRecordRef, DNSServiceErrorType aError);
static void HandleRegisterResult(DNSServiceRef aServiceRef,
DNSRecordRef aRecordRef,
DNSServiceFlags aFlags,
DNSServiceErrorType aErrorCode,
void *aContext);

DNSServiceRef mServiceRef;
std::map<DNSRecordRef, Ip6Address> mRecordRefMap;
uint32_t mCallbackCount;
std::vector<DNSRecordRef> mAddrRecordRefs;
std::vector<bool> mAddrRegistered;
uint32_t mRemainingAddrRegs;
};

struct ServiceRef : private ::NonCopyable
Expand Down Expand Up @@ -312,16 +311,6 @@ class PublisherMDnsSd : public MainloopProcessor, public Publisher
using ServiceSubscriptionList = std::vector<std::unique_ptr<ServiceSubscription>>;
using HostSubscriptionList = std::vector<std::unique_ptr<HostSubscription>>;

static void HandleRegisterHostResult(DNSServiceRef aHostsConnection,
DNSRecordRef aHostRecord,
DNSServiceFlags aFlags,
DNSServiceErrorType aErrorCode,
void *aContext);
void HandleRegisterHostResult(DNSServiceRef aHostsConnection,
DNSRecordRef aHostRecord,
DNSServiceFlags aFlags,
DNSServiceErrorType aErrorCode);

static std::string MakeRegType(const std::string &aType, SubTypeList aSubTypeList);

DnssdHostRegistration *FindHostRegistration(const DNSServiceRef &aServiceRef, const DNSRecordRef &aRecordRef);
Expand Down

0 comments on commit 8a646ee

Please sign in to comment.