Skip to content

Commit

Permalink
Improve serialization and its unit tests (#15)
Browse files Browse the repository at this point in the history
* Improve serialization and unit tests

* Remove old commented out code

* Restore part of old tests
  • Loading branch information
surgura authored Jun 13, 2022
1 parent 39cccc1 commit f5c10c2
Show file tree
Hide file tree
Showing 12 changed files with 469 additions and 125 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ find_package(cereal)

set(SOURCE_FILES
src/Genome.cpp
src/Genes.cpp
src/Innovation.cpp
src/NeuralNetwork.cpp
src/Parameters.cpp
Expand Down
54 changes: 54 additions & 0 deletions src/Genes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "Genes.h"

namespace NEAT
{

bool operator==(Gene const &lhs, Gene const &rhs)
{
return lhs.m_Traits == rhs.m_Traits;
}

std::ostream &operator<<(std::ostream &stream, Gene const &gene)
{
stream << gene.Serialize();
return stream;
}

bool operator==(LinkGene const &lhs, LinkGene const &rhs)
{
return static_cast<Gene const &>(lhs) == static_cast<Gene const &>(rhs) &&
lhs.m_FromNeuronID == rhs.m_FromNeuronID &&
lhs.m_ToNeuronID == rhs.m_ToNeuronID &&
lhs.m_InnovationID == rhs.m_InnovationID &&
lhs.m_IsRecurrent == rhs.m_IsRecurrent &&
lhs.m_Weight == rhs.m_Weight;
}

std::ostream &operator<<(std::ostream &stream, LinkGene const &gene)
{
stream << gene.Serialize();
return stream;
}

bool operator==(NeuronGene const &lhs, NeuronGene const &rhs)
{
return static_cast<Gene const &>(lhs) == static_cast<Gene const &>(rhs) &&
lhs.m_ID == rhs.m_ID &&
lhs.m_Type == rhs.m_Type &&
lhs.x == rhs.x &&
lhs.y == rhs.y &&
lhs.m_SplitY == rhs.m_SplitY &&
lhs.m_A == rhs.m_A &&
lhs.m_B == rhs.m_B &&
lhs.m_TimeConstant == rhs.m_TimeConstant &&
lhs.m_Bias == rhs.m_Bias &&
lhs.m_ActFunction == rhs.m_ActFunction;
}

std::ostream &operator<<(std::ostream &stream, NeuronGene const &gene)
{
stream << gene.Serialize();
return stream;
}

}
157 changes: 104 additions & 53 deletions src/Genes.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
#include "Utils.h"

#include <cereal/cereal.hpp>
#include <cereal/types/vector.hpp>
#include <cereal/types/map.hpp>
#include <cereal/archives/json.hpp>


namespace NEAT
Expand Down Expand Up @@ -86,6 +89,9 @@ namespace NEAT
//////////////////////////////////
class Gene
{
friend bool operator==(Gene const &lhs, Gene const &rhs);
friend std::ostream &operator<<(std::ostream &stream, Gene const &gene);

public:
// Arbitrary traits
std::map<std::string, Trait> m_Traits;
Expand Down Expand Up @@ -512,8 +518,34 @@ namespace NEAT

return dist;
}

// Serialization
template<class Archive>
void serialize(Archive & ar)
{
ar & m_Traits;
}

std::string Serialize() const
{
std::ostringstream os;
{
cereal::JSONOutputArchive oa(os);
oa << *this;
}
return os.str();
}

void Deserialize(const std::string &text)
{
std::istringstream is (text);
cereal::JSONInputArchive ia(is);
ia >> *this;
}
};

bool operator==(Gene const &lhs, Gene const &rhs);
std::ostream &operator<<(std::ostream &stream, Gene const &gene);

//////////////////////////////////
// This class defines a link gene
Expand All @@ -524,6 +556,9 @@ namespace NEAT
// Members
/////////////////////

friend bool operator==(LinkGene const &lhs, LinkGene const &rhs);
friend std::ostream &operator<<(std::ostream &stream, LinkGene const &gene);

public:

// These variables are initialized once and cannot be changed
Expand All @@ -544,20 +579,6 @@ namespace NEAT

public:

// Serialization
template<class Archive>
void serialize(Archive & ar)
{
ar & m_FromNeuronID;
ar & m_ToNeuronID;
ar & m_InnovationID;
ar & m_IsRecurrent;
ar & m_Weight;

// the traits too, TODO
//ar & m_Traits;
}

double GetWeight() const
{
return m_Weight;
Expand Down Expand Up @@ -651,12 +672,39 @@ namespace NEAT
return (a_lhs.m_InnovationID != a_rhs.m_InnovationID);
}

friend bool operator==(const LinkGene &a_lhs, const LinkGene &a_rhs)
// Serialization
template<class Archive>
void serialize(Archive & ar)
{
return (a_lhs.m_InnovationID == a_rhs.m_InnovationID);
ar & cereal::base_class<Gene>(this);
ar & m_FromNeuronID;
ar & m_ToNeuronID;
ar & m_InnovationID;
ar & m_IsRecurrent;
ar & m_Weight;
}

std::string Serialize() const
{
std::ostringstream os;
{
cereal::JSONOutputArchive oa(os);
oa << *this;
}
return os.str();
}

void Deserialize(const std::string &text)
{
std::istringstream is (text);
cereal::JSONInputArchive ia(is);
ia >> *this;
}
};

