Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🧹 Re-factor FVM lowered cell implementation and shared state #2082

Merged
merged 36 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
d1bb28c
Re-factor fvm_lowered_cell.
thorstenhater Jan 20, 2023
9c34a15
Add integration_result.
thorstenhater Jan 20, 2023
d7273e5
Fix includes
thorstenhater Jan 20, 2023
5d4e245
Fix epoch setup.
thorstenhater Jan 20, 2023
a554213
Fix GPU memory.
thorstenhater Jan 20, 2023
ae69e3c
API woes begone.
thorstenhater Jan 20, 2023
eb56669
Consistent namings.
thorstenhater Jan 20, 2023
904f804
Make Stack default constructible.
thorstenhater Jan 24, 2023
eb6f8dc
Remove trailing _.
thorstenhater Jan 24, 2023
438e7cf
Extract detector building.
thorstenhater Jan 24, 2023
6625b1d
Shuffle.
thorstenhater Jan 24, 2023
b951ac2
Squash auto-expand.
thorstenhater Jan 24, 2023
60ed556
Polish a bit more.
thorstenhater Jan 24, 2023
ed73a23
Fix the obvious oversights.
thorstenhater Jan 24, 2023
52936b5
Not const!.
thorstenhater Jan 24, 2023
e9063e4
CRTP to the rescue.
thorstenhater Jan 24, 2023
a132d8b
The attachment.
thorstenhater Jan 24, 2023
5c816d6
The Inheritance.
thorstenhater Jan 24, 2023
def840a
The Deduplication.
thorstenhater Jan 24, 2023
7af4763
The Imitation.
thorstenhater Jan 24, 2023
72752d6
The Inclusion.
thorstenhater Jan 24, 2023
042e23a
The Removal.
thorstenhater Jan 24, 2023
a573524
The Removal, part II.
thorstenhater Jan 24, 2023
9861a72
The Solution.
thorstenhater Jan 24, 2023
8cd76f1
The Renaming.
thorstenhater Jan 24, 2023
f69dcb9
The Confusion.
thorstenhater Jan 24, 2023
17d355a
GPU, CPU tests are fine.
thorstenhater Jan 24, 2023
70ec3c5
The Deinfestation.
thorstenhater Jan 24, 2023
c721997
The Reset.
thorstenhater Jan 24, 2023
418f5a1
The Finalisation.
thorstenhater Jan 25, 2023
43dcf80
Clean-up a print statement.
thorstenhater Jan 25, 2023
1d77202
Merge remote-tracking branch 'origin/master' into qa/refactor-fvm-low…
thorstenhater Feb 3, 2023
9969554
Review: re-insert missing line, remove spurious output.
thorstenhater Feb 3, 2023
553849f
Merge remote-tracking branch 'origin/master' into qa/refactor-fvm-low…
thorstenhater Feb 8, 2023
35b741d
Squash erroneous PE/PL.
thorstenhater Feb 8, 2023
6121d33
Add PROFILING to CI so we do not stumble here again.
thorstenhater Feb 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions arbor/backends/common_types.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include <arbor/common_types.hpp>

#include "util/range.hpp"
#include "backends/threshold_crossing.hpp"
#include "execution_context.hpp"

namespace arb {

struct fvm_integration_result {
util::range<const threshold_crossing*> crossings;
util::range<const arb_value_type*> sample_time;
util::range<const arb_value_type*> sample_value;
};

struct fvm_detector_info {
arb_size_type count = 0;
std::vector<arb_index_type> cv;
std::vector<arb_value_type> threshold;
execution_context ctx;
};

}
10 changes: 10 additions & 0 deletions arbor/backends/gpu/diffusion_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,16 @@ struct diffusion_state {
packed_to_flat(rhs, to);
}

void solve(array& concentration,
const_view dt_intdom,
const_view voltage,
const_view current,
const_view conductivity,
arb_value_type q) {
assemble(dt_intdom, concentration, voltage, current, conductivity, q);
solve(concentration);
}

std::size_t size() const { return matrix_size; }

