Skip to content

Commit

Permalink
Propagate changes from libsonata ElementReport (#62)
Browse files Browse the repository at this point in the history
* Propagate changes in libsonata ElementReport
* Change tests data for the repeated element ids in frame reports
* Remove libsonata hidden API and add Sonata error
* Adding node_ids to the api, node_ids returned as int64
* Fix the slicer for the trace plotting.
* Remove the fix_libsonata_empty_list
  • Loading branch information
tomdele authored Jul 23, 2020
1 parent 639ac58 commit 1619edd
Show file tree
Hide file tree
Showing 14 changed files with 105 additions and 58 deletions.
11 changes: 8 additions & 3 deletions bluepysnap/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from bluepysnap.sonata_constants import Node
from bluepysnap.utils import roundrobin


L = logging.getLogger(__name__)


Expand Down Expand Up @@ -359,9 +360,13 @@ def frame_trace(filtered_report, plot_type='mean', ax=None): # pragma: no cover
elif plot_type == "all":
max_per_pop = 15
levels = filtered_report.report.columns.levels
slicer = tuple(slice(None) if i != len(levels) - 1 else slice(None, max_per_pop)
for i in range(len(levels)))
data = filtered_report.report.loc[:, slicer].T
slicer = []
# create a slicer that will slice only on the last level of the columns
# that is, node_id for the soma report, element_id for the compartment report
for i, _ in enumerate(levels):
max_ = levels[i][:max_per_pop][-1]
slicer.append(slice(None) if i != len(levels) - 1 else slice(None, max_))
data = filtered_report.report.loc[:, tuple(slicer)].T
# create [[(pop1, id1), (pop1, id2),...], [(pop2, id1), (pop2, id2),...]]
indexes = [[(pop, idx) for idx in data.loc[pop].index] for pop in levels[0]]
# try to keep the maximum of ids from each population
Expand Down
37 changes: 24 additions & 13 deletions bluepysnap/frame_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@
from pathlib2 import Path
import numpy as np
import pandas as pd
from libsonata import ElementReportReader
from libsonata import ElementReportReader, SonataError

import bluepysnap._plotting
from bluepysnap.exceptions import BluepySnapError
from bluepysnap.utils import fix_libsonata_empty_list, ensure_list
from bluepysnap.utils import ensure_list

L = logging.getLogger(__name__)


FORMAT_TO_EXT = {"ASCII": ".txt", "HDF5": ".h5", "BIN": ".bbp"}


Expand Down Expand Up @@ -98,22 +97,36 @@ def get(self, group=None, t_start=None, t_stop=None):
Returns:
pandas.DataFrame: frame as columns indexed by timestamps.
"""
ids = [] if group is None else self._resolve(group).tolist()
t_start = -1 if t_start is None else t_start
t_stop = -1 if t_stop is None else t_stop
ids = self._resolve(group).tolist()
try:
view = self._frame_population.get(node_ids=ids, tstart=t_start, tstop=t_stop)
except SonataError as e:
raise BluepySnapError(e)

view = self._frame_population.get(node_ids=ids, tstart=t_start, tstop=t_stop)
if not view.data:
if len(view.ids) == 0:
return pd.DataFrame()
res = pd.DataFrame(data=view.data, index=view.index)

res = pd.DataFrame(data=view.data,
columns=pd.MultiIndex.from_arrays(np.asarray(view.ids).T),
index=view.times).sort_index(axis=1)

# rename from multi index to index cannot be achieved easily through df.rename
res.columns = self._wrap_columns(res.columns)
res.sort_index(inplace=True)
return res

@cached_property
def node_ids(self):
"""Returns the node ids present in the report.
Returns:
np.Array: Numpy array containing the node_ids included in the report
"""
return np.sort(np.asarray(self._frame_population.get_node_ids(), dtype=np.int64))


class FilteredFrameReport(object):
"""Access to filtered FrameReport data."""

def __init__(self, frame_report, group=None, t_start=None, t_stop=None):
"""Initialize a FilteredFrameReport.
Expand Down Expand Up @@ -237,7 +250,7 @@ def simulation(self):
@cached_property
def population_names(self):
"""Returns the population names included in this report."""
return sorted(self._frame_reader.get_populations_names())
return sorted(self._frame_reader.get_population_names())

@cached_property
def _population_report(self):
Expand Down Expand Up @@ -282,8 +295,6 @@ def nodes(self):

def _resolve(self, group):
"""Transform a group into a node_id array."""
if isinstance(group, (np.ndarray, list, tuple)) and len(group) == 0:
return fix_libsonata_empty_list()
return self.nodes.ids(group=group)


Expand Down
2 changes: 1 addition & 1 deletion bluepysnap/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def positions(self, group=None):
return result.astype(float)

def orientations(self, group=None):
"""Node orientation(s) as a pandas Series or DataFrame.
"""Node orientation(s) as a pandas numpy array or pandas Series.
Args:
group (int/sequence/str/mapping/None): Which nodes will have their positions
Expand Down
29 changes: 19 additions & 10 deletions bluepysnap/spike_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@
from cached_property import cached_property
import pandas as pd
import numpy as np
from libsonata import SpikeReader, SonataError

from bluepysnap.exceptions import BluepySnapError
from bluepysnap.utils import fix_libsonata_empty_list
import bluepysnap._plotting


def _get_reader(spike_report):
from libsonata import SpikeReader
path = str(Path(spike_report.config["output_dir"]) / spike_report.config["spikes_file"])
return SpikeReader(path)

Expand Down Expand Up @@ -86,8 +85,6 @@ def nodes(self):

def _resolve_nodes(self, group):
"""Transform a node group into a node_id array."""
if isinstance(group, (np.ndarray, list, tuple)) and len(group) == 0:
return fix_libsonata_empty_list()
return self.nodes.ids(group=group)

def get(self, group=None, t_start=None, t_stop=None):
Expand All @@ -101,23 +98,35 @@ def get(self, group=None, t_start=None, t_stop=None):
Returns:
pandas.Series: return spiking node_ids indexed by sorted spike time.
"""
node_ids = [] if group is None else self._resolve_nodes(group).tolist()
node_ids = self._resolve_nodes(group).tolist()

t_start = -1 if t_start is None else t_start
t_stop = -1 if t_stop is None else t_stop
series_name = "ids"
res = self._spike_population.get(node_ids=node_ids, tstart=t_start, tstop=t_stop)
try:
res = self._spike_population.get(node_ids=node_ids, tstart=t_start, tstop=t_stop)
except SonataError as e:
raise BluepySnapError(e)

if not res:
return pd.Series(data=[], index=pd.Index([], name="times"), name=series_name)

res = pd.DataFrame(data=res, columns=[series_name, "times"]).set_index("times")[series_name]
if self._sorted_by != "by_time":
res.sort_index(inplace=True)
return res
return res.astype(np.int64)

@cached_property
def node_ids(self):
"""Returns the node ids present in the report.
Returns:
np.Array: Numpy array containing the node_ids included in the report
"""
return np.unique(self.get())


class FilteredSpikeReport(object):
"""Access to filtered SpikeReport data."""

def __init__(self, spike_report, group=None, t_start=None, t_stop=None):
"""Initialize a FilteredSpikeReport.
Expand Down Expand Up @@ -227,7 +236,7 @@ def _spike_reader(self):
@cached_property
def population_names(self):
"""Returns the population names included in this report."""
return sorted(self._spike_reader.get_populations_names())
return sorted(self._spike_reader.get_population_names())

@cached_property
def _population(self):
Expand Down
8 changes: 0 additions & 8 deletions bluepysnap/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,6 @@ def roundrobin(*iterables):
nexts = itertools.cycle(itertools.islice(nexts, num_active))


def fix_libsonata_empty_list():
"""Temporary solution to return empty list from libsonata report readers `.get` functions.
see: https://github.com/BlueBrain/libsonata/issues/84
"""
return np.array([-2])


def add_dynamic_prefix(properties):
"""Add the dynamic prefix to a list of properties."""
return [DYNAMICS_PREFIX + name for name in list(properties)]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, *args, **kwargs):
'cached_property>=1.0',
'functools32;python_version<"3.2"',
'h5py>=2.2',
'libsonata>=0.1.3',
'libsonata>=0.1.4',
'neurom>=1.3',
'numpy>=1.8',
'pandas>=0.17',
Expand Down
Binary file modified tests/data/reporting/compartment_named.h5
Binary file not shown.
7 changes: 4 additions & 3 deletions tests/data/reporting/create_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,16 @@ def write_element_report(filepath):
population_names = ['default', 'default2']
node_ids = np.arange(0, 3)
index_pointers = np.arange(0, 8, 2)
element_ids = np.array([0, 1] * 3)
index_pointers[-1] = index_pointers[-1] + 1
element_ids = np.array([0, 1] * 3 + [1])

