diff --git a/models/eprop_iaf.cpp b/models/eprop_iaf.cpp index 2c79ab86db..b6c23ae645 100644 --- a/models/eprop_iaf.cpp +++ b/models/eprop_iaf.cpp @@ -257,7 +257,7 @@ eprop_iaf::pre_run_hook() V_.RefractoryCounts_ = Time( Time::ms( P_.t_ref_ ) ).get_steps(); - compute_surrogate_gradient = select_surrogate_gradient( P_.surrogate_gradient_function_ ); + compute_surrogate_gradient_ = select_surrogate_gradient( P_.surrogate_gradient_function_ ); // calculate the entries of the propagator matrix for the evolution of the state vector @@ -301,7 +301,7 @@ eprop_iaf::update( Time const& origin, const long from, const long to ) // P_.V_th_ is passed twice to handle models without an adaptive threshold, serving as both v_th_adapt and V_th S_.surrogate_gradient_ = - ( this->*compute_surrogate_gradient )( S_.r_, S_.v_m_, P_.V_th_, P_.V_th_, P_.beta_, P_.gamma_ ); + ( this->*compute_surrogate_gradient_ )( S_.r_, S_.v_m_, P_.V_th_, P_.V_th_, P_.beta_, P_.gamma_ ); emplace_new_eprop_history_entry( t ); diff --git a/models/eprop_iaf.h b/models/eprop_iaf.h index 364288cd64..f92bc572a0 100644 --- a/models/eprop_iaf.h +++ b/models/eprop_iaf.h @@ -342,8 +342,8 @@ class eprop_iaf : public EpropArchivingNodeRecurrent bool is_eprop_recurrent_node() const override; long get_eprop_isi_trace_cutoff() const override; - //! Compute the surrogate gradient. - double ( eprop_iaf::*compute_surrogate_gradient )( double, double, double, double, double, double ); + //! Pointer to member function selected for computing the surrogate gradient + surrogate_gradient_function compute_surrogate_gradient_; //! Map for storing a static set of recordables. friend class RecordablesMap< eprop_iaf >; diff --git a/models/eprop_iaf_adapt.cpp b/models/eprop_iaf_adapt.cpp index b5883cf62e..8a7cd7691f 100644 --- a/models/eprop_iaf_adapt.cpp +++ b/models/eprop_iaf_adapt.cpp @@ -291,7 +291,7 @@ eprop_iaf_adapt::pre_run_hook() V_.RefractoryCounts_ = Time( Time::ms( P_.t_ref_ ) ).get_steps(); - compute_surrogate_gradient = select_surrogate_gradient( P_.surrogate_gradient_function_ ); + compute_surrogate_gradient_ = select_surrogate_gradient( P_.surrogate_gradient_function_ ); // calculate the entries of the propagator matrix for the evolution of the state vector @@ -338,7 +338,7 @@ eprop_iaf_adapt::update( Time const& origin, const long from, const long to ) S_.z_ = 0.0; S_.surrogate_gradient_ = - ( this->*compute_surrogate_gradient )( S_.r_, S_.v_m_, S_.v_th_adapt_, P_.V_th_, P_.beta_, P_.gamma_ ); + ( this->*compute_surrogate_gradient_ )( S_.r_, S_.v_m_, S_.v_th_adapt_, P_.V_th_, P_.beta_, P_.gamma_ ); emplace_new_eprop_history_entry( t ); diff --git a/models/eprop_iaf_adapt.h b/models/eprop_iaf_adapt.h index 92eaa82afb..7ffde82607 100644 --- a/models/eprop_iaf_adapt.h +++ b/models/eprop_iaf_adapt.h @@ -358,8 +358,8 @@ class eprop_iaf_adapt : public EpropArchivingNodeRecurrent bool is_eprop_recurrent_node() const override; long get_eprop_isi_trace_cutoff() const override; - //! Compute the surrogate gradient. - double ( eprop_iaf_adapt::*compute_surrogate_gradient )( double, double, double, double, double, double ); + //! Pointer to member function selected for computing the surrogate gradient + surrogate_gradient_function compute_surrogate_gradient_; //! Map for storing a static set of recordables. friend class RecordablesMap< eprop_iaf_adapt >; diff --git a/models/eprop_iaf_adapt_bsshslm_2020.cpp b/models/eprop_iaf_adapt_bsshslm_2020.cpp index 11d3e63727..bcc18e1b37 100644 --- a/models/eprop_iaf_adapt_bsshslm_2020.cpp +++ b/models/eprop_iaf_adapt_bsshslm_2020.cpp @@ -275,7 +275,7 @@ eprop_iaf_adapt_bsshslm_2020::pre_run_hook() V_.RefractoryCounts_ = Time( Time::ms( P_.t_ref_ ) ).get_steps(); - compute_surrogate_gradient = select_surrogate_gradient( P_.surrogate_gradient_function_ ); + compute_surrogate_gradient_ = select_surrogate_gradient( P_.surrogate_gradient_function_ ); // calculate the entries of the propagator matrix for the evolution of the state vector @@ -342,7 +342,7 @@ eprop_iaf_adapt_bsshslm_2020::update( Time const& origin, const long from, const S_.z_ = 0.0; S_.surrogate_gradient_ = - ( this->*compute_surrogate_gradient )( S_.r_, S_.v_m_, S_.v_th_adapt_, P_.V_th_, P_.beta_, P_.gamma_ ); + ( this->*compute_surrogate_gradient_ )( S_.r_, S_.v_m_, S_.v_th_adapt_, P_.V_th_, P_.beta_, P_.gamma_ ); emplace_new_eprop_history_entry( t ); diff --git a/models/eprop_iaf_adapt_bsshslm_2020.h b/models/eprop_iaf_adapt_bsshslm_2020.h index 767a183346..bad20f2e1e 100644 --- a/models/eprop_iaf_adapt_bsshslm_2020.h +++ b/models/eprop_iaf_adapt_bsshslm_2020.h @@ -326,9 +326,8 @@ class eprop_iaf_adapt_bsshslm_2020 : public EpropArchivingNodeRecurrent long get_shift() const override; bool is_eprop_recurrent_node() const override; - //! Compute the surrogate gradient. - double ( - eprop_iaf_adapt_bsshslm_2020::*compute_surrogate_gradient )( double, double, double, double, double, double ); + //! Pointer to member function selected for computing the surrogate gradient + surrogate_gradient_function compute_surrogate_gradient_; //! Map for storing a static set of recordables. friend class RecordablesMap< eprop_iaf_adapt_bsshslm_2020 >; diff --git a/models/eprop_iaf_bsshslm_2020.cpp b/models/eprop_iaf_bsshslm_2020.cpp index ecc939b1c7..1c36964064 100644 --- a/models/eprop_iaf_bsshslm_2020.cpp +++ b/models/eprop_iaf_bsshslm_2020.cpp @@ -241,7 +241,7 @@ eprop_iaf_bsshslm_2020::pre_run_hook() V_.RefractoryCounts_ = Time( Time::ms( P_.t_ref_ ) ).get_steps(); - compute_surrogate_gradient = select_surrogate_gradient( P_.surrogate_gradient_function_ ); + compute_surrogate_gradient_ = select_surrogate_gradient( P_.surrogate_gradient_function_ ); // calculate the entries of the propagator matrix for the evolution of the state vector @@ -304,7 +304,7 @@ eprop_iaf_bsshslm_2020::update( Time const& origin, const long from, const long // P_.V_th_ is passed twice to handle models without an adaptive threshold, serving as both v_th_adapt and V_th S_.surrogate_gradient_ = - ( this->*compute_surrogate_gradient )( S_.r_, S_.v_m_, P_.V_th_, P_.V_th_, P_.beta_, P_.gamma_ ); + ( this->*compute_surrogate_gradient_ )( S_.r_, S_.v_m_, P_.V_th_, P_.V_th_, P_.beta_, P_.gamma_ ); emplace_new_eprop_history_entry( t ); diff --git a/models/eprop_iaf_bsshslm_2020.h b/models/eprop_iaf_bsshslm_2020.h index 4010004629..d32065858f 100644 --- a/models/eprop_iaf_bsshslm_2020.h +++ b/models/eprop_iaf_bsshslm_2020.h @@ -311,8 +311,8 @@ class eprop_iaf_bsshslm_2020 : public EpropArchivingNodeRecurrent long get_shift() const override; bool is_eprop_recurrent_node() const override; - //! Compute the surrogate gradient. - double ( eprop_iaf_bsshslm_2020::*compute_surrogate_gradient )( double, double, double, double, double, double ); + //! Pointer to member function selected for computing the surrogate gradient + surrogate_gradient_function compute_surrogate_gradient_; //! Map for storing a static set of recordables. friend class RecordablesMap< eprop_iaf_bsshslm_2020 >; diff --git a/models/eprop_iaf_psc_delta.cpp b/models/eprop_iaf_psc_delta.cpp index 53ca50c9a0..21933e5412 100644 --- a/models/eprop_iaf_psc_delta.cpp +++ b/models/eprop_iaf_psc_delta.cpp @@ -289,7 +289,7 @@ nest::eprop_iaf_psc_delta::pre_run_hook() { B_.logger_.init(); - compute_surrogate_gradient = select_surrogate_gradient( P_.surrogate_gradient_function_ ); + compute_surrogate_gradient_ = select_surrogate_gradient( P_.surrogate_gradient_function_ ); const double h = Time::get_resolution().get_ms(); @@ -379,7 +379,7 @@ nest::eprop_iaf_psc_delta::update( Time const& origin, const long from, const lo // P_.V_th_ is passed twice to handle models without an adaptive threshold, serving as both v_th_adapt and V_th S_.surrogate_gradient_ = - ( this->*compute_surrogate_gradient )( S_.r_, S_.y3_, P_.V_th_, P_.V_th_, P_.beta_, P_.gamma_ ); + ( this->*compute_surrogate_gradient_ )( S_.r_, S_.y3_, P_.V_th_, P_.V_th_, P_.beta_, P_.gamma_ ); emplace_new_eprop_history_entry( t ); diff --git a/models/eprop_iaf_psc_delta.h b/models/eprop_iaf_psc_delta.h index 1651cca0c2..2784bb592e 100644 --- a/models/eprop_iaf_psc_delta.h +++ b/models/eprop_iaf_psc_delta.h @@ -276,9 +276,8 @@ class eprop_iaf_psc_delta : public EpropArchivingNodeRecurrent bool is_eprop_recurrent_node() const override; long get_eprop_isi_trace_cutoff() const override; - //! Compute the surrogate gradient. - double ( eprop_iaf_psc_delta::*compute_surrogate_gradient )( double, double, double, double, double, double ); - + //! Pointer to member function selected for computing the surrogate gradient + surrogate_gradient_function compute_surrogate_gradient_; // The next two classes need to be friends to access the State_ class/member friend class RecordablesMap< eprop_iaf_psc_delta >; diff --git a/nestkernel/eprop_archiving_node.cpp b/nestkernel/eprop_archiving_node.cpp index 5099acd322..937efea736 100644 --- a/nestkernel/eprop_archiving_node.cpp +++ b/nestkernel/eprop_archiving_node.cpp @@ -31,6 +31,15 @@ namespace nest { +std::map< std::string, EpropArchivingNodeRecurrent::surrogate_gradient_function > + EpropArchivingNodeRecurrent::surrogate_gradient_funcs_ = { + { "piecewise_linear", &EpropArchivingNodeRecurrent::compute_piecewise_linear_surrogate_gradient }, + { "exponential", &EpropArchivingNodeRecurrent::compute_exponential_surrogate_gradient }, + { "fast_sigmoid_derivative", &EpropArchivingNodeRecurrent::compute_fast_sigmoid_derivative_surrogate_gradient }, + { "arctan", &EpropArchivingNodeRecurrent::compute_arctan_surrogate_gradient } + }; + + EpropArchivingNodeRecurrent::EpropArchivingNodeRecurrent() : EpropArchivingNode() , firing_rate_reg_( 0.0 ) @@ -47,6 +56,28 @@ EpropArchivingNodeRecurrent::EpropArchivingNodeRecurrent( const EpropArchivingNo { } +EpropArchivingNodeRecurrent::surrogate_gradient_function +EpropArchivingNodeRecurrent::select_surrogate_gradient( const std::string& surrogate_gradient_function_name ) +{ + const auto found_entry_it = surrogate_gradient_funcs_.find( surrogate_gradient_function_name ); + + if ( found_entry_it != surrogate_gradient_funcs_.end() ) + { + return found_entry_it->second; + } + + std::string error_message = "Surrogate gradient / pseudo-derivate function surrogate_gradient_function from ["; + for ( const auto& surrogate_gradient_func : surrogate_gradient_funcs_ ) + { + error_message += " \"" + surrogate_gradient_func.first + "\","; + } + error_message.pop_back(); + error_message += " ] required."; + + throw BadProperty( error_message ); +} + + double EpropArchivingNodeRecurrent::compute_piecewise_linear_surrogate_gradient( const double r, const double v_m, diff --git a/nestkernel/eprop_archiving_node.h b/nestkernel/eprop_archiving_node.h index c692c96024..ee0996599c 100644 --- a/nestkernel/eprop_archiving_node.h +++ b/nestkernel/eprop_archiving_node.h @@ -199,36 +199,15 @@ class EpropArchivingNodeRecurrent : public EpropArchivingNode< HistEntryEpropRec //! Copy constructor. EpropArchivingNodeRecurrent( const EpropArchivingNodeRecurrent& ); + /** + * Define pointer-to-member function type for surrogate gradient function. + * @note The typename is `surrogate_gradient_function`. All parentheses in the expression are required. + */ + typedef double ( + EpropArchivingNodeRecurrent::*surrogate_gradient_function )( double, double, double, double, double, double ); + //! Select the surrogate gradient function. - double ( EpropArchivingNodeRecurrent::*select_surrogate_gradient( - std::string surrogate_gradient_function ) )( double, double, double, double, double, double ) - { - const std::map< std::string, - double ( EpropArchivingNodeRecurrent::* )( double, double, double, double, double, double ) > - surrogate_gradient_funcs = { { "piecewise_linear", - &EpropArchivingNodeRecurrent::compute_piecewise_linear_surrogate_gradient }, - { "exponential", &EpropArchivingNodeRecurrent::compute_exponential_surrogate_gradient }, - { "fast_sigmoid_derivative", &EpropArchivingNodeRecurrent::compute_fast_sigmoid_derivative_surrogate_gradient }, - { "arctan", &EpropArchivingNodeRecurrent::compute_arctan_surrogate_gradient } }; - - auto found_entry_it = surrogate_gradient_funcs.find( surrogate_gradient_function ); - if ( found_entry_it == surrogate_gradient_funcs.end() ) - { - std::string error_message = "Surrogate gradient / pseudo-derivate function surrogate_gradient_function from ["; - for ( const auto& surrogate_gradient_func : surrogate_gradient_funcs ) - { - error_message += " \"" + surrogate_gradient_func.first + "\","; - } - error_message.pop_back(); - error_message += " ] required."; - - throw BadProperty( error_message ); - } - else - { - return found_entry_it->second; - } - } + surrogate_gradient_function select_surrogate_gradient( const std::string& surrogate_gradient_function_name ); //! Compute the surrogate gradient with a piecewise linear function around the spike time (used, e.g., in Bellec et //! al., 2020). @@ -315,6 +294,14 @@ class EpropArchivingNodeRecurrent : public EpropArchivingNode< HistEntryEpropRec //! History of the firing rate regularization. std::vector< HistEntryEpropFiringRateReg > firing_rate_reg_history_; + + /** + * Map names of surrogate gradients provided to corresponding pointers to member functions. + * + * @todo In the long run, this map should be handled by a manager with proper registration functions, + * so that external modules can add their own gradient functions. + */ + static std::map< std::string, surrogate_gradient_function > surrogate_gradient_funcs_; }; inline void