Skip to content

Commit

Permalink
New implementation for polymorphic chaining
Browse files Browse the repository at this point in the history
-This version has all of the debug messages left in, next commit will remove them
-Algorithm description provided in comments
-relates #354
  • Loading branch information
AzothAmmo committed Oct 27, 2016
1 parent fc57d93 commit bf0f9ae
Showing 1 changed file with 158 additions and 47 deletions.
205 changes: 158 additions & 47 deletions include/cereal/details/polymorphic_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@
#include <functional>
#include <typeindex>
#include <map>
#include <limits>
#include <set>
#include <stack>

#include "cereal/types/map.hpp"
#include "cereal/types/vector.hpp"
#include "cereal/types/string.hpp"
#include "cereal/archives/json.hpp" // DEBUG

//! Binds a polymorhic type to all registered archives
/*! This binds a polymorphic type to all compatible registered archives that
Expand All @@ -74,6 +82,26 @@

namespace cereal
{
template <class Archive> inline
std::string save_minimal( Archive const &, std::type_index const & t )
{
return util::demangle( t.name() );
}

template <class Archive> inline
void load_minimal( Archive const &, std::type_index & t, std::string const & s )
{
}

namespace detail
{
struct PolymorphicCaster;
}

template <class Archive> inline
void serialize( Archive &, detail::PolymorphicCaster const * )
{ }

/* Polymorphic casting support */
namespace detail
{
Expand Down Expand Up @@ -115,6 +143,8 @@ namespace cereal
//! Maps from base type index to a map from derived type index to caster
std::map<std::type_index, std::map<std::type_index, std::vector<PolymorphicCaster const*>>> map;

std::multimap<std::type_index, std::type_index> reverseMap;

//! Error message used for unregistered polymorphic casts
#define UNREGISTERED_POLYMORPHIC_CAST_EXCEPTION(LoadSave) \
throw cereal::Exception("Trying to " #LoadSave " a registered polymorphic type with an unregistered polymorphic cast.\n" \
Expand Down Expand Up @@ -171,6 +201,8 @@ namespace cereal
{
auto const & mapping = lookup( baseInfo, typeid(Derived), [&](){ UNREGISTERED_POLYMORPHIC_CAST_EXCEPTION(save) } );

std::cerr << "------- DOWNCAST " << util::demangle(baseInfo.name()) << "->" << util::demangledName<Derived>() << std::endl;

for( auto const * map : mapping )
dptr = map->downcast( dptr );

Expand All @@ -185,9 +217,11 @@ namespace cereal
{
auto const & mapping = lookup( baseInfo, typeid(Derived), [&](){ UNREGISTERED_POLYMORPHIC_CAST_EXCEPTION(load) } );

std::cerr << "------- UPCAST " << util::demangle(baseInfo.name()) << "->" << util::demangledName<Derived>() << std::endl;

void * uptr = dptr;
for( auto const * map : mapping )
uptr = map->upcast( uptr );
for( auto mIter = mapping.rbegin(), mEnd = mapping.rend(); mIter != mEnd; ++mIter )
uptr = (*mIter)->upcast( uptr );

return uptr;
}
Expand All @@ -198,9 +232,11 @@ namespace cereal
{
auto const & mapping = lookup( baseInfo, typeid(Derived), [&](){ UNREGISTERED_POLYMORPHIC_CAST_EXCEPTION(load) } );

std::cerr << "------- SPTR UPCAST " << util::demangle(baseInfo.name()) << "->" << util::demangledName<Derived>() << std::endl;

std::shared_ptr<void> uptr = dptr;
for( auto const * map : mapping )
uptr = map->upcast( uptr );
for( auto mIter = mapping.rbegin(), mEnd = mapping.rend(); mIter != mEnd; ++mIter )
uptr = (*mIter)->upcast( uptr );

return uptr;
}
Expand All @@ -212,117 +248,192 @@ namespace cereal
template <class Base, class Derived>
struct PolymorphicVirtualCaster : PolymorphicCaster
{
template <class T>
static void print( std::string const & msg, T const & t )
{
std::cerr << msg << std::endl;
{
cereal::JSONOutputArchive ar(std::cerr);
ar(t);
}
std::cerr << std::endl;
}

//! Inserts an entry in the polymorphic casting map for this pairing
/*! Creates an explicit mapping between Base and Derived in both upwards and
downwards directions, allowing void pointers to either to be properly cast
assuming dynamic type information is available */
PolymorphicVirtualCaster()
{
const auto baseKey = std::type_index(typeid(Base));
const auto derivedKey = std::type_index(typeid(Derived));

// First insert the relation Base->Derived
const auto lock = StaticObject<PolymorphicCasters>::lock();
auto & baseMap = StaticObject<PolymorphicCasters>::getInstance().map;
auto baseKey = std::type_index(typeid(Base));
auto lb = baseMap.lower_bound(baseKey);

{
auto & derivedMap = baseMap.insert( lb, {baseKey, {}} )->second;
auto derivedKey = std::type_index(typeid(Derived));
auto lbd = derivedMap.lower_bound(derivedKey);
auto & derivedVec = derivedMap.insert( lbd, { std::move(derivedKey), {}} )->second;
derivedVec.push_back( this );
}

auto str = [](std::type_index t){ return util::demangle(t.name()); };

std::cerr << "Register relation " << str(baseKey) << "->" << str(derivedKey) << std::endl;

// Insert reverse relation Derived->Base
auto & reverseMap = StaticObject<PolymorphicCasters>::getInstance().reverseMap;
reverseMap.emplace( derivedKey, baseKey );

print( "baseMap", baseMap );
print( "reverseMap", reverseMap );

// Find all chainable unregistered relations
/* The strategy here is to process only the nodes in the class hierarchy graph that have been
affected by the new insertion. The aglorithm iteratively processes a node an ensures that it
is updated with all new shortest length paths. It then rocesses the parents of the active node,
with the knowledge that all children have already been processed.
Note that for the following, we'll use the nomenclature of parent and child to not confuse with
the inserted base derived relationship */
{
using Relations = std::map<std::type_index, std::pair<std::type_index, std::vector<PolymorphicCaster const *>>>;
auto findChainableRelations = [&baseMap]() -> Relations
// Checks whether there is a path from parent->child and returns a <dist, path> pair
// dist is set to MAX if the path does not exist
auto checkRelation = [](std::type_index const & parentInfo, std::type_index const & childInfo) ->
std::pair<size_t, std::vector<PolymorphicCaster const *>>
{
auto checkRelation = [](std::type_index const & baseInfo, std::type_index const & derivedInfo)
if( PolymorphicCasters::exists( parentInfo, childInfo ) )
{
const bool exists = PolymorphicCasters::exists( baseInfo, derivedInfo );
return std::make_pair( exists, exists ? PolymorphicCasters::lookup( baseInfo, derivedInfo, [](){} ) :
std::vector<PolymorphicCaster const *>{} );
};

Relations unregisteredRelations;
for( auto const & baseIt : baseMap )
for( auto const & derivedIt : baseIt.second )
auto const & path = PolymorphicCasters::lookup( parentInfo, childInfo, [](){} );
return {path.size(), path};
}
else
return {std::numeric_limits<size_t>::max(), {}};
};

std::stack<std::type_index> parentStack; // Holds the parent nodes to be processed
std::set<std::type_index> dirtySet; // Marks child nodes that have been changed
std::set<std::type_index> processedParents; // Marks parent nodes that have been processed

// Begin processing the base key and mark derived as dirty
parentStack.push( baseKey );
dirtySet.insert( derivedKey );

while( !parentStack.empty() )
{
using Relations = std::multimap<std::type_index, std::pair<std::type_index, std::vector<PolymorphicCaster const *>>>;
Relations unregisteredRelations; // Defer insertions until after main loop to prevent iterator invalidation

const auto parent = parentStack.top();
parentStack.pop();

std::cerr << "Processing " << str(parent) << std::endl;
print( "parent stack", parentStack );
print( "dirty set", dirtySet );
print( "processed set", processedParents );

// Update paths to all children marked dirty
for( auto const & childPair : baseMap[parent] )
{
const auto child = childPair.first;
if( dirtySet.count( child ) && baseMap.count( child ) )
{
for( auto const & otherBaseIt : baseMap )
auto parentChildPath = checkRelation( parent, child );

std::cerr << "Child has childreN? " << baseMap.count(child) << std::endl;

// Search all paths from the child to its own children (finalChild),
// looking for a shorter parth from parent to finalChild
for( auto const & finalChildPair : baseMap[child] )
{
if( baseIt.first == otherBaseIt.first ) // only interested in chained relations
continue;
const auto finalChild = finalChildPair.first;

// Check if there exists a mapping otherBase -> base -> derived that is shorter than
// any existing otherBase -> derived direct mapping
auto otherBaseItToDerived = checkRelation( otherBaseIt.first, derivedIt.first );
auto baseToDerived = checkRelation( baseIt.first, derivedIt.first );
auto otherBaseToBase = checkRelation( otherBaseIt.first, baseIt.first );
auto parentFinalChildPath = checkRelation( parent, finalChild );
auto childFinalChildPath = checkRelation( child, finalChild );

const size_t newLength = otherBaseToBase.second.size() + baseToDerived.second.size();
const bool isShorterOrFirstPath = !otherBaseItToDerived.first || (newLength < derivedIt.second.size());
const size_t newLength = 1u + parentChildPath.first;

if( isShorterOrFirstPath &&
baseToDerived.first &&
otherBaseToBase.first )
if( newLength < parentFinalChildPath.first )
{
std::vector<PolymorphicCaster const *> path = otherBaseToBase.second;
path.insert( path.end(), baseToDerived.second.begin(), baseToDerived.second.end() );
std::vector<PolymorphicCaster const *> path = parentChildPath.second;
path.insert( path.end(), childFinalChildPath.second.begin(), childFinalChildPath.second.end() );

// Check to see if we have a previous uncommitted path in unregisteredRelations
// that is shorter. If so, ignore this path
auto hint = unregisteredRelations.find( otherBaseIt.first );
auto hint = unregisteredRelations.find( parent );
const bool uncommittedExists = hint != unregisteredRelations.end();
if( uncommittedExists && (hint->second.second.size() <= newLength) )
continue;

auto newPath = std::pair<std::type_index, std::vector<PolymorphicCaster const *>>{derivedIt.first, std::move(path)};
auto newPath = std::pair<std::type_index, std::vector<PolymorphicCaster const *>>{finalChild, std::move(path)};

// Insert the new path if it doesn't exist, otherwise this will just lookup where to do the
// replacement
#ifdef CEREAL_OLDER_GCC
auto old = unregisteredRelations.insert( hint, std::make_pair(otherBaseIt.first, newPath) );
auto old = unregisteredRelations.insert( hint, std::make_pair(parent, newPath) );
#else // NOT CEREAL_OLDER_GCC
auto old = unregisteredRelations.emplace_hint( hint, otherBaseIt.first, newPath );
auto old = unregisteredRelations.emplace_hint( hint, parent, newPath );
#endif // NOT CEREAL_OLDER_GCC

// If there was an uncommitted path, we need to perform a replacement
if( uncommittedExists )
old->second = newPath;

std::cerr << "New relation " << str(parent) << "->" << str(finalChild) << ", child: " << str(child) << std::endl;
}
} // end otherBaseIt
} // end derivedIt
return unregisteredRelations;
}; // end findChainableRelations
} // end loop over child's children
} // end if dirty and child has children
} // end loop over children

Relations unregisteredRelations;
do
{
unregisteredRelations = findChainableRelations();
// Insert chained relations
for( auto const & it : unregisteredRelations )
{
auto & derivedMap = baseMap.find( it.first )->second;
derivedMap[it.second.first] = it.second.second;
reverseMap.emplace( it.second.first, it.first );

std::cerr << "Chained relation" << str(it.first) << "->" << str(it.second.first) << std::endl;
}
} while ( !unregisteredRelations.empty() );
} // end chain lookup
}

// Mark current parent as modified
dirtySet.insert( parent );

// Insert all parents of the current parent node that haven't yet been processed
auto parentRange = reverseMap.equal_range( parent );
for( auto pIter = parentRange.first; pIter != parentRange.second; ++pIter )
{
const auto pParent = pIter->second;
if( !processedParents.count( pParent ) )
{
parentStack.push( pParent );
processedParents.insert( pParent );
}
}
} // end loop over parent stack
} // end chainable relations
} // end PolymorphicVirtualCaster()

//! Performs the proper downcast with the templated types
void const * downcast( void const * const ptr ) const override
{
std::cerr << "DOWNCAST " << util::demangledName<Base>() << "->" << util::demangledName<Derived>() << std::endl;
return dynamic_cast<Derived const*>( static_cast<Base const*>( ptr ) );
}

//! Performs the proper upcast with the templated types
void * upcast( void * const ptr ) const override
{
std::cerr << "UPCAST " << util::demangledName<Derived>() << "->" << util::demangledName<Base>() << std::endl;
return dynamic_cast<Base*>( static_cast<Derived*>( ptr ) );
}

//! Performs the proper upcast with the templated types (shared_ptr version)
std::shared_ptr<void> upcast( std::shared_ptr<void> const & ptr ) const override
{
std::cerr << "SPTR UPCAST " << util::demangledName<Derived>() << "->" << util::demangledName<Base>() << std::endl;
return std::dynamic_pointer_cast<Base>( std::static_pointer_cast<Derived>( ptr ) );
}
};
Expand Down

0 comments on commit bf0f9ae

Please sign in to comment.