Skip to content

Commit

Permalink
Return a structured array for the spikes in python
Browse files Browse the repository at this point in the history
 - Modify Spike object from std::pair to struct
  • Loading branch information
jorblancoa committed Jul 26, 2023
1 parent c05694d commit 88a67a9
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 36 deletions.
10 changes: 9 additions & 1 deletion include/bbp/sonata/report_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@ struct SONATA_API DataFrame {
std::vector<float> data;
};

using Spike = std::pair<NodeID, double>;
typedef struct {
NodeID node_id;
double timestamp;
} Spike;

bool operator==(const Spike& lhs, const Spike& rhs) {
return lhs.node_id == rhs.node_id && lhs.timestamp == rhs.timestamp;
}

using Spikes = std::vector<Spike>;

/// Used to read spike files
Expand Down
24 changes: 18 additions & 6 deletions python/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1173,13 +1173,25 @@ PYBIND11_MODULE(_libsonata, m) {

bindStorageClass<EdgeStorage>(m, "EdgeStorage", "EdgePopulation");

PYBIND11_NUMPY_DTYPE(Spike, node_id, timestamp);

py::class_<SpikeReader::Population>(m, "SpikePopulation", "A population inside a SpikeReader")
.def("get",
&SpikeReader::Population::get,
DOC_SPIKEREADER_POP(get),
"node_ids"_a = nonstd::nullopt,
"tstart"_a = nonstd::nullopt,
"tstop"_a = nonstd::nullopt)
.def(
"get",
[](const SpikeReader::Population& self,
const py::object& node_ids = py::none(),
const py::object& tstart = py::none(),
const py::object& tstop = py::none()) {
auto spikes = self.get(
node_ids.is_none() ? nonstd::nullopt
: node_ids.cast<nonstd::optional<Selection>>(),
tstart.is_none() ? nonstd::nullopt : tstart.cast<nonstd::optional<double>>(),
tstop.is_none() ? nonstd::nullopt : tstop.cast<nonstd::optional<double>>());
return py::array_t<Spike>(spikes.size(), spikes.data());
},
"node_ids"_a = nonstd::nullopt,
"tstart"_a = nonstd::nullopt,
"tstop"_a = nonstd::nullopt)
.def_property_readonly(
"sorting",
[](const SpikeReader::Population& self) {
Expand Down
23 changes: 13 additions & 10 deletions python/tests/test_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,22 @@ def test_get_inexistant_population(self):
self.assertRaises(RuntimeError, self.test_obj.__getitem__, 'foobar')

def test_get_spikes_from_population(self):
self.assertEqual(self.test_obj['All'].get(), [(5, 0.1), (2, 0.2), (3, 0.3), (2, 0.7), (3, 1.3)])
self.assertEqual(self.test_obj['All'].get(tstart=0.2, tstop=1.0), [(2, 0.2), (3, 0.3), (2, 0.7)])
self.assertEqual(self.test_obj['spikes2'].get(tstart=0.2, tstop=1.0), [(3, 0.3), (2, 0.2), (2, 0.7)])
self.assertEqual(self.test_obj['spikes1'].get((3,)), [(3, 0.3), (3, 1.3)])
self.assertEqual(self.test_obj['spikes2'].get((3,)), [(3, 0.3), (3, 1.3)])
self.assertEqual(self.test_obj['spikes2'].get((10,)), [])
self.assertEqual(self.test_obj['spikes2'].get((2,), 0., 0.5), [(2, 0.2)])
self.assertEqual(self.test_obj['spikes1'].get((2, 5)), [(2, 0.2), (2, 0.7), (5, 0.1)])
self.assertEqual(self.test_obj['spikes2'].get((2, 5)), [(5, 0.1), (2, 0.2), (2, 0.7)])
# Define the structured data type
dtype = np.dtype([('node_id', '<u8'), ('timestamp', '<f8')])

self.assertTrue(np.array_equal(self.test_obj['All'].get(), np.array([(5, 0.1), (2, 0.2), (3, 0.3), (2, 0.7), (3, 1.3)], dtype=dtype)))
self.assertTrue(np.array_equal(self.test_obj['All'].get(tstart=0.2, tstop=1.0), np.array([(2, 0.2), (3, 0.3), (2, 0.7)], dtype=dtype)))
self.assertTrue(np.array_equal(self.test_obj['spikes2'].get(tstart=0.2, tstop=1.0), np.array([(3, 0.3), (2, 0.2), (2, 0.7)], dtype=dtype)))
self.assertTrue(np.array_equal(self.test_obj['spikes1'].get((3,)), np.array([(3, 0.3), (3, 1.3)], dtype=dtype)))
self.assertTrue(np.array_equal(self.test_obj['spikes2'].get((3,)), np.array([(3, 0.3), (3, 1.3)], dtype=dtype)))
self.assertTrue(np.array_equal(self.test_obj['spikes2'].get((10,)), np.array([], dtype=dtype)))
self.assertTrue(np.array_equal(self.test_obj['spikes2'].get((2,), 0., 0.5), np.array([(2, 0.2)], dtype=dtype)))
self.assertTrue(np.array_equal(self.test_obj['spikes1'].get((2, 5)), np.array([(2, 0.2), (2, 0.7), (5, 0.1)], dtype=dtype)))
self.assertTrue(np.array_equal(self.test_obj['spikes2'].get((2, 5)), np.array([(5, 0.1), (2, 0.2), (2, 0.7)], dtype=dtype)))
self.assertEqual(self.test_obj['All'].sorting, "by_time")
self.assertEqual(self.test_obj['spikes1'].sorting, "by_id")
self.assertEqual(self.test_obj['spikes2'].sorting, "none")
self.assertEqual(self.test_obj['empty'].get(), [])
self.assertTrue(np.array_equal(self.test_obj['empty'].get(), np.array([], dtype=dtype)))

self.assertEqual(len(self.test_obj['All'].get(node_ids=[])), 0)

Expand Down
28 changes: 14 additions & 14 deletions src/report_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void filterNodeIDUnsorted(Spikes& spikes, const Selection& node_ids) {
const auto values = node_ids.flatten();
const auto new_end =
std::remove_if(spikes.begin(), spikes.end(), [&values](const Spike& spike) {
return std::find(values.cbegin(), values.cend(), spike.first) == values.cend();
return std::find(values.cbegin(), values.cend(), spike.node_id) == values.cend();
});
spikes.erase(new_end, spikes.end());
}
Expand All @@ -39,15 +39,15 @@ void filterNodeIDSorted(Spikes& spikes, const Selection& node_ids) {
for (const auto& range : node_ids.ranges()) {
const auto begin = std::lower_bound(spikes.begin(),
spikes.end(),
std::make_pair(range.first, 0.),
Spike{range.first, 0.},
[](const Spike& spike1, const Spike& spike2) {
return spike1.first < spike2.first;
return spike1.node_id < spike2.node_id;
});
const auto end = std::upper_bound(spikes.begin(),
spikes.end(),
std::make_pair(range.second - 1, 0.),
Spike{range.second - 1, 0.},
[](const Spike& spike1, const Spike& spike2) {
return spike1.first < spike2.first;
return spike1.node_id < spike2.node_id;
});

std::move(begin, end, std::back_inserter(_spikes));
Expand All @@ -59,24 +59,24 @@ void filterNodeIDSorted(Spikes& spikes, const Selection& node_ids) {
void filterTimestampUnsorted(Spikes& spikes, double tstart, double tstop) {
auto new_end =
std::remove_if(spikes.begin(), spikes.end(), [&tstart, &tstop](const Spike& spike) {
return (spike.second < tstart - EPSILON) || (spike.second > tstop + EPSILON);
return (spike.timestamp < tstart - EPSILON) || (spike.timestamp > tstop + EPSILON);
});
spikes.erase(new_end, spikes.end());
}

void filterTimestampSorted(Spikes& spikes, double tstart, double tstop) {
const auto end = std::upper_bound(spikes.begin(),
spikes.end(),
std::make_pair(0UL, tstop + EPSILON),
Spike{0UL, tstop + EPSILON},
[](const Spike& spike1, const Spike& spike2) {
return spike1.second < spike2.second;
return spike1.timestamp < spike2.timestamp;
});
spikes.erase(end, spikes.end());
const auto begin = std::lower_bound(spikes.begin(),
spikes.end(),
std::make_pair(0UL, tstart - EPSILON),
Spike{0UL, tstart - EPSILON},
[](const Spike& spike1, const Spike& spike2) {
return spike1.second < spike2.second;
return spike1.timestamp < spike2.timestamp;
});
spikes.erase(spikes.begin(), begin);
}
Expand Down Expand Up @@ -153,10 +153,10 @@ SpikeReader::Population::Population(const std::string& filename,
const auto pop_path = std::string("/spikes/") + populationName;
const auto pop = file.getGroup(pop_path);

std::vector<Spike::first_type> node_ids;
std::vector<NodeID> node_ids;
pop.getDataSet("node_ids").read(node_ids);

std::vector<Spike::second_type> timestamps;
std::vector<double> timestamps;
pop.getDataSet("timestamps").read(timestamps);

if (node_ids.size() != timestamps.size()) {
Expand All @@ -168,8 +168,8 @@ SpikeReader::Population::Population(const std::string& filename,
std::make_move_iterator(node_ids.end()),
std::make_move_iterator(timestamps.begin()),
std::back_inserter(spikes_),
[](Spike::first_type&& node_id, Spike::second_type&& timestamp) {
return std::make_pair(std::move(node_id), std::move(timestamp));
[](NodeID&& node_id, double&& timestamp) {
return Spike{std::move(node_id), std::move(timestamp)};
});

if (pop.hasAttribute("sorting")) {
Expand Down
10 changes: 5 additions & 5 deletions tests/test_report_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ TEST_CASE("SpikeReader", "[base]") {
std::vector<std::string>{"All", "empty", "spikes1", "spikes2"});

REQUIRE(reader.openPopulation("All").get(Selection({{3, 4}})) ==
std::vector<std::pair<uint64_t, double>>{{3UL, 0.3}, {3UL, 1.3}});
std::vector<Spike>{{3UL, 0.3}, {3UL, 1.3}});
REQUIRE(reader.openPopulation("spikes1").get(Selection({{3, 4}})) ==
std::vector<std::pair<uint64_t, double>>{{3UL, 0.3}, {3UL, 1.3}});
std::vector<Spike>{{3UL, 0.3}, {3UL, 1.3}});
REQUIRE(reader.openPopulation("spikes2").get(Selection({{3, 4}})) ==
std::vector<std::pair<uint64_t, double>>{{3UL, 0.3}, {3UL, 1.3}});
std::vector<Spike>{{3UL, 0.3}, {3UL, 1.3}});

REQUIRE(reader.openPopulation("All").getSorting() == SpikeReader::Population::Sorting::by_time);
REQUIRE(reader.openPopulation("spikes1").getSorting() ==
Expand All @@ -31,8 +31,8 @@ TEST_CASE("SpikeReader", "[base]") {
SpikeReader::Population::Sorting::none);

REQUIRE(reader.openPopulation("All").get(Selection({{5, 6}}), 0.1, 0.1) ==
std::vector<std::pair<uint64_t, double>>{{5, 0.1}});
REQUIRE(reader.openPopulation("empty").get() == std::vector<std::pair<uint64_t, double>>{});
std::vector<Spike>{{5, 0.1}});
REQUIRE(reader.openPopulation("empty").get() == std::vector<Spike>{});

REQUIRE(reader.openPopulation("All").getTimes() == std::make_tuple(0.1, 1.3));
}
Expand Down

0 comments on commit 88a67a9

Please sign in to comment.