diff --git a/.github/workflows/test-everything.yml b/.github/workflows/test-everything.yml index 59ae3c2cd0..5b52ba8fbe 100644 --- a/.github/workflows/test-everything.yml +++ b/.github/workflows/test-everything.yml @@ -151,7 +151,7 @@ jobs: export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH" mkdir build cd build - cmake .. -GNinja -DCMAKE_CXX_COMPILER=$CXX -DCMAKE_C_COMPILER=$CC -DARB_WITH_PYTHON=ON -DARB_VECTORIZE=${{ matrix.config.simd }} -DPython3_EXECUTABLE=`which python` -DARB_WITH_MPI=${{ matrix.config.mpi }} -DARB_USE_BUNDLED_LIBS=ON -DARB_WITH_ASSERTIONS=ON + cmake .. -GNinja -DCMAKE_CXX_COMPILER=$CXX -DCMAKE_C_COMPILER=$CC -DARB_WITH_PYTHON=ON -DARB_VECTORIZE=${{ matrix.config.simd }} -DPython3_EXECUTABLE=`which python` -DARB_WITH_MPI=${{ matrix.config.mpi }} -DARB_USE_BUNDLED_LIBS=ON -DARB_WITH_ASSERTIONS=ON -DARB_WITH_PROFILING=ON ninja -j4 tests examples pyarb cd - - if: ${{ matrix.variant == 'shared' }} @@ -160,7 +160,7 @@ jobs: export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH" mkdir build cd build - cmake .. -GNinja -DCMAKE_CXX_COMPILER=$CXX -DCMAKE_C_COMPILER=$CC -DARB_WITH_PYTHON=ON -DARB_VECTORIZE=${{ matrix.config.simd }} -DPython3_EXECUTABLE=`which python` -DARB_WITH_MPI=${{ matrix.config.mpi }} -DARB_USE_BUNDLED_LIBS=ON -DARB_WITH_ASSERTIONS=ON -DBUILD_SHARED_LIBS=ON + cmake .. -GNinja -DCMAKE_CXX_COMPILER=$CXX -DCMAKE_C_COMPILER=$CC -DARB_WITH_PYTHON=ON -DARB_VECTORIZE=${{ matrix.config.simd }} -DPython3_EXECUTABLE=`which python` -DARB_WITH_MPI=${{ matrix.config.mpi }} -DARB_USE_BUNDLED_LIBS=ON -DARB_WITH_ASSERTIONS=ON -DBUILD_SHARED_LIBS=ON -DARB_WITH_PROFILING=ON ninja -j4 tests examples pyarb cd - - name: Install arbor diff --git a/arbor/backends/common_types.hpp b/arbor/backends/common_types.hpp new file mode 100644 index 0000000000..357975c641 --- /dev/null +++ b/arbor/backends/common_types.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include + +#include "util/range.hpp" +#include "backends/threshold_crossing.hpp" +#include "execution_context.hpp" + +namespace arb { + +struct fvm_integration_result { + util::range crossings; + util::range sample_time; + util::range sample_value; +}; + +struct fvm_detector_info { + arb_size_type count = 0; + std::vector cv; + std::vector threshold; + execution_context ctx; +}; + +} diff --git a/arbor/backends/gpu/diffusion_state.hpp b/arbor/backends/gpu/diffusion_state.hpp index dd75d865ca..7c3adff60b 100644 --- a/arbor/backends/gpu/diffusion_state.hpp +++ b/arbor/backends/gpu/diffusion_state.hpp @@ -442,6 +442,16 @@ struct diffusion_state { packed_to_flat(rhs, to); } + void solve(array& concentration, + const_view dt_intdom, + const_view voltage, + const_view current, + const_view conductivity, + arb_value_type q) { + assemble(dt_intdom, concentration, voltage, current, conductivity, q); + solve(concentration); + } + std::size_t size() const { return matrix_size; } private: diff --git a/arbor/backends/gpu/fvm.hpp b/arbor/backends/gpu/fvm.hpp index d492357852..eaac5ab2bb 100644 --- a/arbor/backends/gpu/fvm.hpp +++ b/arbor/backends/gpu/fvm.hpp @@ -48,23 +48,6 @@ struct backend { using shared_state = arb::gpu::shared_state; using ion_state = arb::gpu::ion_state; - - static threshold_watcher voltage_watcher( - shared_state& state, - const std::vector& detector_cv, - const std::vector& thresholds, - const execution_context& context) - { - return threshold_watcher( - state.cv_to_intdom.data(), - state.src_to_spike.data(), - &state.time, - &state.time_to, - state.voltage.size(), - detector_cv, - thresholds, - context); - } }; } // namespace gpu diff --git a/arbor/backends/gpu/matrix_state_fine.hpp b/arbor/backends/gpu/matrix_state_fine.hpp index 5959513dbe..5636df719f 100644 --- a/arbor/backends/gpu/matrix_state_fine.hpp +++ b/arbor/backends/gpu/matrix_state_fine.hpp @@ -451,6 +451,13 @@ struct matrix_state_fine { packed_to_flat(rhs, to); } + + void solve(array& voltage, + const_view dt_intdom, const_view current, const_view conductivity) { + assemble(dt_intdom, voltage, current, conductivity); + solve(voltage); + } + std::size_t size() const { return matrix_size; } private: diff --git a/arbor/backends/gpu/shared_state.cpp b/arbor/backends/gpu/shared_state.cpp index 4f97251414..400afe361e 100644 --- a/arbor/backends/gpu/shared_state.cpp +++ b/arbor/backends/gpu/shared_state.cpp @@ -53,11 +53,10 @@ std::pair minmax_value_impl(arb_size_type n, con // Ion state methods: -ion_state::ion_state( - int charge, - const fvm_ion_config& ion_data, - unsigned, // alignment/padding ignored. - solver_ptr ptr): +ion_state::ion_state(int charge, + const fvm_ion_config& ion_data, + unsigned, // alignment/padding ignored. + solver_ptr ptr): write_eX_(ion_data.revpot_written), write_Xo_(ion_data.econc_written), write_Xi_(ion_data.iconc_written), @@ -101,7 +100,7 @@ void ion_state::reset() { // istim_state methods: -istim_state::istim_state(const fvm_stimulus_config& stim) { +istim_state::istim_state(const fvm_stimulus_config& stim, unsigned) { using util::assign; // Translate instance-to-CV index from stim to istim_state index vectors. @@ -179,21 +178,19 @@ void istim_state::add_current(const array& time, const iarray& cv_to_intdom, arr // Shared state methods: -shared_state::shared_state( - arb_size_type n_intdom, - arb_size_type n_cell, - arb_size_type n_detector, - const std::vector& cv_to_intdom_vec, - const std::vector& cv_to_cell_vec, - const std::vector& init_membrane_potential, - const std::vector& temperature_K, - const std::vector& diam, - const std::vector& src_to_spike, - unsigned, // alignment parameter ignored. - arb_seed_type cbprng_seed_ - ): +shared_state::shared_state(arb_size_type n_intdom, + arb_size_type n_cell, + const std::vector& cv_to_intdom_vec, + const std::vector& cv_to_cell_vec, + const std::vector& init_membrane_potential, + const std::vector& temperature_K, + const std::vector& diam, + const std::vector& src_to_spike, + const fvm_detector_info& detector, + unsigned, // alignment parameter ignored. + arb_seed_type cbprng_seed_): n_intdom(n_intdom), - n_detector(n_detector), + n_detector(detector.count), n_cv(cv_to_intdom_vec.size()), cv_to_intdom(make_const_view(cv_to_intdom_vec)), cv_to_cell(make_const_view(cv_to_cell_vec)), @@ -210,6 +207,15 @@ shared_state::shared_state( time_since_spike(n_cell*n_detector), src_to_spike(make_const_view(src_to_spike)), cbprng_seed(cbprng_seed_), + sample_events(n_intdom), + watcher{cv_to_intdom.data(), + src_to_spike.data(), + &time, + &time_to, + static_cast(voltage.size()), + detector.cv, + detector.threshold, + detector.ctx}, deliverable_events(n_intdom) { memory::fill(time_since_spike, -1.0); @@ -223,17 +229,6 @@ void shared_state::update_prng_state(mechanism& m) { store.random_numbers_.update(m); } -const arb_value_type* shared_state::mechanism_state_data(const mechanism& m, const std::string& key) { - const auto& store = storage.at(m.mechanism_id()); - - for (arb_size_type i = 0; iassemble(dt_intdom, - data.Xd_, - voltage, - data.iX_, - data.gX_, - data.charge[0]); - data.solver->solve(data.Xd_); - } - } -} - -void shared_state::add_ion( - const std::string& ion_name, - int charge, - const fvm_ion_config& ion_info, - ion_state::solver_ptr ptr) { - ion_data.emplace(std::piecewise_construct, - std::forward_as_tuple(ion_name), - std::forward_as_tuple(charge, ion_info, 1u, std::move(ptr))); -} - -void shared_state::configure_stimulus(const fvm_stimulus_config& stims) { - stim_data = istim_state(stims); -} - void shared_state::reset() { memory::copy(init_voltage, voltage); memory::fill(current_density, 0); @@ -435,12 +397,6 @@ void shared_state::zero_currents() { stim_data.zero_current(); } -void shared_state::ions_init_concentration() { - for (auto& i: ion_data) { - i.second.init_concentration(); - } -} - void shared_state::update_time_to(arb_value_type dt_step, arb_value_type tmax) { update_time_to_impl(n_intdom, time_to.data(), time.data(), dt_step, tmax); } @@ -449,10 +405,6 @@ void shared_state::set_dt() { set_dt_impl(n_intdom, n_cv, dt_intdom.data(), dt_cv.data(), time_to.data(), time.data(), cv_to_intdom.data()); } -void shared_state::add_stimulus_current() { - stim_data.add_current(time, cv_to_intdom, current_density); -} - std::pair shared_state::time_bounds() const { return minmax_value_impl(n_intdom, time.data()); } @@ -461,8 +413,10 @@ std::pair shared_state::voltage_bounds() const { return minmax_value_impl(n_cv, voltage.data()); } -void shared_state::take_samples(const sample_event_stream::state& s, array& sample_time, array& sample_value) { - take_samples_impl(s, time.data(), sample_time.data(), sample_value.data()); +void shared_state::take_samples() { + sample_events.mark_until(time_to); + take_samples_impl(sample_events.marked_events(), time.data(), sample_time.data(), sample_value.data()); + sample_events.drop_marked_events(); } // Debug interface diff --git a/arbor/backends/gpu/shared_state.hpp b/arbor/backends/gpu/shared_state.hpp index e08de3aece..4fd9b36a6e 100644 --- a/arbor/backends/gpu/shared_state.hpp +++ b/arbor/backends/gpu/shared_state.hpp @@ -10,11 +10,14 @@ #include "fvm_layout.hpp" +#include "backends/common_types.hpp" +#include "backends/shared_state_base.hpp" #include "backends/gpu/rand.hpp" #include "backends/gpu/gpu_store_types.hpp" #include "backends/gpu/stimulus.hpp" #include "backends/gpu/diffusion_state.hpp" #include "backends/gpu/matrix_state_fine.hpp" +#include "backends/gpu/threshold_watcher.hpp" namespace arb { namespace gpu { @@ -107,13 +110,12 @@ struct ARB_ARBOR_API istim_state { std::size_t size() const; // Construct state from i_clamp data; references to shared state vectors are used to initialize ppack. - istim_state(const fvm_stimulus_config& stim_data); + istim_state(const fvm_stimulus_config& stim_data, unsigned); istim_state() = default; }; -struct ARB_ARBOR_API shared_state { - +struct ARB_ARBOR_API shared_state: shared_state_base { struct mech_storage { array data_; iarray indices_; @@ -155,6 +157,15 @@ struct ARB_ARBOR_API shared_state { arb_seed_type cbprng_seed; // random number generator seed + sample_event_stream sample_events; + array sample_time; + array sample_value; + threshold_watcher watcher; + + // Host-side views/copies and local state. + memory::host_vector sample_time_host; + memory::host_vector sample_value_host; + istim_state stim_data; std::unordered_map ion_data; deliverable_event_stream deliverable_events; @@ -162,19 +173,17 @@ struct ARB_ARBOR_API shared_state { shared_state() = default; - shared_state( - arb_size_type n_intdom, - arb_size_type n_cell, - arb_size_type n_detector, - const std::vector& cv_to_intdom_vec, - const std::vector& cv_to_cell_vec, - const std::vector& init_membrane_potential, - const std::vector& temperature_K, - const std::vector& diam, - const std::vector& src_to_spike, - unsigned, // align parameter ignored - arb_seed_type cbprng_seed_ = 0u - ); + shared_state(arb_size_type n_intdom, + arb_size_type n_cell, + const std::vector& cv_to_intdom_vec, + const std::vector& cv_to_cell_vec, + const std::vector& init_membrane_potential, + const std::vector& temperature_K, + const std::vector& diam, + const std::vector& src_to_spike, + const fvm_detector_info& detector, + unsigned, // align parameter ignored + arb_seed_type cbprng_seed_ = 0u); // Setup a mechanism and tie its backing store to this object void instantiate(mechanism&, @@ -185,34 +194,14 @@ struct ARB_ARBOR_API shared_state { void update_prng_state(mechanism&); - // Note: returned pointer points to device memory. - const arb_value_type* mechanism_state_data(const mechanism& m, const std::string& key); - - void add_ion( - const std::string& ion_name, - int charge, - const fvm_ion_config& ion_data, - ion_state::solver_ptr solver=nullptr); - - void configure_stimulus(const fvm_stimulus_config&); - void zero_currents(); - void ions_init_concentration(); - // Set time_to to earliest of time+dt_step and tmax. void update_time_to(arb_value_type dt_step, arb_value_type tmax); // Set the per-intdom and per-compartment dt from time_to - time. void set_dt(); - // Update stimulus state and add current contributions. - void add_stimulus_current(); - - // Integrate by matrix solve. - void integrate_voltage(); - void integrate_diffusion(); - // Return minimum and maximum time value [ms] across cells. std::pair time_bounds() const; @@ -221,12 +210,15 @@ struct ARB_ARBOR_API shared_state { std::pair voltage_bounds() const; // Take samples according to marked events in a sample_event_stream. - void take_samples( - const sample_event_stream::state& s, - array& sample_time, - array& sample_value); + void take_samples(); + // Reset internal state void reset(); + + void update_sample_views() { + sample_time_host = memory::on_host(sample_time); + sample_value_host = memory::on_host(sample_value); + } }; // For debugging only diff --git a/arbor/backends/gpu/stack.hpp b/arbor/backends/gpu/stack.hpp index bac8490742..d2eed95b8b 100644 --- a/arbor/backends/gpu/stack.hpp +++ b/arbor/backends/gpu/stack.hpp @@ -39,7 +39,7 @@ class stack { private: // pointer in GPU memory - storage_type* device_storage_; + storage_type* device_storage_ = nullptr; // copy of the device_storage in host storage_type host_storage_; @@ -63,7 +63,7 @@ class stack { stack& operator=(const stack& other) = delete; stack(const stack& other) = delete; - stack() = delete; + stack() = default; stack(gpu_context_handle h): gpu_context_(h) { host_storage_.data = nullptr; diff --git a/arbor/backends/gpu/threshold_watcher.hpp b/arbor/backends/gpu/threshold_watcher.hpp index 35c65c25c7..f7600d37b0 100644 --- a/arbor/backends/gpu/threshold_watcher.hpp +++ b/arbor/backends/gpu/threshold_watcher.hpp @@ -38,7 +38,7 @@ class threshold_watcher { public: using stack_type = stack; - threshold_watcher() = delete; + threshold_watcher() = default; threshold_watcher(threshold_watcher&& other) = default; threshold_watcher& operator=(threshold_watcher&& other) = default; @@ -109,17 +109,17 @@ class threshold_watcher { /// Crossing events are recorded for each threshold that has been /// crossed since current time t, and the last time the test was /// performed. - void test(array* time_since_spike) { + void test(array& time_since_spike) { if (size()>0) { test_thresholds_impl( (int)size(), cv_to_intdom_, t_after_ptr_->data(), t_before_ptr_->data(), - src_to_spike_, time_since_spike->data(), + src_to_spike_, time_since_spike.data(), stack_.storage(), is_crossed_.data(), v_prev_.data(), cv_index_.data(), values_, thresholds_.data(), - !time_since_spike->empty()); + !time_since_spike.empty()); // Check that the number of spikes has not exceeded capacity. arb_assert(!stack_.overflow()); diff --git a/arbor/backends/multicore/diffusion_solver.hpp b/arbor/backends/multicore/diffusion_solver.hpp index 05b5d2f572..eb8ea58dae 100644 --- a/arbor/backends/multicore/diffusion_solver.hpp +++ b/arbor/backends/multicore/diffusion_solver.hpp @@ -73,7 +73,12 @@ struct diffusion_solver { // diff. concentration // * will be overwritten by the solution template - void solve(T& concentration, const_view dt_intdom, const_view voltage, const_view current, const_view conductivity, arb_value_type q) { + void solve(T& concentration, + const_view dt_intdom, + const_view voltage, + const_view current, + const_view conductivity, + arb_value_type q) { auto cell_cv_part = util::partition_view(cell_cv_divs); index_type ncells = cell_cv_part.size(); // loop over submatrices diff --git a/arbor/backends/multicore/fvm.hpp b/arbor/backends/multicore/fvm.hpp index c65997efd4..bcd3c696fb 100644 --- a/arbor/backends/multicore/fvm.hpp +++ b/arbor/backends/multicore/fvm.hpp @@ -49,23 +49,6 @@ struct backend { using shared_state = arb::multicore::shared_state; using ion_state = arb::multicore::ion_state; - static threshold_watcher voltage_watcher( - shared_state& state, - const std::vector& detector_cv, - const std::vector& thresholds, - const execution_context& context) - { - return threshold_watcher( - state.cv_to_intdom.data(), - state.src_to_spike.data(), - &state.time, - &state.time_to, - state.voltage.size(), - detector_cv, - thresholds, - context); - } - static value_type* mechanism_field_data(arb::mechanism* mptr, const std::string& field); }; diff --git a/arbor/backends/multicore/shared_state.cpp b/arbor/backends/multicore/shared_state.cpp index 4a18dacd07..07fd85fed3 100644 --- a/arbor/backends/multicore/shared_state.cpp +++ b/arbor/backends/multicore/shared_state.cpp @@ -194,23 +194,21 @@ void istim_state::add_current(const array& time, const iarray& cv_to_intdom, arr // shared_state methods: -shared_state::shared_state( - arb_size_type n_intdom, - arb_size_type n_cell, - arb_size_type n_detector, - const std::vector& cv_to_intdom_vec, - const std::vector& cv_to_cell_vec, - const std::vector& init_membrane_potential, - const std::vector& temperature_K, - const std::vector& diam, - const std::vector& src_to_spike, - unsigned align, - arb_seed_type cbprng_seed_ -): +shared_state::shared_state(arb_size_type n_intdom, + arb_size_type n_cell, + const std::vector& cv_to_intdom_vec, + const std::vector& cv_to_cell_vec, + const std::vector& init_membrane_potential, + const std::vector& temperature_K, + const std::vector& diam, + const std::vector& src_to_spike_, + const fvm_detector_info& detector, + unsigned align, + arb_seed_type cbprng_seed_): alignment(min_alignment(align)), alloc(alignment), n_intdom(n_intdom), - n_detector(n_detector), + n_detector(detector.count), n_cv(cv_to_intdom_vec.size()), cv_to_intdom(math::round_up(n_cv, alignment), pad(alignment)), cv_to_cell(math::round_up(cv_to_cell_vec.size(), alignment), pad(alignment)), @@ -225,56 +223,34 @@ shared_state::shared_state( temperature_degC(n_cv, pad(alignment)), diam_um(diam.begin(), diam.end(), pad(alignment)), time_since_spike(n_cell*n_detector, pad(alignment)), - src_to_spike(src_to_spike.begin(), src_to_spike.end(), pad(alignment)), + src_to_spike(src_to_spike_.begin(), src_to_spike_.end(), pad(alignment)), cbprng_seed(cbprng_seed_), - deliverable_events(n_intdom) -{ + sample_events(n_intdom), + watcher{cv_to_intdom.data(), + src_to_spike.data(), + &time, + &time_to, + static_cast(voltage.size()), + detector.cv, + detector.threshold, + detector.ctx}, + deliverable_events(n_intdom) { // For indices in the padded tail of cv_to_intdom, set index to last valid intdom index. if (n_cv>0) { std::copy(cv_to_intdom_vec.begin(), cv_to_intdom_vec.end(), cv_to_intdom.begin()); std::fill(cv_to_intdom.begin() + n_cv, cv_to_intdom.end(), cv_to_intdom_vec.back()); } + if (cv_to_cell_vec.size()) { std::copy(cv_to_cell_vec.begin(), cv_to_cell_vec.end(), cv_to_cell.begin()); std::fill(cv_to_cell.begin() + n_cv, cv_to_cell.end(), cv_to_cell_vec.back()); } util::fill(time_since_spike, -1.0); - for (unsigned i = 0; isolve(data.Xd_, - dt_intdom, - voltage, - data.iX_, - data.gX_, - data.charge[0]); - - } - } -} - -void shared_state::add_ion( - const std::string& ion_name, - int charge, - const fvm_ion_config& ion_info, - ion_state::solver_ptr ptr) { - ion_data.emplace(std::piecewise_construct, - std::forward_as_tuple(ion_name), - std::forward_as_tuple(charge, ion_info, alignment, std::move(ptr))); -} - -void shared_state::configure_stimulus(const fvm_stimulus_config& stims) { - stim_data = istim_state(stims, alignment); + std::transform(temperature_K.begin(), temperature_K.end(), + temperature_degC.begin(), + [](auto T) { return T - 273.15; }); + reset_thresholds(); } void shared_state::reset() { @@ -301,12 +277,6 @@ void shared_state::zero_currents() { stim_data.zero_current(); } -void shared_state::ions_init_concentration() { - for (auto& i: ion_data) { - i.second.init_concentration(); - } -} - void shared_state::update_time_to(arb_value_type dt_step, arb_value_type tmax) { using simd::assign; using simd::indirect; @@ -343,10 +313,6 @@ void shared_state::set_dt() { } } -void shared_state::add_stimulus_current() { - stim_data.add_current(time, cv_to_intdom, current_density); -} - std::pair shared_state::time_bounds() const { return util::minmax_value(time); } @@ -355,22 +321,19 @@ std::pair shared_state::voltage_bounds() const { return util::minmax_value(voltage); } -void shared_state::take_samples( - const sample_event_stream::state& s, - array& sample_time, - array& sample_value) -{ - for (arb_size_type i = 0; ioffset] = time[i]; sample_value[p->offset] = p->handle? *p->handle: 0; } } + sample_events.drop_marked_events(); } // (Debug interface only.) @@ -437,15 +400,6 @@ std::size_t extend_width(const arb::mechanism& mech, std::size_t width) { } } // anonymous namespace -const arb_value_type* shared_state::mechanism_state_data(const mechanism& m, const std::string& key) { - for (arb_size_type i = 0; i #include -#include "backends/event.hpp" -#include "backends/rand_fwd.hpp" +#include "fvm_layout.hpp" + #include "util/padded_alloc.hpp" #include "util/rangeutil.hpp" -#include "multi_event_stream.hpp" -#include "threshold_watcher.hpp" -#include "fvm_layout.hpp" -#include "multicore_common.hpp" -#include "partition_by_constraint.hpp" +#include "backends/event.hpp" +#include "backends/common_types.hpp" +#include "backends/rand_fwd.hpp" +#include "backends/shared_state_base.hpp" + +#include "backends/multicore/multi_event_stream.hpp" +#include "backends/multicore/threshold_watcher.hpp" +#include "backends/multicore/multicore_common.hpp" +#include "backends/multicore/partition_by_constraint.hpp" #include "backends/multicore/cable_solver.hpp" #include "backends/multicore/diffusion_solver.hpp" @@ -119,7 +123,7 @@ struct ARB_ARBOR_API istim_state { istim_state() = default; }; -struct ARB_ARBOR_API shared_state { +struct ARB_ARBOR_API shared_state: shared_state_base { struct mech_storage { array data_; iarray indices_; @@ -164,6 +168,15 @@ struct ARB_ARBOR_API shared_state { arb_seed_type cbprng_seed; // random number generator seed + sample_event_stream sample_events; + array sample_time; + array sample_value; + threshold_watcher watcher; + + // Host-side views/copies and local state. + util::range sample_time_host; + util::range sample_value_host; + istim_state stim_data; std::unordered_map ion_data; deliverable_event_stream deliverable_events; @@ -171,19 +184,17 @@ struct ARB_ARBOR_API shared_state { shared_state() = default; - shared_state( - arb_size_type n_intdom, - arb_size_type n_cell, - arb_size_type n_detector, - const std::vector& cv_to_intdom_vec, - const std::vector& cv_to_cell_vec, - const std::vector& init_membrane_potential, - const std::vector& temperature_K, - const std::vector& diam, - const std::vector& src_to_spike, - unsigned align, - arb_seed_type cbprng_seed_ = 0u - ); + shared_state(arb_size_type n_intdom, + arb_size_type n_cell, + const std::vector& cv_to_intdom_vec, + const std::vector& cv_to_cell_vec, + const std::vector& init_membrane_potential, + const std::vector& temperature_K, + const std::vector& diam, + const std::vector& src_to_spike, + const fvm_detector_info& detector_info, + unsigned align, + arb_seed_type cbprng_seed_ = 0u); void instantiate(mechanism&, unsigned, @@ -193,20 +204,8 @@ struct ARB_ARBOR_API shared_state { void update_prng_state(mechanism&); - const arb_value_type* mechanism_state_data(const mechanism&, const std::string&); - - void add_ion( - const std::string& ion_name, - int charge, - const fvm_ion_config& ion_data, - ion_state::solver_ptr solver=nullptr); - - void configure_stimulus(const fvm_stimulus_config&); - void zero_currents(); - void ions_init_concentration(); - void ions_nernst_reversal_potential(arb_value_type temperature_K); // Set time_to to earliest of time+dt_step and tmax. @@ -215,13 +214,6 @@ struct ARB_ARBOR_API shared_state { // Set the per-integration domain and per-compartment dt from time_to - time. void set_dt(); - // Update stimulus state and add current contributions. - void add_stimulus_current(); - - // Integrate by matrix solve. - void integrate_voltage(); - void integrate_diffusion(); - // Return minimum and maximum time value [ms] across cells. std::pair time_bounds() const; @@ -230,12 +222,15 @@ struct ARB_ARBOR_API shared_state { std::pair voltage_bounds() const; // Take samples according to marked events in a sample_event_stream. - void take_samples( - const sample_event_stream::state& s, - array& sample_time, - array& sample_value); + void take_samples(); + // Reset internal state void reset(); + + void update_sample_views() { + sample_time_host = util::range_pointer_view(sample_time); + sample_value_host = util::range_pointer_view(sample_value); + } }; // For debugging only: diff --git a/arbor/backends/multicore/threshold_watcher.hpp b/arbor/backends/multicore/threshold_watcher.hpp index 139f928123..760759bf28 100644 --- a/arbor/backends/multicore/threshold_watcher.hpp +++ b/arbor/backends/multicore/threshold_watcher.hpp @@ -14,19 +14,16 @@ namespace multicore { class threshold_watcher { public: threshold_watcher() = default; - threshold_watcher(const execution_context& ctx) {} - threshold_watcher( - const arb_index_type* cv_to_intdom, - const arb_index_type* src_to_spike, - const array* t_before, - const array* t_after, - const arb_size_type num_cv, - const std::vector& cv_index, - const std::vector& thresholds, - const execution_context& context - ): + threshold_watcher(const arb_index_type* cv_to_intdom, + const arb_index_type* src_to_spike, + const array* t_before, + const array* t_after, + const arb_size_type num_cv, + const std::vector& cv_index, + const std::vector& thresholds, + const execution_context& context): cv_to_intdom_(cv_to_intdom), src_to_spike_(src_to_spike), t_before_ptr_(t_before), @@ -66,7 +63,7 @@ class threshold_watcher { /// Tests each target for changed threshold state /// Crossing events are recorded for each threshold that /// is crossed since the last call to test - void test(array* time_since_spike) { + void test(array& time_since_spike) { // either number of cvs is 0 or values_ is not null arb_assert((n_cv_ == 0) || (bool)values_); @@ -81,9 +78,9 @@ class threshold_watcher { auto thresh = thresholds_[i]; arb_index_type spike_idx = 0; - if (!time_since_spike->empty()) { + if (!time_since_spike.empty()) { spike_idx = src_to_spike_[i]; - (*time_since_spike)[spike_idx] = -1.0; + time_since_spike[spike_idx] = -1.0; } if (!is_crossed_[i]) { @@ -94,8 +91,8 @@ class threshold_watcher { auto crossing_time = math::lerp(t_before[intdom], t_after[intdom], pos); crossings_.push_back({i, crossing_time}); - if (!time_since_spike->empty()) { - (*time_since_spike)[spike_idx] = t_after[intdom] - crossing_time; + if (!time_since_spike.empty()) { + time_since_spike[spike_idx] = t_after[intdom] - crossing_time; } is_crossed_[i] = true; diff --git a/arbor/backends/shared_state_base.hpp b/arbor/backends/shared_state_base.hpp new file mode 100644 index 0000000000..d8314feb58 --- /dev/null +++ b/arbor/backends/shared_state_base.hpp @@ -0,0 +1,137 @@ +#pragma once + +#include +#include + +#include "backends/event.hpp" +#include "backends/common_types.hpp" +#include "fvm_layout.hpp" + +#include "util/rangeutil.hpp" + +namespace arb { + +// Common functionality for CPU/GPU shared state. +template +struct shared_state_base { + + void update_time_step(time_type dt_max, time_type tfinal) { + auto d = static_cast(this); + d->deliverable_events.drop_marked_events(); + d->update_time_to(dt_max, tfinal); + d->deliverable_events.event_time_if_before(d->time_to); + d->set_dt(); + } + + void begin_epoch(std::vector deliverables, + std::vector samples) { + auto d = static_cast(this); + // events + d->deliverable_events.init(std::move(deliverables)); + // samples + auto n_samples = samples.size(); + if (d->sample_time.size() < n_samples) { + d->sample_time = array(n_samples); + d->sample_value = array(n_samples); + } + d->sample_events.init(std::move(samples)); + // thresholds + d->watcher.clear_crossings(); + } + + + void add_ion(const std::string& ion_name, + int charge, + const fvm_ion_config& ion_info, + typename ion_state::solver_ptr ptr=nullptr) { + auto d = static_cast(this); + d->ion_data.emplace(std::piecewise_construct, + std::forward_as_tuple(ion_name), + std::forward_as_tuple(charge, ion_info, d->alignment, std::move(ptr))); + } + + arb_value_type* mechanism_state_data(const mechanism& m, + const std::string& key) { + auto d = static_cast(this); + const auto& store = d->storage.at(m.mechanism_id()); + + for (arb_size_type i = 0; i(this); + d->deliverable_events.mark_until_after(d->time); + auto state = d->deliverable_events.marked_events(); + arb_deliverable_event_stream result; + result.n_streams = state.n; + result.begin = state.begin_offset; + result.end = state.end_offset; + result.events = (arb_deliverable_event_data*) state.ev_data; // FIXME(TH): This relies on bit-castability + return result; + } + + + void next_time_step() { + auto d = static_cast(this); + std::swap(d->time_to, d->time); + } + + void reset_thresholds() { + auto d = static_cast(this); + d->watcher.reset(d->voltage); + } + + void test_thresholds() { + auto d = static_cast(this); + d->watcher.test(d->time_since_spike); + } + + void configure_stimulus(const fvm_stimulus_config& stims) { + auto d = static_cast(this); + d->stim_data = {stims, d->alignment}; + } + + void add_stimulus_current() { + auto d = static_cast(this); + d->stim_data.add_current(d->time, d->cv_to_intdom, d->current_density); + } + + void ions_init_concentration() { + auto d = static_cast(this); + for (auto& i: d->ion_data) { + i.second.init_concentration(); + } + } + + void integrate_cable_state() { + auto d = static_cast(this); + d->solver.solve(d->voltage, d->dt_intdom, d->current_density, d->conductivity); + for (auto& [ion, data]: d->ion_data) { + if (data.solver) { + data.solver->solve(data.Xd_, + d->dt_intdom, + d->voltage, + data.iX_, + data.gX_, + data.charge[0]); + } + } + } + + fvm_integration_result get_integration_result() { + auto d = static_cast(this); + const auto& crossings = d->watcher.crossings(); + d->update_sample_views(); + + return { util::range_pointer_view(crossings), + util::range_pointer_view(d->sample_time_host), + util::range_pointer_view(d->sample_value_host) }; + } +}; + +} diff --git a/arbor/fvm_lowered_cell.hpp b/arbor/fvm_lowered_cell.hpp index a276ebbfd8..cf69ce012d 100644 --- a/arbor/fvm_lowered_cell.hpp +++ b/arbor/fvm_lowered_cell.hpp @@ -17,6 +17,7 @@ #include #include "backends/event.hpp" +#include "backends/common_types.hpp" #include "backends/threshold_crossing.hpp" #include "execution_context.hpp" #include "sampler_map.hpp" @@ -26,12 +27,6 @@ namespace arb { -struct fvm_integration_result { - util::range crossings; - util::range sample_time; - util::range sample_value; -}; - // A sample for a probe may be derived from multiple 'raw' sampled // values from the backend. diff --git a/arbor/fvm_lowered_cell_impl.hpp b/arbor/fvm_lowered_cell_impl.hpp index b3d09765e2..1b1e5e0eb3 100644 --- a/arbor/fvm_lowered_cell_impl.hpp +++ b/arbor/fvm_lowered_cell_impl.hpp @@ -48,7 +48,6 @@ class fvm_lowered_cell_impl: public fvm_lowered_cell { fvm_lowered_cell_impl(execution_context ctx, arb_seed_type seed = 0): context_(ctx), - threshold_watcher_(ctx), seed_{seed} {}; @@ -82,19 +81,11 @@ class fvm_lowered_cell_impl: public fvm_lowered_cell { // Host or GPU-side back-end dependent storage. using array = typename backend::array; using shared_state = typename backend::shared_state; - using sample_event_stream = typename backend::sample_event_stream; - using threshold_watcher = typename backend::threshold_watcher; execution_context context_; std::unique_ptr state_; // Cell state shared across mechanisms. - // TODO: Can we move the backend-dependent data structures below into state_? - sample_event_stream sample_events_; - array sample_time_; - array sample_value_; - threshold_watcher threshold_watcher_; - value_type tmin_ = 0; std::vector mechanisms_; // excludes reversal potential calculators. std::vector revpot_mechanisms_; @@ -109,10 +100,6 @@ class fvm_lowered_cell_impl: public fvm_lowered_cell { // Flag indicating that at least one of the mechanisms implements the post_events procedure bool post_events_ = false; - // Host-side views/copies and local state. - decltype(backend::host_view(sample_time_)) sample_time_host_; - decltype(backend::host_view(sample_value_)) sample_value_host_; - void update_ion_state(); // Throw if absolute value of membrane voltage exceeds bounds. @@ -158,7 +145,7 @@ void fvm_lowered_cell_impl::assert_tmin() { throw arbor_internal_error("fvm_lowered_cell: inconsistent times across cells"); } if (time_minmax.first != tmin_) { - throw arbor_internal_error("fvm_lowered_cell: out of synchronziation with cell state time"); + throw arbor_internal_error("fvm_lowered_cell: out of synchronization with cell state time"); } } @@ -199,7 +186,7 @@ void fvm_lowered_cell_impl::reset() { // NOTE: Threshold watcher reset must come after the voltage values are set, // as voltage is implicitly read by watcher to set initial state. - threshold_watcher_.reset(state_->voltage); + state_->reset_thresholds(); } template @@ -213,86 +200,52 @@ fvm_integration_result fvm_lowered_cell_impl::integrate( // Integration setup PE(advance:integrate:setup); - threshold_watcher_.clear_crossings(); - - auto n_samples = staged_samples.size(); - if (sample_time_.size() < n_samples) { - sample_time_ = array(n_samples); - sample_value_ = array(n_samples); - } - - auto& events = state_->deliverable_events; - events.init(std::move(staged_events)); - sample_events_.init(std::move(staged_samples)); - + // Push samples and events down to the state and reset the spike thresholds. + state_->begin_epoch(std::move(staged_events), std::move(staged_samples)); arb_assert((assert_tmin(), true)); unsigned remaining_steps = dt_steps(tmin_, tfinal, dt_max); PL(); - // TODO: Consider devolving more of this to back-end routines (e.g. - // per-compartment dt probably not a win on GPU), possibly rumbling - // complete fvm state into shared state object. while (remaining_steps) { // Update any required reversal potentials based on ionic concs. - for (auto& m: revpot_mechanisms_) { m->update_current(); } - // Deliver events and accumulate mechanism current contributions. - - PE(advance:integrate:events); - state_->deliverable_events.mark_until_after(state_->time); - PL(); - PE(advance:integrate:current:zero); state_->zero_currents(); PL(); + + // Deliver events and accumulate mechanism current contributions. + + PE(advance:integrate:events:mark); + auto deliverable_events = state_->mark_deliverable_events(); + PL(); for (auto& m: mechanisms_) { - auto state = events.marked_events(); - arb_deliverable_event_stream events; - events.n_streams = state.n; - events.begin = state.begin_offset; - events.end = state.end_offset; - events.events = (arb_deliverable_event_data*) state.ev_data; // FIXME(TH): This relies on bit-castability - m->deliver_events(events); + m->deliver_events(deliverable_events); m->update_current(); } - PE(advance:integrate:events); - events.drop_marked_events(); - // Update event list and integration step times. - - state_->update_time_to(dt_max, tfinal); - events.event_time_if_before(state_->time_to); - state_->set_dt(); + PE(advance:integrate:update_time); + state_->update_time_step(dt_max, tfinal); PL(); // Add stimulus current contributions. - // (Note: performed after dt, time_to calculation, in case we - // want to use mean current contributions as opposed to point - // sample.) - + // NOTE: performed after dt, time_to calculation, in case we want to + // use mean current contributions as opposed to point sample. PE(advance:integrate:stimuli) state_->add_stimulus_current(); PL(); // Take samples at cell time if sample time in this step interval. - PE(advance:integrate:samples); - sample_events_.mark_until(state_->time_to); - state_->take_samples(sample_events_.marked_events(), sample_time_, sample_value_); - sample_events_.drop_marked_events(); + state_->take_samples(); PL(); - // Integrate voltage / solve cable eq - PE(advance:integrate:voltage); - state_->integrate_voltage(); - PL(); - // Compute ionic diffusion effects - PE(advance:integrate:diffusion); - state_->integrate_diffusion(); + // Integrate voltage and diffusion + PE(advance:integrate:cable); + state_->integrate_cable_state(); PL(); // Integrate mechanism state for density @@ -308,6 +261,7 @@ fvm_integration_result fvm_lowered_cell_impl::integrate( // voltage mechs run now; after the cable_solver, but before the // threshold test + PE(advance:integrate:v_mechs); for (auto& m: voltage_mechanisms_) { m->update_current(); } @@ -315,10 +269,11 @@ fvm_integration_result fvm_lowered_cell_impl::integrate( state_->update_prng_state(*m); m->update_state(); } + PL(); // Update time and test for spike threshold crossings. PE(advance:integrate:threshold); - threshold_watcher_.test(&state_->time_since_spike); + state_->test_thresholds(); PL(); PE(advance:integrate:post) @@ -329,18 +284,17 @@ fvm_integration_result fvm_lowered_cell_impl::integrate( } PL(); - std::swap(state_->time_to, state_->time); + // Advance epoch by swapping current and next time. + state_->next_time_step(); // Check for non-physical solutions: - if (check_voltage_mV_) { PE(advance:integrate:physicalcheck); assert_voltage_bounded(check_voltage_mV_.value()); PL(); } - // Check for end of integration. - + // At end of epoch, see whether we need additional steps PE(advance:integrate:stepsupdate); if (!--remaining_steps) { tmin_ = state_->time_bounds().first; @@ -351,15 +305,7 @@ fvm_integration_result fvm_lowered_cell_impl::integrate( set_tmin(tfinal); - const auto& crossings = threshold_watcher_.crossings(); - sample_time_host_ = backend::host_view(sample_time_); - sample_value_host_ = backend::host_view(sample_value_); - - return fvm_integration_result{ - util::range_pointer_view(crossings), - util::range_pointer_view(sample_time_host_), - util::range_pointer_view(sample_value_host_) - }; + return state_->get_integration_result(); } template @@ -383,6 +329,23 @@ void fvm_lowered_cell_impl::assert_voltage_bounded(arb_value_type bound v_minmax.first<-bound? v_minmax.first: v_minmax.second); } +inline +fvm_detector_info get_detector_info(arb_size_type max, + arb_size_type ncell, + const std::vector& cells, + const fvm_cv_discretization& D, + execution_context ctx) { + std::vector cv; + std::vector threshold; + for (auto cell_idx: util::make_span(ncell)) { + for (auto entry: cells[cell_idx].detectors()) { + cv.push_back(D.geometry.location_cv(cell_idx, entry.loc, cv_prefer::cv_empty)); + threshold.push_back(entry.item.threshold); + } + } + return { max, std::move(cv), std::move(threshold), ctx }; +} + template fvm_initialization_data fvm_lowered_cell_impl::initialize( const std::vector& gids, @@ -478,7 +441,6 @@ fvm_initialization_data fvm_lowered_cell_impl::initialize( [&fvm_info](index_type i){ return fvm_info.cell_to_intdom[i]; }); arb_assert(D.n_cell() == ncell); - sample_events_ = sample_event_stream(nintdom); // Discretize and build gap junction info. @@ -520,11 +482,17 @@ fvm_initialization_data fvm_lowered_cell_impl::initialize( util::transform_view(keys(mech_data.mechanisms), [&](const std::string& name) { return mech_instance(name).mech->data_alignment(); })); - state_ = std::make_unique( - nintdom, ncell, max_detector, cv_to_intdom, std::move(cv_to_cell), - D.init_membrane_potential, D.temperature_K, D.diam_um, std::move(src_to_spike), - data_alignment? data_alignment: 1u, seed_); - + auto d_info = get_detector_info(max_detector, ncell, cells, D, context_); + + state_ = std::make_unique(nintdom, + ncell, + cv_to_intdom, + std::move(cv_to_cell), + D.init_membrane_potential, D.temperature_K, D.diam_um, + std::move(src_to_spike), + d_info, + data_alignment? data_alignment: 1u, + seed_); state_->solver = {D.geometry.cv_parent, D.geometry.cell_cv_divs, D.cv_capacitance, D.face_conductance, D.cv_area, fvm_info.cell_to_intdom}; @@ -657,19 +625,10 @@ fvm_initialization_data fvm_lowered_cell_impl::initialize( } - std::vector detector_cv; - std::vector detector_threshold; std::vector probe_data; for (auto cell_idx: make_span(ncell)) { cell_gid_type gid = gids[cell_idx]; - - // Collect detectors, probe handles. - for (auto entry: cells[cell_idx].detectors()) { - detector_cv.push_back(D.geometry.location_cv(cell_idx, entry.loc, cv_prefer::cv_empty)); - detector_threshold.push_back(entry.item.threshold); - } - std::vector rec_probes = rec.get_probes(gid); for (cell_lid_type i: count_along(rec_probes)) { probe_info& pi = rec_probes[i]; @@ -687,10 +646,7 @@ fvm_initialization_data fvm_lowered_cell_impl::initialize( } } - threshold_watcher_ = backend::voltage_watcher(*state_, detector_cv, detector_threshold, context_); - reset(); - return fvm_info; } diff --git a/arbor/util/ordered_forest.hpp b/arbor/util/ordered_forest.hpp index a1346af68a..b07b7e4960 100644 --- a/arbor/util/ordered_forest.hpp +++ b/arbor/util/ordered_forest.hpp @@ -573,8 +573,7 @@ class ordered_forest { auto copy_children = [&](auto& self, const auto& from, auto& to) -> void { sibling_iterator j; for (auto i = other.child_begin(from); i!=other.child_end(from); ++i) { - // TODO: explicit `this` required for g++6; remove when g++6 deprecated. - j = j? this->insert_after(j, *i): sibling_iterator(this->push_child(to, *i)); + j = j? insert_after(j, *i): sibling_iterator(push_child(to, *i)); self(self, i, j); } }; diff --git a/python/morphology.cpp b/python/morphology.cpp index a344b066b1..3c0e3579a9 100644 --- a/python/morphology.cpp +++ b/python/morphology.cpp @@ -8,6 +8,8 @@ #include #include +#include +#include #include #include #include @@ -207,6 +209,27 @@ void register_morphology(py::module& m) { "Find the location on the morphology that is closest to a 3d point. " "Returns the location and its distance from the point."); + // arb::place_pwlin + py::class_ prov(m, "morphology_provider"); + prov + .def(py::init(), + "morphology"_a, + "Construct a morphology provider.") + .def("reify_locset", + [](const arb::mprovider& p, + const std::string& r) { + return thingify(arborio::parse_locset_expression(r).unwrap(), p); + }, + "Turn a locset into a list of locations.") + .def("reify_region", + [](const arb::mprovider& p, + const std::string& r) { + return thingify(arborio::parse_region_expression(r).unwrap(), p); + }, + "Turn a region into an extent."); + + + // // Higher-level data structures (segment_tree, morphology) // diff --git a/test/unit/test_abi.cpp b/test/unit/test_abi.cpp index 7c700ac90a..46ebf7e6b2 100644 --- a/test/unit/test_abi.cpp +++ b/test/unit/test_abi.cpp @@ -7,6 +7,7 @@ #include #include +#include "backends/common_types.hpp" #include "backends/multicore/shared_state.hpp" #ifdef ARB_GPU_ENABLED #include "backends/gpu/shared_state.hpp" @@ -50,9 +51,10 @@ TEST(abi, multicore_initialisation) { std::vector vinit(ncv, -65); std::vector src_to_spike = {}; - arb::multicore::shared_state shared_state(ncell, ncell, 0, + arb::multicore::shared_state shared_state(ncell, ncell, cv_to_intdom, cv_to_intdom, vinit, temp, diam, src_to_spike, + arb::fvm_detector_info(), mech.data_alignment()); arb::mechanism_layout layout; @@ -128,9 +130,10 @@ TEST(abi, multicore_null) { std::vector vinit(ncv, -65); std::vector src_to_spike = {}; - arb::multicore::shared_state shared_state(ncell, ncell, 0, + arb::multicore::shared_state shared_state(ncell, ncell, cv_to_intdom, cv_to_intdom, vinit, temp, diam, src_to_spike, + arb::fvm_detector_info{}, mech.data_alignment()); arb::mechanism_layout layout; @@ -193,9 +196,10 @@ TEST(abi, gpu_initialisation) { std::vector vinit(ncv, -65); std::vector src_to_spike = {}; - arb::gpu::shared_state shared_state(ncell, ncell, 0, + arb::gpu::shared_state shared_state(ncell, ncell, cv_to_intdom, cv_to_intdom, vinit, temp, diam, src_to_spike, + arb::fvm_detector_info{}, 1); arb::mechanism_layout layout; @@ -270,9 +274,10 @@ TEST(abi, gpu_null) { std::vector vinit(ncv, -65); std::vector src_to_spike = {}; - arb::gpu::shared_state shared_state(ncell, ncell, 0, + arb::gpu::shared_state shared_state(ncell, ncell, cv_to_intdom, cv_to_intdom, vinit, temp, diam, src_to_spike, + arb::fvm_detector_info{}, 1); arb::mechanism_layout layout; diff --git a/test/unit/test_fvm_lowered.cpp b/test/unit/test_fvm_lowered.cpp index 351504fce1..c44efc31f6 100644 --- a/test/unit/test_fvm_lowered.cpp +++ b/test/unit/test_fvm_lowered.cpp @@ -604,8 +604,12 @@ TEST(fvm_lowered, ionic_concentrations) { auto& read_cai_mech = read_cai.mech; auto& write_cai_mech = write_cai.mech; - auto shared_state = std::make_unique( - ncell, ncell, 0, cv_to_intdom, cv_to_intdom, vinit, temp, diam, src_to_spike, read_cai_mech->data_alignment()); + auto shared_state = std::make_unique(ncell, ncell, + cv_to_intdom, cv_to_intdom, + vinit, temp, diam, + src_to_spike, + fvm_detector_info{}, + read_cai_mech->data_alignment()); shared_state->add_ion("ca", 2, ion_config); shared_state->instantiate(*read_cai_mech, 0, overrides, layout, {}); diff --git a/test/unit/test_kinetic_linear.cpp b/test/unit/test_kinetic_linear.cpp index bccdf44e28..3566879c08 100644 --- a/test/unit/test_kinetic_linear.cpp +++ b/test/unit/test_kinetic_linear.cpp @@ -45,8 +45,12 @@ void run_test(std::string mech_name, std::vector vinit(ncv, -65); std::vector src_to_spike = {}; - auto shared_state = std::make_unique( - ncell, ncell, 0, cv_to_intdom, cv_to_intdom, vinit, temp, diam, src_to_spike, test->data_alignment()); + auto shared_state = std::make_unique(ncell, ncell, + cv_to_intdom, cv_to_intdom, + vinit, temp, diam, + src_to_spike, + fvm_detector_info{}, + test->data_alignment()); mechanism_layout layout; mechanism_overrides overrides; diff --git a/test/unit/test_mech_temp_diam.cpp b/test/unit/test_mech_temp_diam.cpp index dbaf1b0a12..449cd4e77c 100644 --- a/test/unit/test_mech_temp_diam.cpp +++ b/test/unit/test_mech_temp_diam.cpp @@ -4,6 +4,7 @@ #include #include "backends/multicore/fvm.hpp" +#include "backends/common_types.hpp" #ifdef ARB_GPU_ENABLED #include "backends/gpu/fvm.hpp" #endif @@ -35,8 +36,12 @@ void run_celsius_test() { std::vector vinit(ncv, -65); std::vector src_to_spike = {}; - auto shared_state = std::make_unique( - ncell, ncell, 0, cv_to_intdom, cv_to_intdom, vinit, temp, diam, src_to_spike, celsius_test->data_alignment()); + auto shared_state = std::make_unique(ncell, ncell, + cv_to_intdom, cv_to_intdom, + vinit, temp, diam, + src_to_spike, + fvm_detector_info{}, + celsius_test->data_alignment()); mechanism_layout layout; mechanism_overrides overrides; @@ -92,9 +97,12 @@ void run_diam_test() { layout.cv.push_back(i); } - auto shared_state = std::make_unique( - ncell, ncell, 0, cv_to_intdom, cv_to_intdom, vinit, temp, diam, src_to_spike, celsius_test->data_alignment()); - + auto shared_state = std::make_unique(ncell, ncell, + cv_to_intdom, cv_to_intdom, + vinit, temp, diam, + src_to_spike, + fvm_detector_info{}, + celsius_test->data_alignment()); shared_state->instantiate(*celsius_test, 0, overrides, layout, {}); shared_state->reset(); diff --git a/test/unit/test_spikes.cpp b/test/unit/test_spikes.cpp index 987dd99f93..56d80b89cb 100644 --- a/test/unit/test_spikes.cpp +++ b/test/unit/test_spikes.cpp @@ -90,7 +90,7 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) { // test again at t=1, with unchanged values // - nothing should change memory::fill(time_after, 1.); - watch.test(&time_since_spike); + watch.test(time_since_spike); EXPECT_FALSE(watch.is_crossed(0)); EXPECT_TRUE(watch.is_crossed(1)); EXPECT_FALSE(watch.is_crossed(2)); @@ -106,7 +106,7 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) { memory::fill(values, 0.); memory::copy(time_after, time_before); memory::fill(time_after, 2.); - watch.test(&time_since_spike); + watch.test(time_since_spike); EXPECT_FALSE(watch.is_crossed(0)); EXPECT_FALSE(watch.is_crossed(1)); EXPECT_FALSE(watch.is_crossed(2)); @@ -118,7 +118,7 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) { memory::copy(time_after, time_before); time_after[0] = 2.5; time_after[1] = 3.0; - watch.test(&time_since_spike); + watch.test(time_since_spike); EXPECT_TRUE(watch.is_crossed(0)); EXPECT_TRUE(watch.is_crossed(1)); EXPECT_TRUE(watch.is_crossed(2)); @@ -142,7 +142,7 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) { memory::fill(values, 0.); memory::copy(time_after, time_before); memory::fill(time_after, 4.); - watch.test(&time_since_spike); + watch.test(time_since_spike); EXPECT_FALSE(watch.is_crossed(0)); EXPECT_FALSE(watch.is_crossed(1)); EXPECT_FALSE(watch.is_crossed(2)); @@ -158,7 +158,7 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher) { values[index[2]] = 6.; memory::copy(time_after, time_before); memory::fill(time_after, 5.); - watch.test(&time_since_spike); + watch.test(time_since_spike); EXPECT_FALSE(watch.is_crossed(0)); EXPECT_FALSE(watch.is_crossed(1)); EXPECT_TRUE(watch.is_crossed(2)); diff --git a/test/unit/test_synapses.cpp b/test/unit/test_synapses.cpp index d13aa718b0..a5051647d5 100644 --- a/test/unit/test_synapses.cpp +++ b/test/unit/test_synapses.cpp @@ -88,15 +88,15 @@ TEST(synapses, syn_basic_state) { auto align = std::max(expsyn->data_alignment(), exp2syn->data_alignment()); shared_state state(num_intdom, - num_intdom, - 0, - std::vector(num_comp, 0), - std::vector(num_comp, 0), - std::vector(num_comp, -65), - std::vector(num_comp, temp_K), - std::vector(num_comp, 1.), - std::vector(0), - align); + num_intdom, + std::vector(num_comp, 0), + std::vector(num_comp, 0), + std::vector(num_comp, -65), + std::vector(num_comp, temp_K), + std::vector(num_comp, 1.), + std::vector(0), + fvm_detector_info{}, + align); state.reset(); fill(state.current_density, 1.0);