times = (0.0, 1, 0.1)

string_dtype = h5py.special_dtype(vlen=str)
with h5py.File(filepath, 'w') as h5f:
h5f.create_group('report')
gpop_element = h5f.create_group('/report/' + population_names[0])
d1 = np.array([np.arange(6) + j*0.1 for j in range(10)])
d1 = np.array([np.arange(7) + j*0.1 for j in range(10)])
ddata = gpop_element.create_dataset('data', data=d1, dtype=np.float32)
ddata.attrs.create('units', data="mV", dtype=string_dtype)
gmapping = h5f.create_group('/report/' + population_names[0] + '/mapping')
Expand All @@ -85,7 +86,7 @@ def write_element_report(filepath):
dtimes.attrs.create('units', data="ms", dtype=string_dtype)

gpop_element2 = h5f.create_group('/report/' + population_names[1])
d1 = np.array([np.arange(6) + j * 0.1 for j in range(10)])
d1 = np.array([np.arange(7) + j * 0.1 for j in range(10)])
ddata = gpop_element2.create_dataset('data', data=d1, dtype=np.float32)
ddata.attrs.create('units', data="mR", dtype=string_dtype)
gmapping = h5f.create_group('/report/' + population_names[1] + '/mapping')
Expand Down
Binary file modified tests/data/reporting/soma_report.h5
Binary file not shown.
Binary file modified tests/data/reporting/spikes.h5
Binary file not shown.
35 changes: 25 additions & 10 deletions tests/test_frame_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ def test_filter(self):
assert filtered.report.columns.tolist() == [("default2", 1, 0), ("default2", 1, 1)]