bool operator==(LinkGene const &lhs, LinkGene const &rhs);
std::ostream &operator<<(std::ostream &stream, LinkGene const &gene);


////////////////////////////////////
// This class defines a neuron gene
Expand All @@ -667,6 +715,9 @@ namespace NEAT
// Members
/////////////////////

friend bool operator==(NeuronGene const &lhs, NeuronGene const &rhs);
friend std::ostream &operator<<(std::ostream &stream, NeuronGene const &gene);

public:
// These variables are initialized once and cannot be changed
// anymore
Expand Down Expand Up @@ -723,50 +774,13 @@ namespace NEAT
// The type of activation function the neuron has
ActivationFunction m_ActFunction;

// Serialization
template<class Archive>
void serialize(Archive & ar)
{
ar & m_ID;
ar & m_Type;
ar & m_A;
ar & m_B;
ar & m_TimeConstant;
ar & m_Bias;
ar & x;
ar & y;
ar & m_ActFunction;
ar & m_SplitY;

// TODO the traits also
//ar & m_Traits;
}

////////////////
// Constructors
////////////////
NeuronGene()
{

}

/*friend bool operator!=(const NeuronGene &a_lhs, const NeuronGene &a_rhs)
{
return (a_lhs.m_ID != a_rhs.m_ID);
}*/

friend bool operator==(const NeuronGene &a_lhs, const NeuronGene &a_rhs)
{
return (a_lhs.m_ID == a_rhs.m_ID) &&
(a_lhs.m_Type == a_rhs.m_Type)
//(a_lhs.m_SplitY == a_rhs.m_SplitY) &&
//(a_lhs.m_A == a_rhs.m_A) &&
//(a_lhs.m_B == a_rhs.m_B) &&
//(a_lhs.m_TimeConstant == a_rhs.m_TimeConstant) &&
//(a_lhs.m_Bias == a_rhs.m_Bias) &&
//(a_lhs.m_ActFunction == a_rhs.m_ActFunction)
;
}

NeuronGene(NeuronType a_type, int a_id, double a_splity)
{
Expand Down Expand Up @@ -842,8 +856,45 @@ namespace NEAT
m_Bias = a_Bias;
m_ActFunction = a_ActFunc;
}

// Serialization
template<class Archive>
void serialize(Archive & ar)
{
ar & cereal::base_class<Gene>(this);
ar & m_ID;
ar & m_Type;
ar & m_A;
ar & m_B;
ar & m_TimeConstant;
ar & m_Bias;
ar & x;
ar & y;
ar & m_ActFunction;
ar & m_SplitY;
}

std::string Serialize() const
{
std::ostringstream os;
{
cereal::JSONOutputArchive oa(os);
oa << *this;
}
return os.str();
}

