Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#268: Fix deserialization of polymorphic types when base class is specified as a template parameter. #362

Open
wants to merge 14 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/checkpoint/dispatch/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,20 @@ template <typename Serializer, typename T>
inline void serializeArray(Serializer& s, T* array, SerialSizeType const len);

template <typename T>
buffer::ImplReturnType serializeType(T& target, BufferObtainFnType fn = nullptr);
typename std::enable_if<vrt::VirtualSerializeTraits<T>::has_virtual_serialize, buffer::ImplReturnType>::type
serializeType(T& target, BufferObtainFnType fn = nullptr);

template <typename T>
T* deserializeType(SerialByteType* data, SerialByteType* allocBuf = nullptr);
typename std::enable_if<!vrt::VirtualSerializeTraits<T>::has_virtual_serialize, buffer::ImplReturnType>::type
serializeType(T& target, BufferObtainFnType fn = nullptr);

template <typename T>
typename std::enable_if<vrt::VirtualSerializeTraits<T>::has_virtual_serialize, T*>::type
deserializeType(SerialByteType* data, SerialByteType* allocBuf = nullptr);

template <typename T>
typename std::enable_if<!vrt::VirtualSerializeTraits<T>::has_virtual_serialize, T*>::type
deserializeType(SerialByteType* data, SerialByteType* allocBuf = nullptr);

template <typename T>
void deserializeType(InPlaceTag, SerialByteType* data, T* t);
Expand Down
79 changes: 77 additions & 2 deletions src/checkpoint/dispatch/dispatch.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,18 @@ packBuffer(T& target, SerialSizeType size, BufferObtainFnType fn) {
}