private:
Expand Down
17 changes: 0 additions & 17 deletions arbor/backends/gpu/fvm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,6 @@ struct backend {

using shared_state = arb::gpu::shared_state;
using ion_state = arb::gpu::ion_state;

static threshold_watcher voltage_watcher(
shared_state& state,
const std::vector<index_type>& detector_cv,
const std::vector<value_type>& thresholds,
const execution_context& context)
{
return threshold_watcher(
state.cv_to_intdom.data(),
state.src_to_spike.data(),
&state.time,
&state.time_to,
state.voltage.size(),
detector_cv,
thresholds,
context);
}
};

} // namespace gpu
Expand Down
7 changes: 7 additions & 0 deletions arbor/backends/gpu/matrix_state_fine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,13 @@ struct matrix_state_fine {
packed_to_flat(rhs, to);
}


void solve(array& voltage,
const_view dt_intdom, const_view current, const_view conductivity) {
assemble(dt_intdom, voltage, current, conductivity);
solve(voltage);
}

std::size_t size() const { return matrix_size; }

private:
Expand Down
106 changes: 30 additions & 76 deletions arbor/backends/gpu/shared_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@ std::pair<arb_value_type, arb_value_type> minmax_value_impl(arb_size_type n, con

// Ion state methods:

ion_state::ion_state(
int charge,
const fvm_ion_config& ion_data,
unsigned, // alignment/padding ignored.
solver_ptr ptr):
ion_state::ion_state(int charge,
const fvm_ion_config& ion_data,
unsigned, // alignment/padding ignored.
solver_ptr ptr):
write_eX_(ion_data.revpot_written),
write_Xo_(ion_data.econc_written),
write_Xi_(ion_data.iconc_written),
Expand Down Expand Up @@ -101,7 +100,7 @@ void ion_state::reset() {

// istim_state methods:

istim_state::istim_state(const fvm_stimulus_config& stim) {
istim_state::istim_state(const fvm_stimulus_config& stim, unsigned) {
using util::assign;

// Translate instance-to-CV index from stim to istim_state index vectors.
Expand Down Expand Up @@ -179,21 +178,19 @@ void istim_state::add_current(const array& time, const iarray& cv_to_intdom, arr

// Shared state methods:

shared_state::shared_state(
arb_size_type n_intdom,
arb_size_type n_cell,
arb_size_type n_detector,
const std::vector<arb_index_type>& cv_to_intdom_vec,
const std::vector<arb_index_type>& cv_to_cell_vec,
const std::vector<arb_value_type>& init_membrane_potential,
const std::vector<arb_value_type>& temperature_K,
const std::vector<arb_value_type>& diam,
const std::vector<arb_index_type>& src_to_spike,
unsigned, // alignment parameter ignored.
arb_seed_type cbprng_seed_
):
shared_state::shared_state(arb_size_type n_intdom,
arb_size_type n_cell,
const std::vector<arb_index_type>& cv_to_intdom_vec,
const std::vector<arb_index_type>& cv_to_cell_vec,
const std::vector<arb_value_type>& init_membrane_potential,
const std::vector<arb_value_type>& temperature_K,
const std::vector<arb_value_type>& diam,
const std::vector<arb_index_type>& src_to_spike,
const fvm_detector_info& detector,
unsigned, // alignment parameter ignored.
arb_seed_type cbprng_seed_):
n_intdom(n_intdom),
n_detector(n_detector),
n_detector(detector.count),
n_cv(cv_to_intdom_vec.size()),
cv_to_intdom(make_const_view(cv_to_intdom_vec)),
cv_to_cell(make_const_view(cv_to_cell_vec)),
Expand All @@ -210,6 +207,15 @@ shared_state::shared_state(
time_since_spike(n_cell*n_detector),
src_to_spike(make_const_view(src_to_spike)),
cbprng_seed(cbprng_seed_),
sample_events(n_intdom),
watcher{cv_to_intdom.data(),
src_to_spike.data(),
&time,
&time_to,
static_cast<arb_size_type>(voltage.size()),
detector.cv,
detector.threshold,
detector.ctx},
deliverable_events(n_intdom)
{
memory::fill(time_since_spike, -1.0);
Expand All @@ -223,17 +229,6 @@ void shared_state::update_prng_state(mechanism& m) {
store.random_numbers_.update(m);
}

const arb_value_type* shared_state::mechanism_state_data(const mechanism& m, const std::string& key) {
const auto& store = storage.at(m.mechanism_id());

for (arb_size_type i = 0; i<m.mech_.n_state_vars; ++i) {
if (key==m.mech_.state_vars[i].name) {
return store.state_vars_[i];
}
}
return nullptr;
}

void shared_state::instantiate(mechanism& m,
unsigned id,
const mechanism_overrides& overrides,
Expand Down Expand Up @@ -380,39 +375,6 @@ void shared_state::instantiate(mechanism& m,
store.random_numbers_.instantiate(m, width_padded, pos_data, cbprng_seed);
}

void shared_state::integrate_voltage() {
solver.assemble(dt_intdom, voltage, current_density, conductivity);
solver.solve(voltage);
}

void shared_state::integrate_diffusion() {
for (auto& [ion, data]: ion_data) {
if (data.solver) {
data.solver->assemble(dt_intdom,
data.Xd_,
voltage,
data.iX_,
data.gX_,
data.charge[0]);
data.solver->solve(data.Xd_);
}
}
}

void shared_state::add_ion(
const std::string& ion_name,
int charge,
const fvm_ion_config& ion_info,
ion_state::solver_ptr ptr) {
ion_data.emplace(std::piecewise_construct,
std::forward_as_tuple(ion_name),
std::forward_as_tuple(charge, ion_info, 1u, std::move(ptr)));
}

void shared_state::configure_stimulus(const fvm_stimulus_config& stims) {
stim_data = istim_state(stims);
}

void shared_state::reset() {
memory::copy(init_voltage, voltage);
memory::fill(current_density, 0);
Expand All @@ -436,12 +398,6 @@ void shared_state::zero_currents() {
stim_data.zero_current();
}

void shared_state::ions_init_concentration() {
for (auto& i: ion_data) {
i.second.init_concentration();
}
}

void shared_state::update_time_to(arb_value_type dt_step, arb_value_type tmax) {
update_time_to_impl(n_intdom, time_to.data(), time.data(), dt_step, tmax);
}
Expand All @@ -450,10 +406,6 @@ void shared_state::set_dt() {
set_dt_impl(n_intdom, n_cv, dt_intdom.data(), dt_cv.data(), time_to.data(), time.data(), cv_to_intdom.data());
}

void shared_state::add_stimulus_current() {
stim_data.add_current(time, cv_to_intdom, current_density);
}

std::pair<arb_value_type, arb_value_type> shared_state::time_bounds() const {
return minmax_value_impl(n_intdom, time.data());
}
Expand All @@ -462,8 +414,10 @@ std::pair<arb_value_type, arb_value_type> shared_state::voltage_bounds() const {
return minmax_value_impl(n_cv, voltage.data());
}

void shared_state::take_samples(const sample_event_stream::state& s, array& sample_time, array& sample_value) {
take_samples_impl(s, time.data(), sample_time.data(), sample_value.data());
void shared_state::take_samples() {
sample_events.mark_until(time_to);
take_samples_impl(sample_events.marked_events(), time.data(), sample_time.data(), sample_value.data());
sample_events.drop_marked_events();
}

// Debug interface
Expand Down
72 changes: 32 additions & 40 deletions arbor/backends/gpu/shared_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@

#include "fvm_layout.hpp"

#include "backends/common_types.hpp"
#include "backends/shared_state_base.hpp"
#include "backends/gpu/rand.hpp"
#include "backends/gpu/gpu_store_types.hpp"
#include "backends/gpu/stimulus.hpp"
#include "backends/gpu/diffusion_state.hpp"
#include "backends/gpu/matrix_state_fine.hpp"
#include "backends/gpu/threshold_watcher.hpp"

namespace arb {
namespace gpu {
Expand Down Expand Up @@ -107,13 +110,12 @@ struct ARB_ARBOR_API istim_state {
std::size_t size() const;

// Construct state from i_clamp data; references to shared state vectors are used to initialize ppack.
istim_state(const fvm_stimulus_config& stim_data);
istim_state(const fvm_stimulus_config& stim_data, unsigned);

istim_state() = default;
};

struct ARB_ARBOR_API shared_state {

struct ARB_ARBOR_API shared_state: shared_state_base<shared_state, array, ion_state> {
struct mech_storage {
array data_;
iarray indices_;
Expand Down Expand Up @@ -155,26 +157,33 @@ struct ARB_ARBOR_API shared_state {

arb_seed_type cbprng_seed; // random number generator seed

sample_event_stream sample_events;
array sample_time;
array sample_value;
threshold_watcher watcher;

// Host-side views/copies and local state.
memory::host_vector<arb_value_type> sample_time_host;
memory::host_vector<arb_value_type> sample_value_host;

istim_state stim_data;
std::unordered_map<std::string, ion_state> ion_data;
deliverable_event_stream deliverable_events;
std::unordered_map<unsigned, mech_storage> storage;

shared_state() = default;

shared_state(
arb_size_type n_intdom,
arb_size_type n_cell,
arb_size_type n_detector,
const std::vector<arb_index_type>& cv_to_intdom_vec,
const std::vector<arb_index_type>& cv_to_cell_vec,
const std::vector<arb_value_type>& init_membrane_potential,
const std::vector<arb_value_type>& temperature_K,
const std::vector<arb_value_type>& diam,
const std::vector<arb_index_type>& src_to_spike,
unsigned, // align parameter ignored
arb_seed_type cbprng_seed_ = 0u
);
shared_state(arb_size_type n_intdom,
arb_size_type n_cell,
const std::vector<arb_index_type>& cv_to_intdom_vec,
const std::vector<arb_index_type>& cv_to_cell_vec,
const std::vector<arb_value_type>& init_membrane_potential,
const std::vector<arb_value_type>& temperature_K,
const std::vector<arb_value_type>& diam,
const std::vector<arb_index_type>& src_to_spike,
const fvm_detector_info& detector,
unsigned, // align parameter ignored
arb_seed_type cbprng_seed_ = 0u);

// Setup a mechanism and tie its backing store to this object
void instantiate(mechanism&,
Expand All @@ -185,34 +194,14 @@ struct ARB_ARBOR_API shared_state {

void update_prng_state(mechanism&);

// Note: returned pointer points to device memory.
const arb_value_type* mechanism_state_data(const mechanism& m, const std::string& key);

void add_ion(
const std::string& ion_name,
int charge,
const fvm_ion_config& ion_data,
ion_state::solver_ptr solver=nullptr);

void configure_stimulus(const fvm_stimulus_config&);

void zero_currents();

void ions_init_concentration();

// Set time_to to earliest of time+dt_step and tmax.
void update_time_to(arb_value_type dt_step, arb_value_type tmax);

// Set the per-intdom and per-compartment dt from time_to - time.
void set_dt();

// Update stimulus state and add current contributions.
void add_stimulus_current();

// Integrate by matrix solve.
void integrate_voltage();
void integrate_diffusion();

// Return minimum and maximum time value [ms] across cells.
std::pair<arb_value_type, arb_value_type> time_bounds() const;

Expand All @@ -221,12 +210,15 @@ struct ARB_ARBOR_API shared_state {
std::pair<arb_value_type, arb_value_type> voltage_bounds() const;

// Take samples according to marked events in a sample_event_stream.
void take_samples(
const sample_event_stream::state& s,
array& sample_time,
array& sample_value);
void take_samples();

// Reset internal state
void reset();

void update_sample_views() {
sample_time_host = memory::on_host(sample_time);
sample_value_host = memory::on_host(sample_value);
}
};

// For debugging only
Expand Down
Loading