diff --git a/src/common/types.hpp b/src/common/types.hpp index 649f0b6371d..7ba0065ba33 100644 --- a/src/common/types.hpp +++ b/src/common/types.hpp @@ -402,11 +402,13 @@ struct MdnsTelemetryInfo "kEmaFactorDenominator must be greater than kEmaFactorNumerator"); MdnsResponseCounters mHostRegistrations; + MdnsResponseCounters mKeyRegistrations; MdnsResponseCounters mServiceRegistrations; MdnsResponseCounters mHostResolutions; MdnsResponseCounters mServiceResolutions; uint32_t mHostRegistrationEmaLatency; ///< The EMA latency of host registrations in milliseconds + uint32_t mKeyRegistrationEmaLatency; ///< The EMA latency of key registrations in milliseconds uint32_t mServiceRegistrationEmaLatency; ///< The EMA latency of service registrations in milliseconds uint32_t mHostResolutionEmaLatency; ///< The EMA latency of host resolutions in milliseconds uint32_t mServiceResolutionEmaLatency; ///< The EMA latency of service resolutions in milliseconds diff --git a/src/mdns/mdns.cpp b/src/mdns/mdns.cpp index fc2cf8a31fc..23e06af8a74 100644 --- a/src/mdns/mdns.cpp +++ b/src/mdns/mdns.cpp @@ -81,6 +81,19 @@ void Publisher::PublishHost(const std::string &aName, } } +void Publisher::PublishKey(const std::string &aName, const KeyData &aKeyData, ResultCallback &&aCallback) +{ + otbrError error; + + mKeyRegistrationBeginTime[aName] = Clock::now(); + + error = PublishKeyImpl(aName, aKeyData, std::move(aCallback)); + if (error != OTBR_ERROR_NONE) + { + UpdateMdnsResponseCounters(mTelemetryInfo.mKeyRegistrations, error); + } +} + void Publisher::OnServiceResolveFailed(std::string aType, std::string aInstanceName, int32_t aErrorCode) { UpdateMdnsResponseCounters(mTelemetryInfo.mServiceResolutions, DnsErrorToOtbrError(aErrorCode)); @@ -296,6 +309,11 @@ std::string Publisher::MakeFullHostName(const std::string &aName) return aName + ".local"; } +std::string Publisher::MakeFullKeyName(const std::string &aName) +{ + return aName + ".local"; +} + void Publisher::AddServiceRegistration(ServiceRegistrationPtr &&aServiceReg) { mServiceRegistrations.emplace(MakeFullServiceName(aServiceReg->mName, aServiceReg->mType), std::move(aServiceReg)); @@ -437,6 +455,75 @@ Publisher::HostRegistration *Publisher::FindHostRegistration(const std::string & return it != mHostRegistrations.end() ? it->second.get() : nullptr; } +Publisher::ResultCallback Publisher::HandleDuplicateKeyRegistration(const std::string &aName, + const KeyData &aKeyData, + ResultCallback &&aCallback) +{ + KeyRegistration *keyReg = FindKeyRegistration(aName); + + VerifyOrExit(keyReg != nullptr); + + if (keyReg->IsOutdated(aName, aKeyData)) + { + otbrLogInfo("Removing existing key %s: outdated", aName.c_str()); + RemoveKeyRegistration(keyReg->mName, OTBR_ERROR_ABORTED); + } + else if (keyReg->IsCompleted()) + { + // Returns success if the same key has already been + // registered with exactly the same parameters. + std::move(aCallback)(OTBR_ERROR_NONE); + } + else + { + // If the same key is being registered with the same parameters, + // let's join the waiting queue for the result. + keyReg->mCallback = std::bind( + [](std::shared_ptr aExistingCallback, std::shared_ptr aNewCallback, + otbrError aError) { + std::move (*aExistingCallback)(aError); + std::move (*aNewCallback)(aError); + }, + std::make_shared(std::move(keyReg->mCallback)), + std::make_shared(std::move(aCallback)), std::placeholders::_1); + } + +exit: + return std::move(aCallback); +} + +void Publisher::AddKeyRegistration(KeyRegistrationPtr &&aKeyReg) +{ + mKeyRegistrations.emplace(MakeFullKeyName(aKeyReg->mName), std::move(aKeyReg)); +} + +void Publisher::RemoveKeyRegistration(const std::string &aName, otbrError aError) +{ + auto it = mKeyRegistrations.find(MakeFullKeyName(aName)); + KeyRegistrationPtr keyReg; + + otbrLogInfo("Removing key %s", aName.c_str()); + VerifyOrExit(it != mKeyRegistrations.end()); + + // Keep the KeyRegistration around before calling `Complete` + // to invoke the callback. This is for avoiding invalid access + // to the KeyRegistration when it's freed from the callback. + keyReg = std::move(it->second); + mKeyRegistrations.erase(it); + keyReg->Complete(aError); + otbrLogInfo("Removed key %s", aName.c_str()); + +exit: + return; +} + +Publisher::KeyRegistration *Publisher::FindKeyRegistration(const std::string &aName) +{ + auto it = mKeyRegistrations.find(MakeFullKeyName(aName)); + + return it != mKeyRegistrations.end() ? it->second.get() : nullptr; +} + Publisher::Registration::~Registration(void) { TriggerCompleteCallback(OTBR_ERROR_ABORTED); @@ -488,6 +575,26 @@ void Publisher::HostRegistration::OnComplete(otbrError aError) } } +bool Publisher::KeyRegistration::IsOutdated(const std::string &aName, const KeyData &aKeyData) const +{ + return !(mName == aName && mKeyData == aKeyData); +} + +void Publisher::KeyRegistration::Complete(otbrError aError) +{ + OnComplete(aError); + Registration::TriggerCompleteCallback(aError); +} + +void Publisher::KeyRegistration::OnComplete(otbrError aError) +{ + if (!IsCompleted()) + { + mPublisher->UpdateMdnsResponseCounters(mPublisher->mTelemetryInfo.mKeyRegistrations, aError); + mPublisher->UpdateKeyRegistrationEmaLatency(mName, aError); + } +} + void Publisher::UpdateMdnsResponseCounters(otbr::MdnsResponseCounters &aCounters, otbrError aError) { switch (aError) @@ -566,6 +673,18 @@ void Publisher::UpdateHostRegistrationEmaLatency(const std::string &aHostName, o } } +void Publisher::UpdateKeyRegistrationEmaLatency(const std::string &aKeyName, otbrError aError) +{ + auto it = mKeyRegistrationBeginTime.find(aKeyName); + + if (it != mKeyRegistrationBeginTime.end()) + { + uint32_t latency = std::chrono::duration_cast(Clock::now() - it->second).count(); + UpdateEmaLatency(mTelemetryInfo.mKeyRegistrationEmaLatency, latency, aError); + mKeyRegistrationBeginTime.erase(it); + } +} + void Publisher::UpdateServiceInstanceResolutionEmaLatency(const std::string &aInstanceName, const std::string &aType, otbrError aError) diff --git a/src/mdns/mdns.hpp b/src/mdns/mdns.hpp index a4de9181dde..b90d3596dd4 100644 --- a/src/mdns/mdns.hpp +++ b/src/mdns/mdns.hpp @@ -118,6 +118,7 @@ class Publisher : private NonCopyable typedef std::vector TxtList; typedef std::vector SubTypeList; typedef std::vector AddressList; + typedef std::vector KeyData; /** * This structure represents information of a discovered service instance. @@ -266,6 +267,29 @@ class Publisher : private NonCopyable */ virtual void UnpublishHost(const std::string &aName, ResultCallback &&aCallback) = 0; + /** + * This method publishes or updates a key record for a name. + * + * @param[in] aName The name associated with key record (can be host name or service instance name). + * @param[in] aKeyData The key data to publish. + * @param[in] aCallback The callback for receiving the publishing result.`OTBR_ERROR_NONE` will be + * returned if the operation is successful and all other values indicate a + * failure. Specifically, `OTBR_ERROR_DUPLICATED` indicates that the name has + * already been published and the caller can re-publish with a new name if an + * alternative name is available/acceptable. + * + */ + void PublishKey(const std::string &aName, const KeyData &aKeyData, ResultCallback &&aCallback); + + /** + * This method un-publishes a key record + * + * @param[in] aName The name associated with key record. + * @param[in] aCallback The callback for receiving the publishing result. + * + */ + virtual void UnpublishKey(const std::string &aName, ResultCallback &&aCallback) = 0; + /** * This method subscribes a given service or service instance. * @@ -499,15 +523,42 @@ class Publisher : private NonCopyable bool IsOutdated(const std::string &aName, const std::vector &aAddresses) const; }; + class KeyRegistration : public Registration + { + public: + std::string mName; + KeyData mKeyData; + + KeyRegistration(std::string aName, KeyData aKeyData, ResultCallback &&aCallback, Publisher *aPublisher) + : Registration(std::move(aCallback), aPublisher) + , mName(std::move(aName)) + , mKeyData(std::move(aKeyData)) + { + } + + ~KeyRegistration(void) { OnComplete(OTBR_ERROR_ABORTED); } + + void Complete(otbrError aError); + + // Tells whether this `KeyRegistration` object is outdated comparing to the given parameters. + bool IsOutdated(const std::string &aName, const KeyData &aKeyData) const; + + private: + void OnComplete(otbrError aError); + }; + using ServiceRegistrationPtr = std::unique_ptr; using ServiceRegistrationMap = std::map; using HostRegistrationPtr = std::unique_ptr; using HostRegistrationMap = std::map; + using KeyRegistrationPtr = std::unique_ptr; + using KeyRegistrationMap = std::map; static SubTypeList SortSubTypeList(SubTypeList aSubTypeList); static AddressList SortAddressList(AddressList aAddressList); static std::string MakeFullServiceName(const std::string &aName, const std::string &aType); static std::string MakeFullHostName(const std::string &aName); + static std::string MakeFullKeyName(const std::string &aName); virtual otbrError PublishServiceImpl(const std::string &aHostName, const std::string &aName, @@ -515,14 +566,17 @@ class Publisher : private NonCopyable const SubTypeList &aSubTypeList, uint16_t aPort, const TxtData &aTxtData, - ResultCallback &&aCallback) = 0; + ResultCallback &&aCallback) = 0; virtual otbrError PublishHostImpl(const std::string &aName, const std::vector &aAddresses, - ResultCallback &&aCallback) = 0; - virtual void OnServiceResolveFailedImpl(const std::string &aType, - const std::string &aInstanceName, - int32_t aErrorCode) = 0; - virtual void OnHostResolveFailedImpl(const std::string &aHostName, int32_t aErrorCode) = 0; + ResultCallback &&aCallback) = 0; + + virtual otbrError PublishKeyImpl(const std::string &aName, const KeyData &aKeyData, ResultCallback &&aCallback) = 0; + + virtual void OnServiceResolveFailedImpl(const std::string &aType, + const std::string &aInstanceName, + int32_t aErrorCode) = 0; + virtual void OnHostResolveFailedImpl(const std::string &aHostName, int32_t aErrorCode) = 0; virtual otbrError DnsErrorToOtbrError(int32_t aError) = 0; @@ -551,10 +605,18 @@ class Publisher : private NonCopyable const std::vector &aAddresses, ResultCallback &&aCallback); + ResultCallback HandleDuplicateKeyRegistration(const std::string &aName, + const KeyData &aKeyData, + ResultCallback &&aCallback); + void AddHostRegistration(HostRegistrationPtr &&aHostReg); void RemoveHostRegistration(const std::string &aName, otbrError aError); HostRegistration *FindHostRegistration(const std::string &aName); + void AddKeyRegistration(KeyRegistrationPtr &&aKeyReg); + void RemoveKeyRegistration(const std::string &aName, otbrError aError); + KeyRegistration *FindKeyRegistration(const std::string &aName); + static void UpdateMdnsResponseCounters(otbr::MdnsResponseCounters &aCounters, otbrError aError); static void UpdateEmaLatency(uint32_t &aEmaLatency, uint32_t aLatency, otbrError aError); @@ -562,6 +624,7 @@ class Publisher : private NonCopyable const std::string &aType, otbrError aError); void UpdateHostRegistrationEmaLatency(const std::string &aHostName, otbrError aError); + void UpdateKeyRegistrationEmaLatency(const std::string &aKeyName, otbrError aError); void UpdateServiceInstanceResolutionEmaLatency(const std::string &aInstanceName, const std::string &aType, otbrError aError); @@ -569,6 +632,7 @@ class Publisher : private NonCopyable ServiceRegistrationMap mServiceRegistrations; HostRegistrationMap mHostRegistrations; + KeyRegistrationMap mKeyRegistrations; uint64_t mNextSubscriberId = 1; @@ -577,6 +641,8 @@ class Publisher : private NonCopyable std::map, Timepoint> mServiceRegistrationBeginTime; // host name -> the timepoint to begin host registration std::map mHostRegistrationBeginTime; + // key name -> the timepoint to begin key registration + std::map mKeyRegistrationBeginTime; // {instance name, service type} -> the timepoint to begin service resolution std::map, Timepoint> mServiceInstanceResolutionBeginTime; // host name -> the timepoint to begin host resolution diff --git a/src/mdns/mdns_avahi.cpp b/src/mdns/mdns_avahi.cpp index b501b8bd663..2fbb6fc261a 100644 --- a/src/mdns/mdns_avahi.cpp +++ b/src/mdns/mdns_avahi.cpp @@ -524,6 +524,7 @@ void PublisherAvahi::CallHostOrServiceCallback(AvahiEntryGroup *aGroup, otbrErro { ServiceRegistration *serviceReg; HostRegistration *hostReg; + KeyRegistration *keyReg; if ((serviceReg = FindServiceRegistration(aGroup)) != nullptr) { @@ -547,6 +548,17 @@ void PublisherAvahi::CallHostOrServiceCallback(AvahiEntryGroup *aGroup, otbrErro RemoveHostRegistration(hostReg->mName, aError); } } + else if ((keyReg = FindKeyRegistration(aGroup)) != nullptr) + { + if (aError == OTBR_ERROR_NONE) + { + keyReg->Complete(aError); + } + else + { + RemoveKeyRegistration(keyReg->mName, aError); + } + } else { otbrLogWarning("No registered service or host matches avahi group @%p", aGroup); @@ -789,6 +801,64 @@ void PublisherAvahi::UnpublishHost(const std::string &aName, ResultCallback &&aC std::move(aCallback)(error); } +otbrError PublisherAvahi::PublishKeyImpl(const std::string &aName, const KeyData &aKeyData, ResultCallback &&aCallback) +{ + otbrError error = OTBR_ERROR_NONE; + int avahiError = AVAHI_OK; + std::string fullKeyName; + AvahiEntryGroup *group = nullptr; + + VerifyOrExit(mState == State::kReady, error = OTBR_ERROR_INVALID_STATE); + VerifyOrExit(mClient != nullptr, error = OTBR_ERROR_INVALID_STATE); + + aCallback = HandleDuplicateKeyRegistration(aName, aKeyData, std::move(aCallback)); + VerifyOrExit(!aCallback.IsNull()); + + VerifyOrExit((group = CreateGroup(mClient)) != nullptr, error = OTBR_ERROR_MDNS); + + fullKeyName = MakeFullKeyName(aName); + + avahiError = avahi_entry_group_add_record(group, AVAHI_IF_UNSPEC, AVAHI_PROTO_UNSPEC, AvahiPublishFlags{}, + fullKeyName.c_str(), AVAHI_DNS_CLASS_IN, kDnsKeyRecordType, kDefaultTtl, + aKeyData.data(), aKeyData.size()); + VerifyOrExit(avahiError == AVAHI_OK); + + otbrLogInfo("Commit avahi key record for %s", aName.c_str()); + avahiError = avahi_entry_group_commit(group); + VerifyOrExit(avahiError == AVAHI_OK); + + AddKeyRegistration(std::unique_ptr( + new AvahiKeyRegistration(aName, aKeyData, std::move(aCallback), group, this))); + +exit: + if (avahiError != AVAHI_OK || error != OTBR_ERROR_NONE) + { + if (avahiError != AVAHI_OK) + { + error = OTBR_ERROR_MDNS; + otbrLogErr("Failed to publish key record - avahi error: %s!", avahi_strerror(avahiError)); + } + + if (group != nullptr) + { + ReleaseGroup(group); + } + std::move(aCallback)(error); + } + return error; +} + +void PublisherAvahi::UnpublishKey(const std::string &aName, ResultCallback &&aCallback) +{ + otbrError error = OTBR_ERROR_NONE; + + VerifyOrExit(mState == Publisher::State::kReady, error = OTBR_ERROR_INVALID_STATE); + RemoveKeyRegistration(aName, OTBR_ERROR_ABORTED); + +exit: + std::move(aCallback)(error); +} + otbrError PublisherAvahi::TxtDataToAvahiStringList(const TxtData &aTxtData, AvahiStringList *aBuffer, size_t aBufferSize, @@ -870,6 +940,23 @@ Publisher::HostRegistration *PublisherAvahi::FindHostRegistration(const AvahiEnt return result; } +Publisher::KeyRegistration *PublisherAvahi::FindKeyRegistration(const AvahiEntryGroup *aEntryGroup) +{ + KeyRegistration *result = nullptr; + + for (const auto &entry : mKeyRegistrations) + { + const auto &keyReg = static_cast(*entry.second); + if (keyReg.GetEntryGroup() == aEntryGroup) + { + result = entry.second.get(); + break; + } + } + + return result; +} + void PublisherAvahi::SubscribeService(const std::string &aType, const std::string &aInstanceName) { auto service = MakeUnique(*this, aType, aInstanceName); diff --git a/src/mdns/mdns_avahi.hpp b/src/mdns/mdns_avahi.hpp index 29b5ba6730b..f358099a851 100644 --- a/src/mdns/mdns_avahi.hpp +++ b/src/mdns/mdns_avahi.hpp @@ -78,6 +78,7 @@ class PublisherAvahi : public Publisher void UnpublishService(const std::string &aName, const std::string &aType, ResultCallback &&aCallback) override; void UnpublishHost(const std::string &aName, ResultCallback &&aCallback) override; + void UnpublishKey(const std::string &aName, ResultCallback &&aCallback) override; void SubscribeService(const std::string &aType, const std::string &aInstanceName) override; void UnsubscribeService(const std::string &aType, const std::string &aInstanceName) override; void SubscribeHost(const std::string &aHostName) override; @@ -97,6 +98,7 @@ class PublisherAvahi : public Publisher otbrError PublishHostImpl(const std::string &aName, const std::vector &aAddresses, ResultCallback &&aCallback) override; + otbrError PublishKeyImpl(const std::string &aName, const KeyData &aKeyData, ResultCallback &&aCallback) override; void OnServiceResolveFailedImpl(const std::string &aType, const std::string &aInstanceName, int32_t aErrorCode) override; @@ -106,6 +108,7 @@ class PublisherAvahi : public Publisher private: static constexpr size_t kMaxSizeOfTxtRecord = 1024; static constexpr uint32_t kDefaultTtl = 10; // In seconds. + static constexpr uint16_t kDnsKeyRecordType = 25; class AvahiServiceRegistration : public ServiceRegistration { @@ -158,6 +161,25 @@ class PublisherAvahi : public Publisher AvahiEntryGroup *mEntryGroup; }; + class AvahiKeyRegistration : public KeyRegistration + { + public: + AvahiKeyRegistration(const std::string &aName, + const KeyData &aKeyData ResultCallback &&aCallback, + AvahiEntryGroup *aEntryGroup, + PublisherAvahi *aPublisher) + : KeyRegistration(aName, aKeyData, std::move(aCallback), aPublisher) + , mEntryGroup(aEntryGroup) + { + } + + ~AvahiKeyRegistration(void) override; + const AvahiEntryGroup *GetEntryGroup(void) const { return mEntryGroup; } + + private: + AvahiEntryGroup *mEntryGroup; + }; + struct Subscription : private ::NonCopyable { PublisherAvahi *mPublisherAvahi; @@ -347,6 +369,7 @@ class PublisherAvahi : public Publisher ServiceRegistration *FindServiceRegistration(const AvahiEntryGroup *aEntryGroup); HostRegistration *FindHostRegistration(const AvahiEntryGroup *aEntryGroup); + KeyRegistration *FindKeyRegistration(const AvahiEntryGroup *aEntryGroup); AvahiClient *mClient; std::unique_ptr mPoller; diff --git a/src/mdns/mdns_mdnssd.cpp b/src/mdns/mdns_mdnssd.cpp index af3d5468a42..928e3bbe73f 100644 --- a/src/mdns/mdns_mdnssd.cpp +++ b/src/mdns/mdns_mdnssd.cpp @@ -214,7 +214,7 @@ static const char *DNSErrorToString(DNSServiceErrorType aError) } PublisherMDnsSd::PublisherMDnsSd(StateCallback aCallback) - : mHostsRef(nullptr) + : mHostsAndKeysRef(nullptr) , mState(State::kIdle) , mStateCallback(std::move(aCallback)) { @@ -241,17 +241,19 @@ void PublisherMDnsSd::Stop(void) { ServiceRegistrationMap serviceRegistrations; HostRegistrationMap hostRegistrations; + KeyRegistrationMap keyRegistrations; VerifyOrExit(mState == State::kReady); std::swap(mServiceRegistrations, serviceRegistrations); std::swap(mHostRegistrations, hostRegistrations); + std::swap(mKeyRegistrations, keyRegistrations); - if (mHostsRef != nullptr) + if (mHostsAndKeysRef != nullptr) { - DNSServiceRefDeallocate(mHostsRef); - otbrLogDebug("Deallocated DNSServiceRef for hosts: %p", mHostsRef); - mHostsRef = nullptr; + DNSServiceRefDeallocate(mHostsAndKeysRef); + otbrLogDebug("Deallocated DNSServiceRef for hosts and keys: %p", mHostsAndKeysRef); + mHostsAndKeysRef = nullptr; } mSubscribedServices.clear(); @@ -281,9 +283,9 @@ void PublisherMDnsSd::Update(MainloopContext &aMainloop) } } - if (mHostsRef != nullptr) + if (mHostsAndKeysRef != nullptr) { - int fd = DNSServiceRefSockFD(mHostsRef); + int fd = DNSServiceRefSockFD(mHostsAndKeysRef); assert(fd != -1); @@ -318,13 +320,13 @@ void PublisherMDnsSd::Process(const MainloopContext &aMainloop) } } - if (mHostsRef != nullptr) + if (mHostsAndKeysRef != nullptr) { - int fd = DNSServiceRefSockFD(mHostsRef); + int fd = DNSServiceRefSockFD(mHostsAndKeysRef); if (FD_ISSET(fd, &aMainloop.mReadFdSet)) { - readyServices.push_back(mHostsRef); + readyServices.push_back(mHostsAndKeysRef); } } @@ -402,6 +404,30 @@ PublisherMDnsSd::DnssdHostRegistration::~DnssdHostRegistration(void) return; } +PublisherMDnsSd::DnssdKeyRegistration::~DnssdKeyRegistration(void) +{ + int dnsError; + + VerifyOrExit(mServiceRef != nullptr); + + if (IsCompleted()) + { + // Send goodbye message (see comment in `~DnssdHostRegistration`) + dnsError = DNSServiceUpdateRecord(mServiceRef, mRecordRef, kDNSServiceFlagsUnique, mKeyData.size(), + mKeyData.data(), /* ttl */ 1); + otbrLogResult(DNSErrorToOtbrError(dnsError), "Send goodbye message for key %s: %s", mName.c_str(), + DNSErrorToString(dnsError)); + } + + dnsError = DNSServiceRemoveRecord(mServiceRef, mRecordRef, /* flags */ 0); + + otbrLogResult(DNSErrorToOtbrError(dnsError), "Remove key record for %s: %s", mName.c_str(), + DNSErrorToString(dnsError)); + +exit: + return; +} + Publisher::ServiceRegistration *PublisherMDnsSd::FindServiceRegistration(const DNSServiceRef &aServiceRef) { ServiceRegistration *result = nullptr; @@ -441,6 +467,25 @@ Publisher::HostRegistration *PublisherMDnsSd::FindHostRegistration(const DNSServ return result; } +Publisher::KeyRegistration *PublisherMDnsSd::FindKeyRegistration(const DNSServiceRef &aServiceRef, + const DNSRecordRef &aRecordRef) +{ + KeyRegistration *result = nullptr; + + for (auto &entry : mKeyRegistrations) + { + auto &keyReg = static_cast(*entry.second); + + if (keyReg.GetServiceRef() == aServiceRef && keyReg.GetRecordRef() == aRecordRef) + { + result = entry.second.get(); + break; + } + } + + return result; +} + void PublisherMDnsSd::HandleServiceRegisterResult(DNSServiceRef aService, const DNSServiceFlags aFlags, DNSServiceErrorType aError, @@ -556,6 +601,19 @@ void PublisherMDnsSd::UnpublishService(const std::string &aName, const std::stri std::move(aCallback)(error); } +int PublisherMDnsSd::AllocateHostsAndKeysRefIfUnallocated(void) +{ + int dnsError = kDNSServiceErr_NoError; + + VerifyOrExit(mHostsAndKeysRef == nullptr); + + SuccessOrExit(dnsError = DNSServiceCreateConnection(&mHostsAndKeysRef)); + otbrLogDebug("Created new DNSServiceRef for hosts and keys: %p", mHostsAndKeysRef); + +exit: + return dnsError; +} + otbrError PublisherMDnsSd::PublishHostImpl(const std::string &aName, const std::vector &aAddresses, ResultCallback &&aCallback) @@ -573,20 +631,16 @@ otbrError PublisherMDnsSd::PublishHostImpl(const std::string &aName, VerifyOrExit(!aCallback.IsNull()); VerifyOrExit(!aAddresses.empty(), std::move(aCallback)(OTBR_ERROR_NONE)); - if (mHostsRef == nullptr) - { - SuccessOrExit(error = DNSServiceCreateConnection(&mHostsRef)); - otbrLogDebug("Created new DNSServiceRef for hosts: %p", mHostsRef); - } + SuccessOrExit(error = AllocateHostsAndKeysRefIfUnallocated()); - registration = new DnssdHostRegistration(aName, aAddresses, std::move(aCallback), mHostsRef, this); + registration = new DnssdHostRegistration(aName, aAddresses, std::move(aCallback), mHostsAndKeysRef, this); otbrLogInfo("Registering new host %s", aName.c_str()); for (const auto &address : aAddresses) { DNSRecordRef recordRef = nullptr; // Supports only IPv6 for now, may support IPv4 in the future. - SuccessOrExit(error = DNSServiceRegisterRecord(mHostsRef, &recordRef, kDNSServiceFlagsShared, + SuccessOrExit(error = DNSServiceRegisterRecord(mHostsAndKeysRef, &recordRef, kDNSServiceFlagsShared, kDNSServiceInterfaceIndexAny, fullName.c_str(), kDNSServiceType_AAAA, kDNSServiceClass_IN, sizeof(address.m8), address.m8, /* ttl */ 0, HandleRegisterHostResult, this)); @@ -670,6 +724,96 @@ void PublisherMDnsSd::HandleRegisterHostResult(DNSServiceRef aServiceRef, return; } +otbrError PublisherMDnsSd::PublishKeyImpl(const std::string &aName, const KeyData &aKeyData, ResultCallback &&aCallback) +{ + otbrError ret = OTBR_ERROR_NONE; + int error = 0; + std::string fullName; + DNSRecordRef recordRef = nullptr; + + VerifyOrExit(mState == Publisher::State::kReady, ret = OTBR_ERROR_INVALID_STATE); + + fullName = MakeFullKeyName(aName); + + aCallback = HandleDuplicateKeyRegistration(aName, aKeyData, std::move(aCallback)); + VerifyOrExit(!aCallback.IsNull()); + + otbrLogInfo("Registering new key %s", aName.c_str()); + + SuccessOrExit(error = AllocateHostsAndKeysRefIfUnallocated()); + + SuccessOrExit(error = DNSServiceRegisterRecord(mHostsAndKeysRef, &recordRef, kDNSServiceFlagsShared, + kDNSServiceInterfaceIndexAny, fullName.c_str(), kDNSServiceType_KEY, + kDNSServiceClass_IN, aKeyData.size(), aKeyData.data(), /* ttl */ 0, + HandleRegisterKeyResult, this)); + + AddKeyRegistration(std::unique_ptr( + new DnssdKeyRegistration(aName, aKeyData, std::move(aCallback), mHostsAndKeysRef, recordRef, this))); + +exit: + if (error != kDNSServiceErr_NoError || ret != OTBR_ERROR_NONE) + { + if (error != kDNSServiceErr_NoError) + { + ret = DNSErrorToOtbrError(error); + otbrLogErr("Failed to publish/update key for %s mdnssd error: %s!", aName.c_str(), DNSErrorToString(error)); + } + + std::move(aCallback)(ret); + } + return ret; +} + +void PublisherMDnsSd::UnpublishKey(const std::string &aName, ResultCallback &&aCallback) +{ + otbrError error = OTBR_ERROR_NONE; + + VerifyOrExit(mState == Publisher::State::kReady, error = OTBR_ERROR_INVALID_STATE); + RemoveKeyRegistration(aName, OTBR_ERROR_ABORTED); + +exit: + std::move(aCallback)(error); +} + +void PublisherMDnsSd::HandleRegisterKeyResult(DNSServiceRef aServiceRef, + DNSRecordRef aRecordRef, + DNSServiceFlags aFlags, + DNSServiceErrorType aError, + void *aContext) +{ + static_cast(aContext)->HandleRegisterKeyResult(aServiceRef, aRecordRef, aFlags, aError); +} + +void PublisherMDnsSd::HandleRegisterKeyResult(DNSServiceRef aServiceRef, + DNSRecordRef aRecordRef, + DNSServiceFlags aFlags, + DNSServiceErrorType aError) +{ + OTBR_UNUSED_VARIABLE(aFlags); + + otbrError error = DNSErrorToOtbrError(aError); + auto *keyReg = static_cast(FindKeyRegistration(aServiceRef, aRecordRef)); + std::string keyName; + + VerifyOrExit(keyReg != nullptr); + + keyName = MakeFullKeyName(keyReg->mName); + + if (error == OTBR_ERROR_NONE) + { + otbrLogInfo("Successfully registered key for %s", keyName.c_str()); + keyReg->Complete(OTBR_ERROR_NONE); + } + else + { + otbrLogWarning("Failed to register key for %s - mdnssd error: %s", keyName.c_str(), DNSErrorToString(aError)); + RemoveKeyRegistration(keyReg->mName, error); + } + +exit: + return; +} + // See `regtype` parameter of the DNSServiceRegister() function for more information. std::string PublisherMDnsSd::MakeRegType(const std::string &aType, SubTypeList aSubTypeList) { diff --git a/src/mdns/mdns_mdnssd.hpp b/src/mdns/mdns_mdnssd.hpp index 76d13f66004..f781d8cadf3 100644 --- a/src/mdns/mdns_mdnssd.hpp +++ b/src/mdns/mdns_mdnssd.hpp @@ -70,6 +70,7 @@ class PublisherMDnsSd : public MainloopProcessor, public Publisher void UnpublishService(const std::string &aName, const std::string &aType, ResultCallback &&aCallback) override; void UnpublishHost(const std::string &aName, ResultCallback &&aCallback) override; + void UnpublishKey(const std::string &aName, ResultCallback &&aCallback) override; void SubscribeService(const std::string &aType, const std::string &aInstanceName) override; void UnsubscribeService(const std::string &aType, const std::string &aInstanceName) override; void SubscribeHost(const std::string &aHostName) override; @@ -94,6 +95,7 @@ class PublisherMDnsSd : public MainloopProcessor, public Publisher otbrError PublishHostImpl(const std::string &aName, const std::vector &aAddress, ResultCallback &&aCallback) override; + otbrError PublishKeyImpl(const std::string &aName, const KeyData &aKeyData, ResultCallback &&aCallback) override; void OnServiceResolveFailedImpl(const std::string &aType, const std::string &aInstanceName, int32_t aErrorCode) override; @@ -162,6 +164,30 @@ class PublisherMDnsSd : public MainloopProcessor, public Publisher uint32_t mCallbackCount; }; + class DnssdKeyRegistration : public KeyRegistration + { + public: + DnssdKeyRegistration(const std::string &aName, + const KeyData &aKeyData, + ResultCallback &&aCallback, + DNSServiceRef aServiceRef, + DNSRecordRef aRecordRef, + Publisher *aPublisher) + : KeyRegistration(aName, aKeyData, std::move(aCallback), aPublisher) + , mServiceRef(aServiceRef) + , mRecordRef(aRecordRef) + { + } + + ~DnssdKeyRegistration(void) override; + const DNSServiceRef &GetServiceRef(void) const { return mServiceRef; } + const DNSRecordRef &GetRecordRef(void) const { return mRecordRef; } + + private: + DNSServiceRef mServiceRef; + DNSRecordRef mRecordRef; + }; + struct ServiceRef : private ::NonCopyable { DNSServiceRef mServiceRef; @@ -342,13 +368,24 @@ class PublisherMDnsSd : public MainloopProcessor, public Publisher DNSRecordRef aHostRecord, DNSServiceFlags aFlags, DNSServiceErrorType aErrorCode); + static void HandleRegisterKeyResult(DNSServiceRef aServiceRef, + DNSRecordRef aRecordRef, + DNSServiceFlags aFlags, + DNSServiceErrorType aErrorCode, + void *aContext); + void HandleRegisterKeyResult(DNSServiceRef aServiceRef, + DNSRecordRef aRecordRef, + DNSServiceFlags aFlags, + DNSServiceErrorType aErrorCode); static std::string MakeRegType(const std::string &aType, SubTypeList aSubTypeList); ServiceRegistration *FindServiceRegistration(const DNSServiceRef &aServiceRef); HostRegistration *FindHostRegistration(const DNSServiceRef &aServiceRef, const DNSRecordRef &aRecordRef); + KeyRegistration *FindKeyRegistration(const DNSServiceRef &aServiceRef, const DNSRecordRef &aRecordRef); + int AllocateHostsAndKeysRefIfUnallocated(void); - DNSServiceRef mHostsRef; + DNSServiceRef mHostsAndKeysRef; State mState; StateCallback mStateCallback;