diff --git a/autowiring/CoreContext.h b/autowiring/CoreContext.h index f0810b88f..25f6cf98e 100644 --- a/autowiring/CoreContext.h +++ b/autowiring/CoreContext.h @@ -7,6 +7,7 @@ #include "AutowiringEvents.h" #include "autowiring_error.h" #include "Bolt.h" +#include "CoreContextStateBlock.h" #include "CoreRunnable.h" #include "ContextMember.h" #include "CreationRules.h" @@ -24,12 +25,12 @@ #include "TypeUnifier.h" #include -#include TYPE_INDEX_HEADER #include MEMORY_HEADER +#include FUNCTIONAL_HEADER +#include TYPE_INDEX_HEADER #include STL_UNORDERED_MAP #include STL_UNORDERED_SET -struct CoreContextStateBlock; class AutoInjectable; class AutoPacketFactory; class DeferrableAutowiring; @@ -130,7 +131,7 @@ class CoreContext: // This is a list of concrete types, indexed by the true type of each element. std::vector m_concreteTypes; - // This is a memoization map used to memoize any already-detected interfaces. The map + // This is a memoization map used to memoize any already-detected interfaces. mutable std::unordered_map m_typeMemos; // All known context members, exception filters: @@ -367,6 +368,14 @@ class CoreContext: /// void FindByTypeUnsafe(AnySharedPointer& reference) const; + /// + /// Recursive locking for Autowire satisfaction search + /// + /// + /// The argument &&reference enables implicit type from AnySharedPointerT. + /// + void FindByTypeRecursiveUnsafe(AnySharedPointer&& reference, const std::function& terminal) const; + /// /// Returns or constructs a new AutoPacketFactory instance /// @@ -375,7 +384,7 @@ class CoreContext: /// /// Adds the specified deferrable autowiring as a general recipient of autowiring events /// - void AddDeferred(const AnySharedPointer& reference, DeferrableAutowiring* deferrable); + void AddDeferredUnsafe(const AnySharedPointer& reference, DeferrableAutowiring* deferrable); /// /// Adds a snooper to the snoopers set @@ -846,9 +855,9 @@ class CoreContext: /// template void FindByType(std::shared_ptr& slot) const { - AnySharedPointerT ptr; - FindByType(ptr); - slot = ptr.slot()->template as(); + AnySharedPointerT reference; + FindByType(reference); + slot = reference.slot()->template as(); } /// @@ -856,14 +865,14 @@ class CoreContext: /// template bool FindByTypeRecursive(std::shared_ptr& slot) { - // First-chance resolution in this context and ancestor contexts: - for(CoreContext* pCur = this; pCur; pCur = pCur->m_pParent.get()) { - pCur->FindByType(slot); - if(slot) - return true; + { + std::lock_guard guard(m_stateBlock->m_lock); + FindByTypeRecursiveUnsafe(AnySharedPointerT(), + [&slot](AnySharedPointer& reference){ + slot = reference.slot()->template as(); + }); } - - return false; + return static_cast(slot); } /// @@ -871,19 +880,25 @@ class CoreContext: /// template bool Autowire(AutowirableSlot& slot) { - if(FindByTypeRecursive(slot)) - return true; - - // Failed, defer - AddDeferred(AnySharedPointerT(), &slot); - return false; + { + std::lock_guard lk(m_stateBlock->m_lock); + FindByTypeRecursiveUnsafe(AnySharedPointerT(), + [this, &slot](AnySharedPointer& reference){ + slot = reference.slot()->template as(); + if (!slot) { + AddDeferredUnsafe(AnySharedPointerT(), &slot); + } + }); + } + return static_cast(slot); } /// /// Adds a post-attachment listener in this context for a particular autowired member /// /// - /// A pointer to a deferrable autowiring function which the caller may safely ignore if it's not needed + /// A pointer to a deferrable autowiring function which the caller may safely ignore if it's not needed. + /// Returns nullptr if the call was made immediately. /// /// /// This method will succeed if slot was constructed in this context or any parent context. If the @@ -902,12 +917,22 @@ class CoreContext: /// template const AutowirableSlotFn* NotifyWhenAutowired(Fn&& listener) { - auto retVal = MakeAutowirableSlotFn( - shared_from_this(), - std::forward(listener) - ); - - AddDeferred(AnySharedPointerT(), retVal); + AutowirableSlotFn* retVal = nullptr; + { + std::lock_guard lk(m_stateBlock->m_lock); + FindByTypeRecursiveUnsafe(AnySharedPointerT(), + [this, &listener, &retVal](AnySharedPointer& reference) { + if (reference) { + listener(); + } else { + retVal = MakeAutowirableSlotFn( + shared_from_this(), + std::forward(listener) + ); + AddDeferredUnsafe(reference, retVal); + } + }); + } return retVal; } diff --git a/src/autowiring/CoreContext.cpp b/src/autowiring/CoreContext.cpp index 011de19c1..50acfa938 100644 --- a/src/autowiring/CoreContext.cpp +++ b/src/autowiring/CoreContext.cpp @@ -4,7 +4,6 @@ #include "AutoInjectable.h" #include "AutoPacketFactory.h" #include "BoltBase.h" -#include "CoreContextStateBlock.h" #include "CoreThread.h" #include "GlobalCoreContext.h" #include "JunctionBox.h" @@ -285,6 +284,26 @@ void CoreContext::FindByTypeUnsafe(AnySharedPointer& reference) const { m_typeMemos[type].m_value = reference; } +void CoreContext::FindByTypeRecursiveUnsafe(AnySharedPointer&& reference, const std::function& terminal) const { + FindByTypeUnsafe(reference); + if (reference) { + // Type satisfied in current context + terminal(reference); + return; + } + + if (m_pParent) { + std::lock_guard guard(m_pParent->m_stateBlock->m_lock); + // Recurse while holding lock on this context + // NOTE: Deadlock is only possible if there is a simultaneous descending locked chain, + // but by definition of contexts this is forbidden. + m_pParent->FindByTypeRecursiveUnsafe(std::move(reference), terminal); + } else { + // Call function while holding all locks through global scope. + terminal(reference); + } +} + std::shared_ptr CoreContext::GetGlobal(void) { return std::static_pointer_cast(GlobalCoreContext::Get()); } @@ -481,6 +500,9 @@ void CoreContext::BuildCurrentState(void) { } void CoreContext::CancelAutowiringNotification(DeferrableAutowiring* pDeferrable) { + if (!pDeferrable) + return; + std::lock_guard lk(m_stateBlock->m_lock); auto q = m_typeMemos.find(pDeferrable->GetType()); if(q == m_typeMemos.end()) @@ -764,10 +786,7 @@ std::shared_ptr CoreContext::GetPacketFactory(void) { return pf; } -void CoreContext::AddDeferred(const AnySharedPointer& reference, DeferrableAutowiring* deferrable) -{ - std::lock_guard lk(m_stateBlock->m_lock); - +void CoreContext::AddDeferredUnsafe(const AnySharedPointer& reference, DeferrableAutowiring* deferrable) { // Determine whether a type memo exists right now for the thing we're trying to defer. If it doesn't // exist, we need to inject one in order to allow deferred satisfaction to know what kind of type we // are trying to satisfy at this point. diff --git a/src/autowiring/test/PostConstructTest.cpp b/src/autowiring/test/PostConstructTest.cpp index d20a9692d..5c948a38a 100644 --- a/src/autowiring/test/PostConstructTest.cpp +++ b/src/autowiring/test/PostConstructTest.cpp @@ -264,10 +264,9 @@ TEST_F(PostConstructTest, ContextNotifyWhenAutowired) { // Now we'd like to be notified when SimpleObject gets added: ctxt->NotifyWhenAutowired( - [called] { + [called] { *called = true; - } - ); + }); // Should only be two uses, at this point, of the capture of the above lambda: EXPECT_EQ(2L, called.use_count()) << "Unexpected number of references held in a capture lambda"; @@ -283,3 +282,23 @@ TEST_F(PostConstructTest, ContextNotifyWhenAutowired) { ASSERT_TRUE(called.unique()) << "Autowiring notification lambda was not properly cleaned up"; } +TEST_F(PostConstructTest, ContextNotifyWhenAutowiredPostConstruct) { + auto called = std::make_shared(false); + AutoCurrentContext ctxt; + + // Create an object that will satisfy subsequent notification call: + AutoRequired sobj; + + // Notification should be immediate: + ctxt->NotifyWhenAutowired( + [called] { + *called = true; + }); + + // Insert the SimpleObject, see if the lambda got hit: + ASSERT_TRUE(*called) << "Context-wide autowiring notification was not hit as expected when a matching type was injected into a context"; + + // Our shared pointer should be unique by this point, because the lambda should have been destroyed + ASSERT_TRUE(called.unique()) << "Autowiring notification lambda was not properly cleaned up"; +} +