Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved implementation of ptr-to-member-fcn for surrogate gradient #32

Merged
merged 3 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions models/eprop_iaf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ eprop_iaf::pre_run_hook()

V_.RefractoryCounts_ = Time( Time::ms( P_.t_ref_ ) ).get_steps();

compute_surrogate_gradient = select_surrogate_gradient( P_.surrogate_gradient_function_ );
compute_surrogate_gradient_ = select_surrogate_gradient( P_.surrogate_gradient_function_ );

// calculate the entries of the propagator matrix for the evolution of the state vector

Expand Down Expand Up @@ -301,7 +301,7 @@ eprop_iaf::update( Time const& origin, const long from, const long to )

// P_.V_th_ is passed twice to handle models without an adaptive threshold, serving as both v_th_adapt and V_th
S_.surrogate_gradient_ =
( this->*compute_surrogate_gradient )( S_.r_, S_.v_m_, P_.V_th_, P_.V_th_, P_.beta_, P_.gamma_ );
( this->*compute_surrogate_gradient_ )( S_.r_, S_.v_m_, P_.V_th_, P_.V_th_, P_.beta_, P_.gamma_ );

emplace_new_eprop_history_entry( t );

Expand Down
4 changes: 2 additions & 2 deletions models/eprop_iaf.h
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ class eprop_iaf : public EpropArchivingNodeRecurrent
bool is_eprop_recurrent_node() const override;
long get_eprop_isi_trace_cutoff() const override;

//! Compute the surrogate gradient.
double ( eprop_iaf::*compute_surrogate_gradient )( double, double, double, double, double, double );
//! Pointer to member function selected for computing the surrogate gradient
surrogate_gradient_function compute_surrogate_gradient_;

//! Map for storing a static set of recordables.
friend class RecordablesMap< eprop_iaf >;
Expand Down
4 changes: 2 additions & 2 deletions models/eprop_iaf_adapt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ eprop_iaf_adapt::pre_run_hook()

V_.RefractoryCounts_ = Time( Time::ms( P_.t_ref_ ) ).get_steps();

compute_surrogate_gradient = select_surrogate_gradient( P_.surrogate_gradient_function_ );
compute_surrogate_gradient_ = select_surrogate_gradient( P_.surrogate_gradient_function_ );

// calculate the entries of the propagator matrix for the evolution of the state vector

Expand Down Expand Up @@ -338,7 +338,7 @@ eprop_iaf_adapt::update( Time const& origin, const long from, const long to )
S_.z_ = 0.0;

S_.surrogate_gradient_ =
( this->*compute_surrogate_gradient )( S_.r_, S_.v_m_, S_.v_th_adapt_, P_.V_th_, P_.beta_, P_.gamma_ );
( this->*compute_surrogate_gradient_ )( S_.r_, S_.v_m_, S_.v_th_adapt_, P_.V_th_, P_.beta_, P_.gamma_ );

emplace_new_eprop_history_entry( t );

Expand Down
4 changes: 2 additions & 2 deletions models/eprop_iaf_adapt.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,8 @@ class eprop_iaf_adapt : public EpropArchivingNodeRecurrent
bool is_eprop_recurrent_node() const override;
long get_eprop_isi_trace_cutoff() const override;

//! Compute the surrogate gradient.
double ( eprop_iaf_adapt::*compute_surrogate_gradient )( double, double, double, double, double, double );
//! Pointer to member function selected for computing the surrogate gradient
surrogate_gradient_function compute_surrogate_gradient_;

//! Map for storing a static set of recordables.
friend class RecordablesMap< eprop_iaf_adapt >;
Expand Down
4 changes: 2 additions & 2 deletions models/eprop_iaf_adapt_bsshslm_2020.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ eprop_iaf_adapt_bsshslm_2020::pre_run_hook()

V_.RefractoryCounts_ = Time( Time::ms( P_.t_ref_ ) ).get_steps();

compute_surrogate_gradient = select_surrogate_gradient( P_.surrogate_gradient_function_ );
compute_surrogate_gradient_ = select_surrogate_gradient( P_.surrogate_gradient_function_ );

// calculate the entries of the propagator matrix for the evolution of the state vector