template <typename T>
buffer::ImplReturnType serializeType(T& target, BufferObtainFnType fn) {
typename std::enable_if<
!vrt::VirtualSerializeTraits<T>::has_virtual_serialize,
buffer::ImplReturnType>::type
serializeType(T& target, BufferObtainFnType fn) {
auto len = Standard::size<T, Sizer>(target);
debug_checkpoint("serializeType: len=%ld\n", len);
return packBuffer<T>(target, len, fn);
}

template <typename T>
T* deserializeType(SerialByteType* data, SerialByteType* allocBuf) {
typename std::enable_if<!vrt::VirtualSerializeTraits<T>::has_virtual_serialize, T*>::type
deserializeType(SerialByteType* data, SerialByteType* allocBuf) {
auto mem = allocBuf ? allocBuf : Standard::allocate<T>();
auto t_buf = std::unique_ptr<T>(Standard::construct<T>(mem));
T* traverser =
Expand All @@ -287,6 +291,77 @@ void deserializeType(InPlaceTag, SerialByteType* data, T* t) {
Standard::unpack<T, UnpackerBuffer<buffer::UserBuffer>>(t, data);
}

template <typename T>
struct PrefixedType {
using BaseType = vrt::checkpoint_base_type_t<T>;

// Create PrefixedType for serialization purposes
explicit PrefixedType(BaseType* target) : target_(target) {
prefix_ = target->_checkpointDynamicTypeIndex();
}

// Create PrefixedType for deserialization purposes
explicit PrefixedType(SerialByteType* allocBuf) : unpack_buf_(allocBuf) { }

template <typename SerializerT>
void serialize(SerializerT& s) {
s | prefix_;

// Determine the correct type and allocate memory
if (s.isUnpacking()) {
validatePrefix(prefix_);

auto mem = unpack_buf_ ? unpack_buf_ : vrt::objregistry::allocateConcreteType<BaseType>(prefix_);
target_ = vrt::objregistry::constructConcreteType<BaseType>(prefix_, mem);
}

s | *target_;
}

BaseType* getTarget() const {
return target_;
}

private:
void validatePrefix(vrt::TypeIdx prefix) {
if (!vrt::objregistry::isValidIdx<BaseType>(prefix)) {
std::string const err = std::string("Unpacking invalid prefix type (") +
std::to_string(prefix) + std::string(") from object registry for type=") +
std::string(typeregistry::getTypeName<BaseType>());
throw serialization_error(err);
}
}

vrt::TypeIdx prefix_ = 0;
BaseType* target_ = nullptr;
SerialByteType* unpack_buf_ = nullptr;
};

template <typename T>
typename std::enable_if<
vrt::VirtualSerializeTraits<T>::has_virtual_serialize,
buffer::ImplReturnType>::type
serializeType(T& target, BufferObtainFnType fn) {
using BaseType = vrt::checkpoint_base_type_t<T>;
using PrefixedType = PrefixedType<BaseType>;

auto prefixed = PrefixedType(&target);
auto len = Standard::size<PrefixedType, Sizer>(prefixed);
debug_checkpoint("serializeType: len=%ld\n", len);
return packBuffer<PrefixedType>(prefixed, len, fn);
}

template <typename T>
typename std::enable_if<vrt::VirtualSerializeTraits<T>::has_virtual_serialize, T*>::type
deserializeType(SerialByteType* data, SerialByteType* allocBuf) {
using BaseType = vrt::checkpoint_base_type_t<T>;
using PrefixedType = PrefixedType<BaseType>;

auto prefixed = PrefixedType(allocBuf);
auto* traverser = Standard::unpack<PrefixedType, UnpackerBuffer<buffer::UserBuffer>>(&prefixed, data);
return static_cast<T*>(traverser->getTarget());
}

}} /* end namespace checkpoint::dispatch */

#endif /*INCLUDED_SRC_CHECKPOINT_DISPATCH_DISPATCH_IMPL_H*/
5 changes: 5 additions & 0 deletions src/checkpoint/dispatch/vrt/object_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ inline auto getObjIdx(TypeIdx han) {
return getRegistry<T>().at(han).idx_;
}

template <typename T>
inline auto isValidIdx(TypeIdx han) {
return getRegistry<T>().size() > static_cast<std::size_t>(han);
}

template <typename T>
inline auto getSizeConcreteType(TypeIdx han) {
return getRegistry<T>().at(han).size_;
Expand Down
129 changes: 129 additions & 0 deletions tests/unit/test_polymorphic.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
//@HEADER
// *****************************************************************************
//
// test_polymorphic.cc
// DARMA/magistrate => Serialization Library
//
// Copyright 2019 National Technology & Engineering Solutions of Sandia, LLC
// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S.
// Government retains certain rights in this software.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from this
// software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
// Questions? Contact darma@sandia.gov
//
// *****************************************************************************
//@HEADER
*/
#include <gtest/gtest.h>

#include "test_harness.h"

#include <checkpoint/checkpoint.h>
#include <checkpoint/dispatch/vrt/base.h>

namespace checkpoint { namespace tests { namespace unit {

using TestPolymorphic = TestHarness;

struct Base {
explicit Base() = default;
explicit Base(int val_in): base_val_(val_in) {};
virtual ~Base() = default;

checkpoint_virtual_serialize_root()

int base_val_;
virtual int getVal() {
return base_val_;
}

template <typename Serializer>
void serialize(Serializer& s) {
s | base_val_;
}
};

struct Derived1: public Base {
explicit Derived1() = default;
explicit Derived1(int val_in): Base(0), derived_val_(val_in) {};
virtual ~Derived1() = default;

checkpoint_virtual_serialize_derived_from(Base)

int derived_val_;
int getVal() override {
return derived_val_;
}

template <typename Serializer>
void serialize(Serializer& s) {
s | derived_val_;
}
};

struct Derived2: public Derived1 {
explicit Derived2() = default;
explicit Derived2(int val_in): Derived1(0), derived_val_2_(val_in) {};
virtual ~Derived2() = default;

checkpoint_virtual_serialize_derived_from(Derived1)

int derived_val_2_;
int getVal() override {
return derived_val_2_;
}

template <typename Serializer>
void serialize(Serializer& s) {
s | derived_val_2_;
}
};

template<typename Base, typename Derived>
void testPolymorphicTypes(int val) {
std::unique_ptr<Base> task(new Derived(val));
auto ret = checkpoint::serialize(*task);
auto out = checkpoint::deserialize<Base>(std::move(ret));

EXPECT_TRUE(nullptr != out);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should probably include an ID check here to ensure we are getting the right type instead of just checking that it's not null and the output value is correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the check for typeid.

EXPECT_EQ(typeid(*task), typeid(*out));
EXPECT_TRUE(nullptr != dynamic_cast<Derived*>(out.get()));
EXPECT_EQ(val, out->getVal());
}

TEST_F(TestPolymorphic, test_polymorphic_type) {
testPolymorphicTypes<Derived2, Derived2>(5);
testPolymorphicTypes<Derived1, Derived2>(50);
testPolymorphicTypes<Base, Derived2>(500);
testPolymorphicTypes<Derived1, Derived1>(10);
testPolymorphicTypes<Base, Derived1>(100);
testPolymorphicTypes<Base, Base>(1);
}

}}} // end namespace checkpoint::tests::unit
24 changes: 24 additions & 0 deletions tests/unit/test_virtual_serialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,30 @@ INSTANTIATE_TYPED_TEST_CASE_P(
test_virtual_serialize_inst, TestVirtualSerialize, ConstructTypes,
);

/*
* Test for deserialization when using the base class type
*/

using TestDeserializationFromBase = TestHarness;

template<typename Base, typename Derived>
void testDeserializationFromBase(TestEnum expected_id) {
std::unique_ptr<Base> task(new Derived(TEST_CONSTRUCT{}));
auto ret = checkpoint::serialize<Base>(*task);
auto out = checkpoint::deserialize<Base>(std::move(ret));

EXPECT_TRUE(nullptr != out);
thearusable marked this conversation as resolved.
Show resolved Hide resolved
EXPECT_EQ(expected_id, out->getID());
out->check();
}

TEST_F(TestDeserializationFromBase, test_deserialization_from_base) {
testDeserializationFromBase<test_2::TestBase, test_2::TestDerived3>(
TestEnum::Derived3);
testDeserializationFromBase<test_3::TestBase, test_3::TestDerived2>(
TestEnum::Derived2);
}

////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
Expand Down