diff --git a/src/mdns/mdns_mdnssd.cpp b/src/mdns/mdns_mdnssd.cpp index edbd671c3bc..946ecf1cce6 100644 --- a/src/mdns/mdns_mdnssd.cpp +++ b/src/mdns/mdns_mdnssd.cpp @@ -505,6 +505,11 @@ void PublisherMDnsSd::DnssdServiceRegistration::HandleRegisterResult(DNSServiceR void PublisherMDnsSd::DnssdServiceRegistration::HandleRegisterResult(DNSServiceFlags aFlags, DNSServiceErrorType aError) { + if (mReleatedKeyReg != nullptr) + { + mReleatedKeyReg->HandleRegisterResult(aError); + } + if ((aError == kDNSServiceErr_NoError) && (aFlags & kDNSServiceFlagsAdd)) { otbrLogInfo("Successfully registered service %s.%s", mName.c_str(), mType.c_str()); @@ -590,8 +595,8 @@ void PublisherMDnsSd::DnssdHostRegistration::HandleRegisterResult(DNSServiceRef DNSServiceErrorType aError, void *aContext) { - OT_UNUSED_VARIABLE(aServiceRef); - OT_UNUSED_VARIABLE(aFlags); + OTBR_UNUSED_VARIABLE(aServiceRef); + OTBR_UNUSED_VARIABLE(aFlags); static_cast(aContext)->HandleRegisterResult(aRecordRef, aError); } @@ -630,6 +635,122 @@ void PublisherMDnsSd::DnssdHostRegistration::HandleRegisterResult(DNSRecordRef a } } +//~~~~~~~~~~~~~~~~~~ + +otbrError PublisherMDnsSd::DnssdKeyRegistration::Register(void) +{ + DNSServiceErrorType dnsError = kDNSServiceErr_NoError; + DnssdServiceRegistration *serviceReg; + + otbrLogInfo("Registering new key %s", mName.c_str()); + + mRelatedServiceReg = static_cast(GetPublisher().FindServiceRegisteration(mName)); + + if (mRelatedServiceReg != nullptr) + { + dnsError = DNSServiceAddRecord(mRelatedServiceReg->mServiceRef, mRecordRef, kDNSServiceFlagsUnique, kDNSServiceType_KEY, + mKeyData.size(), mKeyData.data(), /* ttl */ 0); + + VerifyOrExit(dnsError == kDNSServiceErr_NoError, mRelatedServiceReg = nullptr); + + mRelatedServiceReg->mReleatedKeyReg = this; + + if (mRelatedServiceReg.IsCompleted()) + { + HandleRegisterResult(kDNSServiceErr_NoError); + } + + // Otherwise we wait for service registration completion to signal + // key record registration as well. + } + else + { + dnsError = GetPublisher().CreateSharedHostsRef(); + VerifyOrExit(dnsError == kDNSServiceErr_NoError); + + dnsError = DNSServiceRegisterRecord(GetPublisher().mHostsRef, &mRecordRef, kDNSServiceFlagsUnique, + kDNSServiceInterfaceIndexAny, MakeFullKeyName(mName).c_str(), + kDNSServiceType_KEY, kDNSServiceClass_IN, mKeyData.size(), mKeyData.data(), + /* ttl */ 0, HandleRegisterResult, this); + VerifyOrExit(dnsError == kDNSServiceErr_NoError); + } + +exit: + if (dnsError != kDNSServiceErr_NoError) + { + HandleRegisterResult(dnsError); + } + + return GetPublisher().DnsErrorToOtbrError(dnsError); +} + +void PublisherMDnsSd::DnssdKeyRegistration::Unregister(void) +{ + DNSServiceErrorType dnsError; + DNSServiceRef serviceRef; + + VerifyOrExit(mRecordRef != nullptr); + + if (mRelatedServiceReg != nullptr) + { + serviceRef = mRelatedServiceReg->mServiceRef; + + mRelatedServiceReg->mReleatedKeyReg = nullptr; + mRelatedServiceReg = nullptr; + } + else + { + serviceRef = GetPublisher().mHostsRef; + } + + VerifyOrExit(serviceRef != nullptr); + + dnsError = DNSServiceRemoveRecord(serviceRef, mRecordRef, /* flags */ 0); + + otbrLogInfo("Unregistered key %s: error:%s", mName.c_str(), DNSE) + +exit: + return; +} + +void PublisherMDnsSd::DnssdKeyRegistration::HandleRegisterResult(DNSServiceRef aServiceRef, + DNSRecordRef aRecordRef, + DNSServiceFlags aFlags, + DNSServiceErrorType aError, + void *aContext) +{ + OTBR_UNUSED_VARIABLE(aServiceRef); + OTBR_UNUSED_VARIABLE(aRecordRef); + OTBR_UNUSED_VARIABLE(aFlags); + + static_cast(aContext)->HandleRegisterResult(aRecordRef, aError); +} + +void PublisherMDnsSd::DnssdKeyRegistration::HandleRegisterResult(DNSServiceErrorType aError) +{ + if (aError != kDNSServiceErr_NoError) + { + otbrLogErr("Failed to register key %s: %s", mName.c_str(), DNSErrorToString(aError)); + GetPublisher().RemoveKeyRegistration(mName, DNSErrorToOtbrError(aError)); + } + else + { + otbrLogInfo("Successfully registered key %s", mName.c_str()); + Complete(OTBR_ERROR_NONE); + } +} + + + + + + + + + + + + otbrError PublisherMDnsSd::PublishServiceImpl(const std::string &aHostName, const std::string &aName, const std::string &aType, @@ -718,6 +839,42 @@ void PublisherMDnsSd::UnpublishHost(const std::string &aName, ResultCallback &&a std::move(aCallback)(error); } +otbrError PublisherMDnsSd::PublishKeyImpl(const std::string &aName, const KeyData &aKeyData, ResultCallback &&aCallback) +{ + otbrError error = OTBR_ERROR_NONE; + DnssdKeyRegistration *keyReg; + + if (mState != State::kReady) + { + error = OTBR_ERROR_INVALID_STATE; + std::move(aCallback)(error); + ExitNow(); + } + + aCallback = HandleDuplicateKeyRegistration(aName, aKeyData, std::move(aCallback)); + VerifyOrExit(!aCallback.IsNull()); + + keyReg = new DnssdKeyRegistration(aName, aKeyData, std::move(aCallback), this); + AddKeyRegistration(std::unique_ptr(keyReg)); + + error = keyReg->Register(); + +exit: + return error; +} + +void PublisherMDnsSd::UnpublishKey(const std::string &aName, ResultCallback &&aCallback) +{ + otbrError error = OTBR_ERROR_NONE; + + VerifyOrExit(mState == Publisher::State::kReady, error = OTBR_ERROR_INVALID_STATE); + RemoveKeyRegistration(aName, OTBR_ERROR_ABORTED); + +exit: + std::move(aCallback)(error); +} + + // See `regtype` parameter of the DNSServiceRegister() function for more information. std::string PublisherMDnsSd::MakeRegType(const std::string &aType, SubTypeList aSubTypeList) { diff --git a/src/mdns/mdns_mdnssd.hpp b/src/mdns/mdns_mdnssd.hpp index a43d66d58dd..b256d6e1e2b 100644 --- a/src/mdns/mdns_mdnssd.hpp +++ b/src/mdns/mdns_mdnssd.hpp @@ -111,8 +111,12 @@ class PublisherMDnsSd : public MainloopProcessor, public Publisher kStopOnServiceNotRunningError, }; + class DnssdKeyRegistration; + class DnssdServiceRegistration : public ServiceRegistration { + friend class DnssdKeyRegistration; + public: using ServiceRegistration::ServiceRegistration; // Inherit base constructor @@ -134,10 +138,11 @@ class PublisherMDnsSd : public MainloopProcessor, public Publisher const char *aDomain, void *aContext); - DNSServiceRef mServiceRef = nullptr; + DNSServiceRef mServiceRef = nullptr; + DnssdKeyRegistration * mRelatedKeyReg = nullptr; }; - class DnssdKeyRegistration; + class DnssdHostRegistration : public HostRegistration { @@ -160,30 +165,32 @@ class PublisherMDnsSd : public MainloopProcessor, public Publisher std::vector mAddrRecordRefs; std::vector mAddrRegistered; - DnssdKeyRegistration * mKeyRegistration; }; class DnssdKeyRegistration : public KeyRegistration { + friend class DnssdServiceRegistration; + public: - DnssdKeyRegistration(const std::string &aName, - const KeyData &aKeyData, - ResultCallback &&aCallback, - DNSServiceRef aServiceRef, - DNSRecordRef aRecordRef, - Publisher *aPublisher) - : KeyRegistration(aName, aKeyData, std::move(aCallback), aPublisher) - , mServiceRegistration(nullptr) - , mServiceRef(aServiceRef) - , mRecordRef(aRecordRef) - { - } + using KeyRegistration::KeyRegistration; // Inherit base class constructor + + ~DnssdKeyRegistration(void) override { Unregister(); } + + otbrError Register(void); + + private: + void Unregister(void); + PublisherMDnsSd &GetPublisher(void) { return *static_cast(mPublisher); } + void HandleRegisterResult(DNSServiceErrorType aError); + static void HandleRegisterResult(DNSServiceRef aServiceRef, + DNSRecordRef aRecordRef, + DNSServiceFlags aFlags, + DNSServiceErrorType aErrorCode, + void *aContext); - ~DnssdKeyRegistration(void) override; + DNSRecordRef mRecordRef = nullptr; + DnssdServiceRegistration *mRelatedServiceReg = nullptr - DnssdServiceRegistration *mServiceRegistration; - DNSServiceRef mServiceRef; - DNSRecordRef mRecordRef; }; struct ServiceRef : private ::NonCopyable