void Deserialize(const std::string &text)
{
std::istringstream is (text);
cereal::JSONInputArchive ia(is);
ia >> *this;
}
};

bool operator==(NeuronGene const &lhs, NeuronGene const &rhs);
std::ostream &operator<<(std::ostream &stream, NeuronGene const &gene);


} // namespace NEAT

Expand Down
25 changes: 25 additions & 0 deletions src/Genome.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4172,6 +4172,31 @@ namespace NEAT
}
}

bool operator==(Genome const &lhs, Genome const &rhs)
{
return lhs.m_ID == rhs.m_ID &&
lhs.m_NumInputs == rhs.m_NumInputs &&
lhs.m_NumOutputs == rhs.m_NumOutputs &&
lhs.m_Fitness == rhs.m_Fitness &&
lhs.m_AdjustedFitness == rhs.m_AdjustedFitness &&
lhs.m_Depth == rhs.m_Depth &&
lhs.m_NeuronRecursionLimit == rhs.m_NeuronRecursionLimit &&
lhs.m_OffspringAmount == rhs.m_OffspringAmount &&
lhs.m_NeuronGenes == rhs.m_NeuronGenes &&
lhs.m_LinkGenes == rhs.m_LinkGenes &&
lhs.m_GenomeGene == rhs.m_GenomeGene &&
lhs.m_Evaluated == rhs.m_Evaluated &&
lhs.m_initial_num_neurons == rhs.m_initial_num_neurons &&
lhs.m_initial_num_links == rhs.m_initial_num_links &&
((lhs.m_PhenotypeBehavior == nullptr && rhs.m_PhenotypeBehavior == nullptr) || (*lhs.m_PhenotypeBehavior == *rhs.m_PhenotypeBehavior));
}

std::ostream &operator<<(std::ostream &stream, Genome const &genome)
{
stream << genome.Serialize();
return stream;
}

#ifdef PYTHON_BINDINGS
pybind11::dict Genome::TraitMap2Dict(const std::map< std::string, Trait>& tmap) const
{
Expand Down
44 changes: 24 additions & 20 deletions src/Genome.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ namespace NEAT
// Returns true is the specified neuron ID is a dead end or isolated
bool IsDeadEndNeuron(int a_id) const;

friend bool operator==(Genome const &lhs, Genome const &rhs);
friend std::ostream &operator<<(std::ostream &stream, Genome const &genome);

public:

// The two lists of genes
Expand Down Expand Up @@ -159,13 +162,6 @@ namespace NEAT
// assignment operator
Genome &operator=(const Genome &a_g);

// comparison operator (nessesary for python bindings)
// todo: implement a better comparison technique
bool operator==(Genome const &other) const
{
return m_ID == other.m_ID;
}

// Builds this genome from a file
Genome(const char *a_filename);

Expand Down Expand Up @@ -643,20 +639,28 @@ namespace NEAT
unsigned int output_count, unsigned int hidden_count);

// Serialization
template<class Archive>
void serialize(Archive & ar)
template <class Archive>
void serialize(Archive &ar)
{
ar & m_ID;
ar & m_NeuronGenes;
ar & m_LinkGenes;
ar & m_NumInputs;
ar & m_NumOutputs;
ar & m_Fitness;
ar & m_AdjustedFitness;
ar & m_Depth;
ar & m_OffspringAmount;
ar & m_Evaluated;
//ar & m_PhenotypeBehavior; // todo: think about how we will handle the behaviors with pickle
ar &m_ID;
ar &m_NeuronGenes;
ar &m_LinkGenes;
ar &m_NumInputs;
ar &m_NumOutputs;
ar &m_Fitness;
ar &m_AdjustedFitness;
ar &m_Depth;
ar &m_OffspringAmount;
ar &m_Evaluated;
ar &m_NeuronRecursionLimit;
ar &m_GenomeGene;
ar &m_initial_num_neurons;
ar &m_initial_num_links;

if (m_PhenotypeBehavior != nullptr)
{
throw std::runtime_error("m_PhenotypeBehavior not null but serialization not implemented.");
}
}

std::string Serialize() const
Expand Down
Loading

0 comments on commit f5c10c2

Please sign in to comment.