From f5c10c247685822a705b278b5224d7983586b970 Mon Sep 17 00:00:00 2001 From: Aart Stuurman Date: Mon, 13 Jun 2022 15:27:47 +0200 Subject: [PATCH] Improve serialization and its unit tests (#15) * Improve serialization and unit tests * Remove old commented out code * Restore part of old tests --- CMakeLists.txt | 1 + src/Genes.cpp | 54 ++++++ src/Genes.h | 157 ++++++++++++------ src/Genome.cpp | 25 +++ src/Genome.h | 44 ++--- src/Innovation.cpp | 28 ++++ src/Innovation.h | 29 +++- src/Traits.cpp | 33 +++- src/Traits.h | 91 ++++++++++ tests/serialization/CMakeLists.txt | 12 +- tests/serialization/serialize_genome.cpp | 84 +++++----- .../serialize_innovation_database.cpp | 36 ++++ 12 files changed, 469 insertions(+), 125 deletions(-) create mode 100644 src/Genes.cpp create mode 100644 tests/serialization/serialize_innovation_database.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 9e65af26..b01799b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,6 +31,7 @@ find_package(cereal) set(SOURCE_FILES src/Genome.cpp + src/Genes.cpp src/Innovation.cpp src/NeuralNetwork.cpp src/Parameters.cpp diff --git a/src/Genes.cpp b/src/Genes.cpp new file mode 100644 index 00000000..4a64b6d2 --- /dev/null +++ b/src/Genes.cpp @@ -0,0 +1,54 @@ +#include "Genes.h" + +namespace NEAT +{ + + bool operator==(Gene const &lhs, Gene const &rhs) + { + return lhs.m_Traits == rhs.m_Traits; + } + + std::ostream &operator<<(std::ostream &stream, Gene const &gene) + { + stream << gene.Serialize(); + return stream; + } + + bool operator==(LinkGene const &lhs, LinkGene const &rhs) + { + return static_cast(lhs) == static_cast(rhs) && + lhs.m_FromNeuronID == rhs.m_FromNeuronID && + lhs.m_ToNeuronID == rhs.m_ToNeuronID && + lhs.m_InnovationID == rhs.m_InnovationID && + lhs.m_IsRecurrent == rhs.m_IsRecurrent && + lhs.m_Weight == rhs.m_Weight; + } + + std::ostream &operator<<(std::ostream &stream, LinkGene const &gene) + { + stream << gene.Serialize(); + return stream; + } + + bool operator==(NeuronGene const &lhs, NeuronGene const &rhs) + { + return static_cast(lhs) == static_cast(rhs) && + lhs.m_ID == rhs.m_ID && + lhs.m_Type == rhs.m_Type && + lhs.x == rhs.x && + lhs.y == rhs.y && + lhs.m_SplitY == rhs.m_SplitY && + lhs.m_A == rhs.m_A && + lhs.m_B == rhs.m_B && + lhs.m_TimeConstant == rhs.m_TimeConstant && + lhs.m_Bias == rhs.m_Bias && + lhs.m_ActFunction == rhs.m_ActFunction; + } + + std::ostream &operator<<(std::ostream &stream, NeuronGene const &gene) + { + stream << gene.Serialize(); + return stream; + } + +} diff --git a/src/Genes.h b/src/Genes.h index e68c1c8f..05378fc8 100644 --- a/src/Genes.h +++ b/src/Genes.h @@ -39,6 +39,9 @@ #include "Utils.h" #include +#include +#include +#include namespace NEAT @@ -86,6 +89,9 @@ namespace NEAT ////////////////////////////////// class Gene { + friend bool operator==(Gene const &lhs, Gene const &rhs); + friend std::ostream &operator<<(std::ostream &stream, Gene const &gene); + public: // Arbitrary traits std::map m_Traits; @@ -512,8 +518,34 @@ namespace NEAT return dist; } + + // Serialization + template + void serialize(Archive & ar) + { + ar & m_Traits; + } + + std::string Serialize() const + { + std::ostringstream os; + { + cereal::JSONOutputArchive oa(os); + oa << *this; + } + return os.str(); + } + + void Deserialize(const std::string &text) + { + std::istringstream is (text); + cereal::JSONInputArchive ia(is); + ia >> *this; + } }; + bool operator==(Gene const &lhs, Gene const &rhs); + std::ostream &operator<<(std::ostream &stream, Gene const &gene); ////////////////////////////////// // This class defines a link gene @@ -524,6 +556,9 @@ namespace NEAT // Members ///////////////////// + friend bool operator==(LinkGene const &lhs, LinkGene const &rhs); + friend std::ostream &operator<<(std::ostream &stream, LinkGene const &gene); + public: // These variables are initialized once and cannot be changed @@ -544,20 +579,6 @@ namespace NEAT public: - // Serialization - template - void serialize(Archive & ar) - { - ar & m_FromNeuronID; - ar & m_ToNeuronID; - ar & m_InnovationID; - ar & m_IsRecurrent; - ar & m_Weight; - - // the traits too, TODO - //ar & m_Traits; - } - double GetWeight() const { return m_Weight; @@ -651,12 +672,39 @@ namespace NEAT return (a_lhs.m_InnovationID != a_rhs.m_InnovationID); } - friend bool operator==(const LinkGene &a_lhs, const LinkGene &a_rhs) + // Serialization + template + void serialize(Archive & ar) { - return (a_lhs.m_InnovationID == a_rhs.m_InnovationID); + ar & cereal::base_class(this); + ar & m_FromNeuronID; + ar & m_ToNeuronID; + ar & m_InnovationID; + ar & m_IsRecurrent; + ar & m_Weight; + } + + std::string Serialize() const + { + std::ostringstream os; + { + cereal::JSONOutputArchive oa(os); + oa << *this; + } + return os.str(); + } + + void Deserialize(const std::string &text) + { + std::istringstream is (text); + cereal::JSONInputArchive ia(is); + ia >> *this; } }; + bool operator==(LinkGene const &lhs, LinkGene const &rhs); + std::ostream &operator<<(std::ostream &stream, LinkGene const &gene); + //////////////////////////////////// // This class defines a neuron gene @@ -667,6 +715,9 @@ namespace NEAT // Members ///////////////////// + friend bool operator==(NeuronGene const &lhs, NeuronGene const &rhs); + friend std::ostream &operator<<(std::ostream &stream, NeuronGene const &gene); + public: // These variables are initialized once and cannot be changed // anymore @@ -723,25 +774,6 @@ namespace NEAT // The type of activation function the neuron has ActivationFunction m_ActFunction; - // Serialization - template - void serialize(Archive & ar) - { - ar & m_ID; - ar & m_Type; - ar & m_A; - ar & m_B; - ar & m_TimeConstant; - ar & m_Bias; - ar & x; - ar & y; - ar & m_ActFunction; - ar & m_SplitY; - - // TODO the traits also - //ar & m_Traits; - } - //////////////// // Constructors //////////////// @@ -749,24 +781,6 @@ namespace NEAT { } - - /*friend bool operator!=(const NeuronGene &a_lhs, const NeuronGene &a_rhs) - { - return (a_lhs.m_ID != a_rhs.m_ID); - }*/ - - friend bool operator==(const NeuronGene &a_lhs, const NeuronGene &a_rhs) - { - return (a_lhs.m_ID == a_rhs.m_ID) && - (a_lhs.m_Type == a_rhs.m_Type) - //(a_lhs.m_SplitY == a_rhs.m_SplitY) && - //(a_lhs.m_A == a_rhs.m_A) && - //(a_lhs.m_B == a_rhs.m_B) && - //(a_lhs.m_TimeConstant == a_rhs.m_TimeConstant) && - //(a_lhs.m_Bias == a_rhs.m_Bias) && - //(a_lhs.m_ActFunction == a_rhs.m_ActFunction) - ; - } NeuronGene(NeuronType a_type, int a_id, double a_splity) { @@ -842,8 +856,45 @@ namespace NEAT m_Bias = a_Bias; m_ActFunction = a_ActFunc; } + + // Serialization + template + void serialize(Archive & ar) + { + ar & cereal::base_class(this); + ar & m_ID; + ar & m_Type; + ar & m_A; + ar & m_B; + ar & m_TimeConstant; + ar & m_Bias; + ar & x; + ar & y; + ar & m_ActFunction; + ar & m_SplitY; + } + + std::string Serialize() const + { + std::ostringstream os; + { + cereal::JSONOutputArchive oa(os); + oa << *this; + } + return os.str(); + } + + void Deserialize(const std::string &text) + { + std::istringstream is (text); + cereal::JSONInputArchive ia(is); + ia >> *this; + } }; + bool operator==(NeuronGene const &lhs, NeuronGene const &rhs); + std::ostream &operator<<(std::ostream &stream, NeuronGene const &gene); + } // namespace NEAT diff --git a/src/Genome.cpp b/src/Genome.cpp index d660425b..bd17ead6 100644 --- a/src/Genome.cpp +++ b/src/Genome.cpp @@ -4172,6 +4172,31 @@ namespace NEAT } } + bool operator==(Genome const &lhs, Genome const &rhs) + { + return lhs.m_ID == rhs.m_ID && + lhs.m_NumInputs == rhs.m_NumInputs && + lhs.m_NumOutputs == rhs.m_NumOutputs && + lhs.m_Fitness == rhs.m_Fitness && + lhs.m_AdjustedFitness == rhs.m_AdjustedFitness && + lhs.m_Depth == rhs.m_Depth && + lhs.m_NeuronRecursionLimit == rhs.m_NeuronRecursionLimit && + lhs.m_OffspringAmount == rhs.m_OffspringAmount && + lhs.m_NeuronGenes == rhs.m_NeuronGenes && + lhs.m_LinkGenes == rhs.m_LinkGenes && + lhs.m_GenomeGene == rhs.m_GenomeGene && + lhs.m_Evaluated == rhs.m_Evaluated && + lhs.m_initial_num_neurons == rhs.m_initial_num_neurons && + lhs.m_initial_num_links == rhs.m_initial_num_links && + ((lhs.m_PhenotypeBehavior == nullptr && rhs.m_PhenotypeBehavior == nullptr) || (*lhs.m_PhenotypeBehavior == *rhs.m_PhenotypeBehavior)); + } + + std::ostream &operator<<(std::ostream &stream, Genome const &genome) + { + stream << genome.Serialize(); + return stream; + } + #ifdef PYTHON_BINDINGS pybind11::dict Genome::TraitMap2Dict(const std::map< std::string, Trait>& tmap) const { diff --git a/src/Genome.h b/src/Genome.h index daa77abc..93559b14 100644 --- a/src/Genome.h +++ b/src/Genome.h @@ -122,6 +122,9 @@ namespace NEAT // Returns true is the specified neuron ID is a dead end or isolated bool IsDeadEndNeuron(int a_id) const; + friend bool operator==(Genome const &lhs, Genome const &rhs); + friend std::ostream &operator<<(std::ostream &stream, Genome const &genome); + public: // The two lists of genes @@ -159,13 +162,6 @@ namespace NEAT // assignment operator Genome &operator=(const Genome &a_g); - // comparison operator (nessesary for python bindings) - // todo: implement a better comparison technique - bool operator==(Genome const &other) const - { - return m_ID == other.m_ID; - } - // Builds this genome from a file Genome(const char *a_filename); @@ -643,20 +639,28 @@ namespace NEAT unsigned int output_count, unsigned int hidden_count); // Serialization - template - void serialize(Archive & ar) + template + void serialize(Archive &ar) { - ar & m_ID; - ar & m_NeuronGenes; - ar & m_LinkGenes; - ar & m_NumInputs; - ar & m_NumOutputs; - ar & m_Fitness; - ar & m_AdjustedFitness; - ar & m_Depth; - ar & m_OffspringAmount; - ar & m_Evaluated; - //ar & m_PhenotypeBehavior; // todo: think about how we will handle the behaviors with pickle + ar &m_ID; + ar &m_NeuronGenes; + ar &m_LinkGenes; + ar &m_NumInputs; + ar &m_NumOutputs; + ar &m_Fitness; + ar &m_AdjustedFitness; + ar &m_Depth; + ar &m_OffspringAmount; + ar &m_Evaluated; + ar &m_NeuronRecursionLimit; + ar &m_GenomeGene; + ar &m_initial_num_neurons; + ar &m_initial_num_links; + + if (m_PhenotypeBehavior != nullptr) + { + throw std::runtime_error("m_PhenotypeBehavior not null but serialization not implemented."); + } } std::string Serialize() const diff --git a/src/Innovation.cpp b/src/Innovation.cpp index 637f4f08..8ecf6bbf 100644 --- a/src/Innovation.cpp +++ b/src/Innovation.cpp @@ -295,6 +295,34 @@ void InnovationDatabase::Flush() m_Innovations.clear(); } +bool operator==(Innovation const &lhs, Innovation const &rhs) +{ + return lhs.m_ID == rhs.m_ID && + lhs.m_InnovType == rhs.m_InnovType && + lhs.m_FromNeuronID == rhs.m_FromNeuronID && + lhs.m_ToNeuronID == rhs.m_ToNeuronID && + lhs.m_NeuronType == rhs.m_NeuronType && + lhs.m_NeuronID == rhs.m_NeuronID; +} + +std::ostream &operator<<(std::ostream &stream, Innovation const &innov) +{ + stream << innov.Serialize(); + return stream; +} + +bool operator==(InnovationDatabase const &lhs, InnovationDatabase const &rhs) +{ + return lhs.m_NextNeuronID == rhs.m_NextNeuronID && + lhs.m_NextInnovationNum == rhs.m_NextInnovationNum && + lhs.m_Innovations == rhs.m_Innovations; +} + +std::ostream &operator<<(std::ostream &stream, InnovationDatabase const &db) +{ + stream << db.Serialize(); + return stream; +} diff --git a/src/Innovation.h b/src/Innovation.h index 1e481021..30aa41b0 100644 --- a/src/Innovation.h +++ b/src/Innovation.h @@ -76,6 +76,9 @@ class Innovation NeuronType m_NeuronType; int m_NeuronID; + friend bool operator==(Innovation const &lhs, Innovation const &rhs); + friend std::ostream &operator<<(std::ostream &stream, Innovation const &db); + public: //////////////////////////// @@ -142,8 +145,28 @@ class Innovation ar & m_NeuronType; ar & m_NeuronID; } + + std::string Serialize() const + { + std::ostringstream os; + { + cereal::JSONOutputArchive oa(os); + oa << *this; + } + return os.str(); + } + + void Deserialize(const std::string &text) + { + std::istringstream is(text); + cereal::JSONInputArchive ia(is); + ia >> *this; + } }; +bool operator==(Innovation const &lhs, Innovation const &rhs); +std::ostream &operator<<(std::ostream &stream, Innovation const &db); + // forward class Genome; @@ -165,6 +188,9 @@ class InnovationDatabase int m_NextNeuronID; int m_NextInnovationNum; + friend bool operator==(InnovationDatabase const &lhs, InnovationDatabase const &rhs); + friend std::ostream &operator<<(std::ostream &stream, InnovationDatabase const &db); + public: //////////////////////////// @@ -258,7 +284,8 @@ class InnovationDatabase } }; - +bool operator==(InnovationDatabase const &lhs, InnovationDatabase const &rhs); +std::ostream &operator<<(std::ostream &stream, InnovationDatabase const &db); diff --git a/src/Traits.cpp b/src/Traits.cpp index 41a345c3..4a7d97d2 100644 --- a/src/Traits.cpp +++ b/src/Traits.cpp @@ -1,6 +1,31 @@ -// -// Created by peter on 28.04.17. -// - #include "Traits.h" +namespace NEAT +{ + + bool operator==(Trait const &lhs, Trait const &rhs) + { + return lhs.value == rhs.value && + lhs.dep_key == rhs.dep_key; // && + // lhs.dep_values == rhs.dep_values + } + + std::ostream &operator<<(std::ostream &stream, Trait const &trait) + { + stream << trait.Serialize(); + return stream; + } + + std::ostream &operator<<(std::ostream &stream, intsetelement const &trait) + { + stream << trait.Serialize(); + return stream; + } + + std::ostream &operator<<(std::ostream &stream, floatsetelement const &trait) + { + stream << trait.Serialize(); + return stream; + } + +} diff --git a/src/Traits.h b/src/Traits.h index 14ce077f..05df3d93 100644 --- a/src/Traits.h +++ b/src/Traits.h @@ -9,6 +9,10 @@ #include #include #include +#include +#include +#include +#include #ifdef PYTHON_BINDINGS #include @@ -25,8 +29,36 @@ namespace NEAT return rhs.value == value; } + friend std::ostream &operator<<(std::ostream &stream, intsetelement const &trait); + int value; + + // Serialization + template + void serialize(Archive & ar) + { + ar & value; + } + + std::string Serialize() const + { + std::ostringstream os; + { + cereal::JSONOutputArchive oa(os); + oa << *this; + } + return os.str(); + } + + void Deserialize(const std::string &text) + { + std::istringstream is (text); + cereal::JSONInputArchive ia(is); + ia >> *this; + } }; + std::ostream &operator<<(std::ostream &stream, intsetelement const &trait); + class floatsetelement { public: @@ -36,8 +68,35 @@ namespace NEAT return rhs.value == value; } + friend std::ostream &operator<<(std::ostream &stream, floatsetelement const &trait); + double value; + + // Serialization + template + void serialize(Archive & ar) + { + ar & value; + } + + std::string Serialize() const + { + std::ostringstream os; + { + cereal::JSONOutputArchive oa(os); + oa << *this; + } + return os.str(); + } + + void Deserialize(const std::string &text) + { + std::istringstream is (text); + cereal::JSONInputArchive ia(is); + ia >> *this; + } }; + std::ostream &operator<<(std::ostream &stream, floatsetelement const &trait); typedef std::variant dep_values; // and has this value + + // Serialization + template + void serialize(Archive & ar) + { + // ar & value; + ar & dep_key; + // ar & dep_values; TODO + } + + std::string Serialize() const + { + std::ostringstream os; + { + cereal::JSONOutputArchive oa(os); + oa << *this; + } + return os.str(); + } + + void Deserialize(const std::string &text) + { + std::istringstream is (text); + cereal::JSONInputArchive ia(is); + ia >> *this; + } }; + bool operator==(Trait const &lhs, Trait const &rhs); + std::ostream &operator<<(std::ostream &stream, Trait const &trait); + } #endif //MULTINEAT_TRAITS_H diff --git a/tests/serialization/CMakeLists.txt b/tests/serialization/CMakeLists.txt index 58927713..2c25d6c9 100644 --- a/tests/serialization/CMakeLists.txt +++ b/tests/serialization/CMakeLists.txt @@ -6,7 +6,17 @@ target_link_libraries(serialize_genome MultiNEAT Boost::unit_test_framework) -add_test(serialize_genome serialize_genome) +add_test(genome serialize_genome) + +# serialize_innovation_database +add_executable(serialize_innovation_database +serialize_innovation_database.cpp) + +target_link_libraries(serialize_innovation_database + MultiNEAT + Boost::unit_test_framework) + +add_test(serialize_innovation_database serialize_innovation_database) if(GENERATE_PYTHON_BINDINGS) # python to C++ test diff --git a/tests/serialization/serialize_genome.cpp b/tests/serialization/serialize_genome.cpp index 0ddbfb19..47576877 100644 --- a/tests/serialization/serialize_genome.cpp +++ b/tests/serialization/serialize_genome.cpp @@ -9,55 +9,47 @@ #define BOOST_TEST_MODULE Serialization test #include -std::string serialize(const NEAT::Genome &g) +BOOST_AUTO_TEST_CASE(serialize_genome) { - std::ostringstream output_data; + NEAT::Parameters params; + NEAT::RNG rng; - { - cereal::JSONOutputArchive archive(output_data); - archive << g; - } + NEAT::Genome genome(42, + 2, 1, 3, + NEAT::ActivationFunction::SIGNED_SIGMOID, + NEAT::ActivationFunction::SIGNED_SIGMOID, + params); + std::string serialized = genome.Serialize(); - return output_data.str(); -} + NEAT::Genome genome_loaded; + genome_loaded.Deserialize(serialized); -NEAT::Genome deserialize(const std::string &data) -{ - NEAT::Genome g; - std::istringstream input_data(data); - { - cereal::JSONInputArchive archive(input_data); - archive >> g; - } - return g; -} + BOOST_TEST(genome == genome_loaded); -BOOST_AUTO_TEST_CASE(serialize_genome) -{ - NEAT::Parameters params; + auto innov_db = NEAT::InnovationDatabase(); - NEAT::Genome g(42, - 2, 1, 3, - NEAT::ActivationFunction::SIGNED_SIGMOID, - NEAT::ActivationFunction::SIGNED_SIGMOID, - params - ); - - const std::string serialized = serialize(g); - NEAT::Genome copy = deserialize(serialized); - - BOOST_TEST(g.GetID() == copy.GetID()); - BOOST_TEST(g.m_LinkGenes == copy.m_LinkGenes); - BOOST_TEST(g.m_NeuronGenes == copy.m_NeuronGenes); - BOOST_TEST(g.GetDepth() == copy.GetDepth()); - BOOST_TEST(g.NumInputs() == copy.NumInputs()); - BOOST_TEST(g.NumOutputs() == copy.NumOutputs()); - BOOST_TEST(g.GetFitness() == copy.GetFitness()); - BOOST_TEST(g.GetAdjFitness() == copy.GetAdjFitness()); - BOOST_TEST(g.GetDepth() == copy.GetDepth()); - BOOST_TEST(g.GetOffspringAmount() == copy.GetOffspringAmount()); - BOOST_TEST(g.m_Evaluated == copy.m_Evaluated); - - // This last test should fail, but it does not :) - BOOST_TEST(g.m_PhenotypeBehavior == copy.m_PhenotypeBehavior); -} \ No newline at end of file + do + { + genome.Mutate(false, SearchMode::BLENDED, innov_db, params, rng); + } while (innov_db.m_Innovations.size() < 100); + + serialized = genome.Serialize(); + + NEAT::Genome genome_loaded2; + genome_loaded2.Deserialize(serialized); + + BOOST_TEST(genome == genome_loaded2); + + // below are old tests, keeping them as extra + BOOST_TEST(genome.GetID() == genome_loaded2.GetID()); + BOOST_TEST(genome.m_LinkGenes == genome_loaded2.m_LinkGenes); + BOOST_TEST(genome.m_NeuronGenes == genome_loaded2.m_NeuronGenes); + BOOST_TEST(genome.GetDepth() == genome_loaded2.GetDepth()); + BOOST_TEST(genome.NumInputs() == genome_loaded2.NumInputs()); + BOOST_TEST(genome.NumOutputs() == genome_loaded2.NumOutputs()); + BOOST_TEST(genome.GetFitness() == genome_loaded2.GetFitness()); + BOOST_TEST(genome.GetAdjFitness() == genome_loaded2.GetAdjFitness()); + BOOST_TEST(genome.GetDepth() == genome_loaded2.GetDepth()); + BOOST_TEST(genome.GetOffspringAmount() == genome_loaded2.GetOffspringAmount()); + BOOST_TEST(genome.m_Evaluated == genome_loaded2.m_Evaluated); +} diff --git a/tests/serialization/serialize_innovation_database.cpp b/tests/serialization/serialize_innovation_database.cpp new file mode 100644 index 00000000..c245c0c8 --- /dev/null +++ b/tests/serialization/serialize_innovation_database.cpp @@ -0,0 +1,36 @@ +// +// Aart 2022 jun 8 +// + +#include +#include +#include +#include + +#define BOOST_TEST_MODULE serialize_innovation_database test +#include + +BOOST_AUTO_TEST_CASE(serialize_innovation_database) +{ + auto innov_db = NEAT::InnovationDatabase(); + + NEAT::Parameters params; + NEAT::Genome genome(42, + 2, 1, 3, + NEAT::ActivationFunction::SIGNED_SIGMOID, + NEAT::ActivationFunction::SIGNED_SIGMOID, + params); + + NEAT::RNG rng; + + do + { + genome.Mutate(false, SearchMode::BLENDED, innov_db, params, rng); + } while (innov_db.m_Innovations.size() < 100); + + std::string serialized = innov_db.Serialize(); + NEAT::InnovationDatabase innov_db_loaded = NEAT::InnovationDatabase(); + innov_db_loaded.Deserialize(serialized); + + BOOST_TEST(innov_db == innov_db_loaded); +}