filtered = self.test_obj.filter(group={"population": "default2"}, t_start=0.3, t_stop=0.6)
assert filtered.report.columns.tolist() == [("default2", 0, 0), ("default2", 0, 1), ("default2", 1, 0), ("default2", 1, 1), ("default2", 2, 0), ("default2", 2, 1)]
assert filtered.report.columns.tolist() == [("default2", 0, 0), ("default2", 0, 1),
("default2", 1, 0), ("default2", 1, 1),
("default2", 2, 0), ("default2", 2, 1),
("default2", 2, 1)]

filtered = self.test_obj.filter(group={"population": "default3"}, t_start=0.3, t_stop=0.6)
pdt.assert_frame_equal(filtered.report, pd.DataFrame())
Expand Down Expand Up @@ -163,11 +166,9 @@ def setup(self):
self.simulation = Simulation(str(TEST_DATA_DIR / 'simulation_config.json'))
self.test_obj = test_module.CompartmentReport(self.simulation, "section_report")["default"]
timestamps = np.linspace(0, 0.9, 10)
data = np.array([np.arange(6) + j * 0.1 for j in range(10)])

data = {(0, 0): data[:, 0], (0, 1): data[:, 1], (1, 0): data[:, 2], (1, 1): data[:, 3],
(2, 0): data[:, 4], (2, 1): data[:, 5]}
self.df = pd.DataFrame(data=data, index=timestamps)
data = np.array([np.arange(7) + j * 0.1 for j in range(10)], dtype=np.float32)
ids = [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (2, 1)]
self.df = pd.DataFrame(data=data, columns=pd.MultiIndex.from_tuples(ids), index=timestamps)

def test__resolve(self):
npt.assert_array_equal(self.test_obj._resolve({Cell.MTYPE: "L6_Y"}), [1, 2])
Expand Down Expand Up @@ -213,8 +214,6 @@ def test_get(self):
pdt.assert_frame_equal(
self.test_obj.get([2, 1], t_start=0.2, t_stop=0.8), self.df.iloc[2:9].loc[:, [1, 2]])

pdt.assert_frame_equal(self.test_obj.get([0, 2], t_start=15), pd.DataFrame())

pdt.assert_frame_equal(
self.test_obj.get(group={Cell.MTYPE: "L6_Y"}, t_start=0.2, t_stop=0.8),
self.df.iloc[2:9].loc[:, [1, 2]])
Expand All @@ -225,12 +224,28 @@ def test_get(self):
pdt.assert_frame_equal(
self.test_obj.get(group="Layer23"), self.df.loc[:, [0]])

with pytest.raises(BluepySnapError):
self.test_obj.get(-1, t_start=0.2)

with pytest.raises(BluepySnapError):
self.test_obj.get(0, t_start=-1)

with pytest.raises(BluepySnapError):
self.test_obj.get([0, 2], t_start=15)

with pytest.raises(BluepySnapError):
self.test_obj.get(4)

def test_get_partially_not_in_report(self):
with patch.object(self.test_obj.__class__, "_resolve", return_value=np.asarray([0, 4])):
pdt.assert_frame_equal(self.test_obj.get([0, 4]), self.df.loc[:, [0]])

def test_get_not_in_report(self):
with patch.object(self.test_obj.__class__, "_resolve", return_value=np.asarray([4])):
pdt.assert_frame_equal(self.test_obj.get(4), pd.DataFrame())
pdt.assert_frame_equal(self.test_obj.get([4]), pd.DataFrame())