Expand Down Expand Up @@ -342,7 +342,7 @@ eprop_iaf_adapt_bsshslm_2020::update( Time const& origin, const long from, const
S_.z_ = 0.0;

S_.surrogate_gradient_ =
( this->*compute_surrogate_gradient )( S_.r_, S_.v_m_, S_.v_th_adapt_, P_.V_th_, P_.beta_, P_.gamma_ );
( this->*compute_surrogate_gradient_ )( S_.r_, S_.v_m_, S_.v_th_adapt_, P_.V_th_, P_.beta_, P_.gamma_ );

emplace_new_eprop_history_entry( t );

Expand Down
5 changes: 2 additions & 3 deletions models/eprop_iaf_adapt_bsshslm_2020.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,8 @@ class eprop_iaf_adapt_bsshslm_2020 : public EpropArchivingNodeRecurrent
long get_shift() const override;
bool is_eprop_recurrent_node() const override;

//! Compute the surrogate gradient.
double (
eprop_iaf_adapt_bsshslm_2020::*compute_surrogate_gradient )( double, double, double, double, double, double );
//! Pointer to member function selected for computing the surrogate gradient
surrogate_gradient_function compute_surrogate_gradient_;

//! Map for storing a static set of recordables.
friend class RecordablesMap< eprop_iaf_adapt_bsshslm_2020 >;
Expand Down
4 changes: 2 additions & 2 deletions models/eprop_iaf_bsshslm_2020.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ eprop_iaf_bsshslm_2020::pre_run_hook()

V_.RefractoryCounts_ = Time( Time::ms( P_.t_ref_ ) ).get_steps();

compute_surrogate_gradient = select_surrogate_gradient( P_.surrogate_gradient_function_ );
compute_surrogate_gradient_ = select_surrogate_gradient( P_.surrogate_gradient_function_ );

// calculate the entries of the propagator matrix for the evolution of the state vector

Expand Down Expand Up @@ -304,7 +304,7 @@ eprop_iaf_bsshslm_2020::update( Time const& origin, const long from, const long

// P_.V_th_ is passed twice to handle models without an adaptive threshold, serving as both v_th_adapt and V_th
S_.surrogate_gradient_ =
( this->*compute_surrogate_gradient )( S_.r_, S_.v_m_, P_.V_th_, P_.V_th_, P_.beta_, P_.gamma_ );
( this->*compute_surrogate_gradient_ )( S_.r_, S_.v_m_, P_.V_th_, P_.V_th_, P_.beta_, P_.gamma_ );

emplace_new_eprop_history_entry( t );

Expand Down
4 changes: 2 additions & 2 deletions models/eprop_iaf_bsshslm_2020.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,8 @@ class eprop_iaf_bsshslm_2020 : public EpropArchivingNodeRecurrent
long get_shift() const override;
bool is_eprop_recurrent_node() const override;

//! Compute the surrogate gradient.
double ( eprop_iaf_bsshslm_2020::*compute_surrogate_gradient )( double, double, double, double, double, double );
//! Pointer to member function selected for computing the surrogate gradient
surrogate_gradient_function compute_surrogate_gradient_;

//! Map for storing a static set of recordables.
friend class RecordablesMap< eprop_iaf_bsshslm_2020 >;
Expand Down
4 changes: 2 additions & 2 deletions models/eprop_iaf_psc_delta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ nest::eprop_iaf_psc_delta::pre_run_hook()
{
B_.logger_.init();

compute_surrogate_gradient = select_surrogate_gradient( P_.surrogate_gradient_function_ );
compute_surrogate_gradient_ = select_surrogate_gradient( P_.surrogate_gradient_function_ );

const double h = Time::get_resolution().get_ms();

Expand Down Expand Up @@ -379,7 +379,7 @@ nest::eprop_iaf_psc_delta::update( Time const& origin, const long from, const lo

// P_.V_th_ is passed twice to handle models without an adaptive threshold, serving as both v_th_adapt and V_th
S_.surrogate_gradient_ =
( this->*compute_surrogate_gradient )( S_.r_, S_.y3_, P_.V_th_, P_.V_th_, P_.beta_, P_.gamma_ );
( this->*compute_surrogate_gradient_ )( S_.r_, S_.y3_, P_.V_th_, P_.V_th_, P_.beta_, P_.gamma_ );

emplace_new_eprop_history_entry( t );

Expand Down
5 changes: 2 additions & 3 deletions models/eprop_iaf_psc_delta.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,8 @@ class eprop_iaf_psc_delta : public EpropArchivingNodeRecurrent
bool is_eprop_recurrent_node() const override;
long get_eprop_isi_trace_cutoff() const override;

//! Compute the surrogate gradient.
double ( eprop_iaf_psc_delta::*compute_surrogate_gradient )( double, double, double, double, double, double );

//! Pointer to member function selected for computing the surrogate gradient
surrogate_gradient_function compute_surrogate_gradient_;

// The next two classes need to be friends to access the State_ class/member
friend class RecordablesMap< eprop_iaf_psc_delta >;
Expand Down
31 changes: 31 additions & 0 deletions nestkernel/eprop_archiving_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@
namespace nest
{

std::map< std::string, EpropArchivingNodeRecurrent::surrogate_gradient_function >
EpropArchivingNodeRecurrent::surrogate_gradient_funcs_ = {
{ "piecewise_linear", &EpropArchivingNodeRecurrent::compute_piecewise_linear_surrogate_gradient },
{ "exponential", &EpropArchivingNodeRecurrent::compute_exponential_surrogate_gradient },
{ "fast_sigmoid_derivative", &EpropArchivingNodeRecurrent::compute_fast_sigmoid_derivative_surrogate_gradient },
{ "arctan", &EpropArchivingNodeRecurrent::compute_arctan_surrogate_gradient }
};


EpropArchivingNodeRecurrent::EpropArchivingNodeRecurrent()
: EpropArchivingNode()
, firing_rate_reg_( 0.0 )
Expand All @@ -47,6 +56,28 @@ EpropArchivingNodeRecurrent::EpropArchivingNodeRecurrent( const EpropArchivingNo
{
}

EpropArchivingNodeRecurrent::surrogate_gradient_function
EpropArchivingNodeRecurrent::select_surrogate_gradient( const std::string& surrogate_gradient_function_name )
{
const auto found_entry_it = surrogate_gradient_funcs_.find( surrogate_gradient_function_name );

if ( found_entry_it != surrogate_gradient_funcs_.end() )
{
return found_entry_it->second;
}

std::string error_message = "Surrogate gradient / pseudo-derivate function surrogate_gradient_function from [";
for ( const auto& surrogate_gradient_func : surrogate_gradient_funcs_ )
{
error_message += " \"" + surrogate_gradient_func.first + "\",";
}
error_message.pop_back();
error_message += " ] required.";

throw BadProperty( error_message );
}


double
EpropArchivingNodeRecurrent::compute_piecewise_linear_surrogate_gradient( const double r,
const double v_m,
Expand Down
45 changes: 16 additions & 29 deletions nestkernel/eprop_archiving_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,36 +199,15 @@ class EpropArchivingNodeRecurrent : public EpropArchivingNode< HistEntryEpropRec
//! Copy constructor.
EpropArchivingNodeRecurrent( const EpropArchivingNodeRecurrent& );

/**
* Define pointer-to-member function type for surrogate gradient function.
* @note The typename is `surrogate_gradient_function`. All parentheses in the expression are required.
*/
typedef double (
EpropArchivingNodeRecurrent::*surrogate_gradient_function )( double, double, double, double, double, double );

//! Select the surrogate gradient function.
double ( EpropArchivingNodeRecurrent::*select_surrogate_gradient(
std::string surrogate_gradient_function ) )( double, double, double, double, double, double )
{
const std::map< std::string,
double ( EpropArchivingNodeRecurrent::* )( double, double, double, double, double, double ) >
surrogate_gradient_funcs = { { "piecewise_linear",
&EpropArchivingNodeRecurrent::compute_piecewise_linear_surrogate_gradient },
{ "exponential", &EpropArchivingNodeRecurrent::compute_exponential_surrogate_gradient },
{ "fast_sigmoid_derivative", &EpropArchivingNodeRecurrent::compute_fast_sigmoid_derivative_surrogate_gradient },
{ "arctan", &EpropArchivingNodeRecurrent::compute_arctan_surrogate_gradient } };

auto found_entry_it = surrogate_gradient_funcs.find( surrogate_gradient_function );
if ( found_entry_it == surrogate_gradient_funcs.end() )
{
std::string error_message = "Surrogate gradient / pseudo-derivate function surrogate_gradient_function from [";
for ( const auto& surrogate_gradient_func : surrogate_gradient_funcs )
{
error_message += " \"" + surrogate_gradient_func.first + "\",";
}
error_message.pop_back();
error_message += " ] required.";

throw BadProperty( error_message );
}
else
{
return found_entry_it->second;
}
}
surrogate_gradient_function select_surrogate_gradient( const std::string& surrogate_gradient_function_name );

//! Compute the surrogate gradient with a piecewise linear function around the spike time (used, e.g., in Bellec et
//! al., 2020).
Expand Down Expand Up @@ -315,6 +294,14 @@ class EpropArchivingNodeRecurrent : public EpropArchivingNode< HistEntryEpropRec

//! History of the firing rate regularization.
std::vector< HistEntryEpropFiringRateReg > firing_rate_reg_history_;

/**
* Map names of surrogate gradients provided to corresponding pointers to member functions.
*
* @todo In the long run, this map should be handled by a manager with proper registration functions,
* so that external modules can add their own gradient functions.
*/
static std::map< std::string, surrogate_gradient_function > surrogate_gradient_funcs_;
};

inline void
Expand Down
Loading