diff --git a/src/vt/vrt/collection/manager.h b/src/vt/vrt/collection/manager.h index 718112013e..2872ba411d 100644 --- a/src/vt/vrt/collection/manager.h +++ b/src/vt/vrt/collection/manager.h @@ -304,6 +304,20 @@ struct CollectionManager VirtualProxyType proxy, typename ColT::IndexType idx, Args&&... args ); + /** + * \internal \brief Insert into a collection on this node with a pointer to + * the collection element to insert + * + * \param[in] proxy the collection proxy + * \param[in] idx the index to insert + * \param[in] ptr unique ptr to insert for the collection + */ + template + void staticInsertColPtr( + VirtualProxyType proxy, typename ColT::IndexType idx, + std::unique_ptr ptr + ); + public: /** * \brief Collectively construct a virtual context collection with a staged diff --git a/src/vt/vrt/collection/manager.impl.h b/src/vt/vrt/collection/manager.impl.h index 770d80f766..d31ed92413 100644 --- a/src/vt/vrt/collection/manager.impl.h +++ b/src/vt/vrt/collection/manager.impl.h @@ -1679,20 +1679,46 @@ VirtualProxyType CollectionManager::makeDistProxy(TagType const& tag) { /* end SPMD distributed collection support */ +template +void CollectionManager::staticInsertColPtr( + VirtualProxyType proxy, typename ColT::IndexType idx, + std::unique_ptr ptr +) { + using IndexT = typename ColT::IndexType; + using BaseIdxType = vt::index::BaseIndex; + + auto map_han = UniversalIndexHolder<>::getMap(proxy); + auto holder = findColHolder(proxy); + auto range = holder->max_idx; + auto const num_elms = range.getSize(); + auto fn = auto_registry::getHandlerMap(map_han); + auto const num_nodes = theContext()->getNumNodes(); + auto const cur = static_cast(&idx); + auto const max = static_cast(&range); + auto const home_node = fn(cur, max, num_nodes); + + // Through the attorney, setup all the properties on the newly constructed + // collection element: index, proxy, number of elements. Note: because of + // how the constructor works, the index is not currently available through + // "getIndex" + CollectionTypeAttorney::setup(ptr.get(), num_elms, idx, proxy); + + VirtualPtrType col_ptr( + static_cast*>(ptr.release()) + ); + + // Insert the element into the managed holder for elements + insertCollectionElement( + std::move(col_ptr), idx, range, map_han, proxy, true, home_node + ); +} + template void CollectionManager::staticInsert( VirtualProxyType proxy, typename ColT::IndexType idx, Args&&... args ) { using IndexT = typename ColT::IndexType; using IdxContextHolder = InsertContextHolder; - using BaseIdxType = vt::index::BaseIndex; - - auto const& num_nodes = theContext()->getNumNodes(); - - auto map_han = UniversalIndexHolder<>::getMap(proxy); - - // Set the current context index to `idx` - IdxContextHolder::set(&idx,proxy); auto tuple = std::make_tuple(std::forward(args)...); @@ -1701,12 +1727,8 @@ void CollectionManager::staticInsert( auto range = holder->max_idx; auto const num_elms = range.getSize(); - // Get the handler function - auto fn = auto_registry::getHandlerMap(map_han); - - auto const cur = static_cast(&idx); - auto const max = static_cast(&range); - auto const& home_node = fn(cur, max, num_nodes); + // Set the current context index to `idx` + IdxContextHolder::set(&idx,proxy); #if vt_check_enabled(detector) && vt_check_enabled(cons_multi_idx) auto elm_ptr = DerefCons::derefTuple( @@ -1719,25 +1741,17 @@ void CollectionManager::staticInsert( ); #endif + // Clear the current index context + IdxContextHolder::clear(); + vt_debug_print_verbose( vrt_coll, node, "construct (staticInsert): ptr={}\n", print_ptr(elm_ptr.get()) ); - // Through the attorney, setup all the properties on the newly constructed - // collection element: index, proxy, number of elements. Note: because of - // how the constructor works, the index is not currently available through - // "getIndex" - CollectionTypeAttorney::setup(elm_ptr.get(), num_elms, idx, proxy); - - // Insert the element into the managed holder for elements - insertCollectionElement( - std::move(elm_ptr), idx, range, map_han, proxy, true, home_node - ); - - // Clear the current index context - IdxContextHolder::clear(); + std::unique_ptr col_ptr(static_cast(elm_ptr.release())); + staticInsertColPtr(proxy, idx, std::move(col_ptr)); } template < @@ -1788,6 +1802,9 @@ InsertToken CollectionManager::constructInsertMap( // Insert the meta-data for this new collection insertMetaCollection(proxy, map_han, range, is_static); + // Insert action on cleanup for this collection + theCollection()->addCleanupFn(proxy); + return InsertToken{proxy}; } @@ -3182,7 +3199,7 @@ CollectionManager::restoreFromFile( // @todo: error check the file read with bytes in directory auto col_ptr = checkpoint::deserializeFromFile(file_name); - token[idx].insert(std::move(*col_ptr)); + token[idx].insertPtr(std::move(col_ptr)); } return finishedInsert(std::move(token)); diff --git a/src/vt/vrt/collection/staged_token/token.h b/src/vt/vrt/collection/staged_token/token.h index 384b1d17b4..328d01bb92 100644 --- a/src/vt/vrt/collection/staged_token/token.h +++ b/src/vt/vrt/collection/staged_token/token.h @@ -62,6 +62,8 @@ struct InsertTokenRval { template void insert(Args&&... args); + void insertPtr(std::unique_ptr ptr); + friend CollectionManager; private: diff --git a/src/vt/vrt/collection/staged_token/token.impl.h b/src/vt/vrt/collection/staged_token/token.impl.h index e847813862..b2a7b9bee8 100644 --- a/src/vt/vrt/collection/staged_token/token.impl.h +++ b/src/vt/vrt/collection/staged_token/token.impl.h @@ -59,6 +59,11 @@ void InsertTokenRval::insert(Args&&... args) { return theCollection()->staticInsert(proxy_,idx_,args...); } +template +void InsertTokenRval::insertPtr(std::unique_ptr ptr) { + return theCollection()->staticInsertColPtr(proxy_,idx_,std::move(ptr)); +} + // /*virtual*/ InsertToken::~InsertToken() { // theCollection()->finishedStaticInsert(proxy_); // } diff --git a/tests/unit/collection/test_checkpoint.extended.cc b/tests/unit/collection/test_checkpoint.extended.cc index 86ea5cda9b..52d4b75389 100644 --- a/tests/unit/collection/test_checkpoint.extended.cc +++ b/tests/unit/collection/test_checkpoint.extended.cc @@ -54,9 +54,37 @@ namespace vt { namespace tests { namespace unit { static constexpr std::size_t data1_len = 1024; static constexpr std::size_t data2_len = 64; +static std::size_t counter = 0; + struct TestCol : vt::Collection { - TestCol() = default; + TestCol() { + // fmt::print("{} ctor\n", theContext()->getNode()); + counter++; + } + TestCol(TestCol&& other) + : iter(other.iter), + data1(std::move(other.data1)), + data2(std::move(other.data2)), + token(other.token) + { + // fmt::print("{} move ctor\n", theContext()->getNode()); + counter++; + } + TestCol(TestCol const& other) + : iter(other.iter), + data1(other.data1), + data2(other.data2), + token(other.token) + { + // fmt::print("{} copy ctor\n", theContext()->getNode()); + counter++; + } + + virtual ~TestCol() { + // fmt::print("{} destroying\n", theContext()->getNode()); + counter--; + } struct NullMsg : vt::CollectionMessage {}; @@ -194,11 +222,20 @@ TEST_F(TestCheckpoint, test_checkpoint_1) { // Restoration should be done now vt::theCollective()->barrier(); - runInEpoch([&]{ + runInEpochCollective([&]{ if (this_node == 0) { proxy.broadcast(); } }); + + runInEpochCollective([&]{ + if (this_node == 0) { + proxy.destroy(); + } + }); + + // Ensure that all elements were properly destroyed + EXPECT_EQ(counter, 0); } }