def test_node_ids(self):
npt.assert_array_equal(self.test_obj.node_ids, np.array(sorted([0, 1, 2]), dtype=np.int64))


class TestPopulationSomaReport(TestPopulationCompartmentReport):
Expand All @@ -239,4 +254,4 @@ def setup(self):
self.test_obj = test_module.SomaReport(self.simulation, "soma_report")["default"]
timestamps = np.linspace(0, 0.9, 10)
data = {0: timestamps, 1: timestamps + 1, 2: timestamps + 2}
self.df = pd.DataFrame(data=data, index=timestamps, columns=[0, 1, 2])
self.df = pd.DataFrame(data=data, index=timestamps, columns=[0, 1, 2]).astype(np.float32)
1 change: 0 additions & 1 deletion tests/test_morph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import bluepysnap.morph as test_module
from bluepysnap.circuit import Circuit
from bluepysnap.nodes import NodeStorage
from bluepysnap.sonata_constants import Node
from bluepysnap.exceptions import BluepySnapError

Expand Down
27 changes: 23 additions & 4 deletions tests/test_spike_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
def _create_series(node_ids, index, name="ids"):
def _get_index(ids):
return pd.Index(ids, name="times")

return pd.Series(node_ids, index=_get_index(index), name=name)


Expand Down Expand Up @@ -132,8 +133,10 @@ def test_get(self):
pdt.assert_series_equal(self.test_obj.get(2), _create_series([2, 2], [0.1, 0.7]))
pdt.assert_series_equal(self.test_obj.get(0, t_start=1.), _create_series([0], [1.3]))
pdt.assert_series_equal(self.test_obj.get(0, t_stop=1.), _create_series([0], [0.2]))
pdt.assert_series_equal(self.test_obj.get(0, t_start=1., t_stop=12), _create_series([0], [1.3]))
pdt.assert_series_equal(self.test_obj.get(0, t_start=0.1, t_stop=12), _create_series([0, 0], [0.2, 1.3]))
pdt.assert_series_equal(self.test_obj.get(0, t_start=1., t_stop=12),
_create_series([0], [1.3]))
pdt.assert_series_equal(self.test_obj.get(0, t_start=0.1, t_stop=12),
_create_series([0, 0], [0.2, 1.3]))

pdt.assert_series_equal(self.test_obj.get([2, 0]),
_create_series([2, 0, 2, 0], [0.1, 0.2, 0.7, 1.3]))
Expand All @@ -157,8 +160,6 @@ def test_get(self):
pdt.assert_series_equal(self.test_obj.get([2, 1], t_start=0.2, t_stop=0.8),
_create_series([1, 2], [0.3, 0.7]))

pdt.assert_series_equal(self.test_obj.get([0, 2], t_start=12), _create_series([], []))

pdt.assert_series_equal(
self.test_obj.get(group={Cell.MTYPE: "L6_Y"}, t_start=0.2, t_stop=0.8),
_create_series([1, 2], [0.3, 0.7]))
Expand All @@ -169,6 +170,15 @@ def test_get(self):
pdt.assert_series_equal(
self.test_obj.get(group="Layer23"), _create_series([0, 0], [0.2, 1.3]))

with pytest.raises(BluepySnapError):
self.test_obj.get([-1], t_start=0.2)

with pytest.raises(BluepySnapError):
self.test_obj.get([0, 2], t_start=-1)

with pytest.raises(BluepySnapError):
self.test_obj.get([0, 2], t_start=12)

with pytest.raises(BluepySnapError):
self.test_obj.get(4)

Expand All @@ -188,3 +198,12 @@ def test_get2(self):
def test_get_not_in_report(self, mock):
pdt.assert_series_equal(self.test_obj.get(4),
_create_series([], []))

@patch(test_module.__name__ + '.PopulationSpikeReport._resolve_nodes',
return_value=np.asarray([0, 4]))
def test_get_not_in_report(self, mock):
pdt.assert_series_equal(self.test_obj.get([0, 4]),
_create_series([0, 0], [0.2, 1.3]))

def test_node_ids(self):
npt.assert_array_equal(self.test_obj.node_ids, np.array(sorted([0, 1, 2]), dtype=np.int64))
4 changes: 0 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ def test_roundrobin():
assert list(test_module.roundrobin(*a)) == [1, 4, 5, 2, 6, 3]


def test_fix_libsonata_empty_list():
npt.assert_array_equal(test_module.fix_libsonata_empty_list(), np.array([-2]))


def test_add_dynamic_prefix():
assert test_module.add_dynamic_prefix(["a", "b"]) == [DYNAMICS_PREFIX + "a",
DYNAMICS_PREFIX + "b"]
Expand Down

0 comments on commit 1619edd

Please sign in to comment.