Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate changes from libsonata ElementReport #62

Merged
merged 10 commits into from
Jul 23, 2020
11 changes: 8 additions & 3 deletions bluepysnap/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from bluepysnap.sonata_constants import Node
from bluepysnap.utils import roundrobin


L = logging.getLogger(__name__)


Expand Down Expand Up @@ -359,9 +360,13 @@ def frame_trace(filtered_report, plot_type='mean', ax=None): # pragma: no cover
elif plot_type == "all":
max_per_pop = 15
levels = filtered_report.report.columns.levels
slicer = tuple(slice(None) if i != len(levels) - 1 else slice(None, max_per_pop)
for i in range(len(levels)))
data = filtered_report.report.loc[:, slicer].T
slicer = []
# create a slicer that will slice only on the last level of the columns
# that is, node_id for the soma report, element_id for the compartment report
for i, _ in enumerate(levels):
max_ = levels[i][:max_per_pop][-1]
slicer.append(slice(None) if i != len(levels) - 1 else slice(None, max_))
data = filtered_report.report.loc[:, tuple(slicer)].T
# create [[(pop1, id1), (pop1, id2),...], [(pop2, id1), (pop2, id2),...]]
indexes = [[(pop, idx) for idx in data.loc[pop].index] for pop in levels[0]]
# try to keep the maximum of ids from each population
Expand Down
37 changes: 24 additions & 13 deletions bluepysnap/frame_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@
from pathlib2 import Path
import numpy as np
import pandas as pd
from libsonata import ElementReportReader
from libsonata import ElementReportReader, SonataError

import bluepysnap._plotting
from bluepysnap.exceptions import BluepySnapError
from bluepysnap.utils import fix_libsonata_empty_list, ensure_list
from bluepysnap.utils import ensure_list

L = logging.getLogger(__name__)


FORMAT_TO_EXT = {"ASCII": ".txt", "HDF5": ".h5", "BIN": ".bbp"}


Expand Down Expand Up @@ -98,22 +97,36 @@ def get(self, group=None, t_start=None, t_stop=None):
Returns:
pandas.DataFrame: frame as columns indexed by timestamps.
"""
ids = [] if group is None else self._resolve(group).tolist()
t_start = -1 if t_start is None else t_start
t_stop = -1 if t_stop is None else t_stop
ids = self._resolve(group).tolist()
try:
view = self._frame_population.get(node_ids=ids, tstart=t_start, tstop=t_stop)
except SonataError as e:
raise BluepySnapError(e)

view = self._frame_population.get(node_ids=ids, tstart=t_start, tstop=t_stop)
if not view.data:
if len(view.ids) == 0:
return pd.DataFrame()
res = pd.DataFrame(data=view.data, index=view.index)

res = pd.DataFrame(data=view.data,
columns=pd.MultiIndex.from_arrays(np.asarray(view.ids).T),
index=view.times).sort_index(axis=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

view.times are always guaranteed to be present?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes !
the view always contains the 3 objects


# rename from multi index to index cannot be achieved easily through df.rename
res.columns = self._wrap_columns(res.columns)
res.sort_index(inplace=True)
return res

@cached_property
def node_ids(self):
"""Returns the node ids present in the report.

Returns:
np.Array: Numpy array containing the node_ids included in the report
"""
return np.sort(np.asarray(self._frame_population.get_node_ids(), dtype=np.int64))


class FilteredFrameReport(object):
"""Access to filtered FrameReport data."""

def __init__(self, frame_report, group=None, t_start=None, t_stop=None):
"""Initialize a FilteredFrameReport.

Expand Down Expand Up @@ -237,7 +250,7 @@ def simulation(self):
@cached_property
def population_names(self):
"""Returns the population names included in this report."""
return sorted(self._frame_reader.get_populations_names())
return sorted(self._frame_reader.get_population_names())

@cached_property
def _population_report(self):
Expand Down Expand Up @@ -282,8 +295,6 @@ def nodes(self):

def _resolve(self, group):
"""Transform a group into a node_id array."""
if isinstance(group, (np.ndarray, list, tuple)) and len(group) == 0:
return fix_libsonata_empty_list()
return self.nodes.ids(group=group)


Expand Down
2 changes: 1 addition & 1 deletion bluepysnap/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def positions(self, group=None):
return result.astype(float)

def orientations(self, group=None):
"""Node orientation(s) as a pandas Series or DataFrame.
"""Node orientation(s) as a pandas numpy array or pandas Series.

Args:
group (int/sequence/str/mapping/None): Which nodes will have their positions
Expand Down
29 changes: 19 additions & 10 deletions bluepysnap/spike_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@
from cached_property import cached_property
import pandas as pd
import numpy as np
from libsonata import SpikeReader, SonataError

from bluepysnap.exceptions import BluepySnapError
from bluepysnap.utils import fix_libsonata_empty_list
import bluepysnap._plotting


def _get_reader(spike_report):
from libsonata import SpikeReader
path = str(Path(spike_report.config["output_dir"]) / spike_report.config["spikes_file"])
return SpikeReader(path)

Expand Down Expand Up @@ -86,8 +85,6 @@ def nodes(self):

def _resolve_nodes(self, group):
"""Transform a node group into a node_id array."""
if isinstance(group, (np.ndarray, list, tuple)) and len(group) == 0:
return fix_libsonata_empty_list()
return self.nodes.ids(group=group)

def get(self, group=None, t_start=None, t_stop=None):
Expand All @@ -101,23 +98,35 @@ def get(self, group=None, t_start=None, t_stop=None):
Returns:
pandas.Series: return spiking node_ids indexed by sorted spike time.
"""
node_ids = [] if group is None else self._resolve_nodes(group).tolist()
node_ids = self._resolve_nodes(group).tolist()

t_start = -1 if t_start is None else t_start
t_stop = -1 if t_stop is None else t_stop
series_name = "ids"
res = self._spike_population.get(node_ids=node_ids, tstart=t_start, tstop=t_stop)
try:
res = self._spike_population.get(node_ids=node_ids, tstart=t_start, tstop=t_stop)
except SonataError as e:
raise BluepySnapError(e)

if not res:
return pd.Series(data=[], index=pd.Index([], name="times"), name=series_name)

res = pd.DataFrame(data=res, columns=[series_name, "times"]).set_index("times")[series_name]
if self._sorted_by != "by_time":
res.sort_index(inplace=True)
return res
return res.astype(np.int64)

@cached_property
def node_ids(self):
"""Returns the node ids present in the report.

Returns:
np.Array: Numpy array containing the node_ids included in the report
"""
return np.unique(self.get())


class FilteredSpikeReport(object):
"""Access to filtered SpikeReport data."""

def __init__(self, spike_report, group=None, t_start=None, t_stop=None):
"""Initialize a FilteredSpikeReport.

Expand Down Expand Up @@ -227,7 +236,7 @@ def _spike_reader(self):
@cached_property
def population_names(self):
"""Returns the population names included in this report."""
return sorted(self._spike_reader.get_populations_names())
return sorted(self._spike_reader.get_population_names())

@cached_property
def _population(self):
Expand Down
8 changes: 0 additions & 8 deletions bluepysnap/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,6 @@ def roundrobin(*iterables):
nexts = itertools.cycle(itertools.islice(nexts, num_active))


def fix_libsonata_empty_list():
"""Temporary solution to return empty list from libsonata report readers `.get` functions.

see: https://github.com/BlueBrain/libsonata/issues/84
"""
return np.array([-2])


def add_dynamic_prefix(properties):
"""Add the dynamic prefix to a list of properties."""
return [DYNAMICS_PREFIX + name for name in list(properties)]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, *args, **kwargs):
'cached_property>=1.0',
'functools32;python_version<"3.2"',
'h5py>=2.2',
'libsonata>=0.1.3',
'libsonata>=0.1.4',
'neurom>=1.3',
'numpy>=1.8',
'pandas>=0.17',
Expand Down
Binary file modified tests/data/reporting/compartment_named.h5
Binary file not shown.
7 changes: 4 additions & 3 deletions tests/data/reporting/create_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,16 @@ def write_element_report(filepath):
population_names = ['default', 'default2']
node_ids = np.arange(0, 3)
index_pointers = np.arange(0, 8, 2)
element_ids = np.array([0, 1] * 3)
index_pointers[-1] = index_pointers[-1] + 1
element_ids = np.array([0, 1] * 3 + [1])

times = (0.0, 1, 0.1)

string_dtype = h5py.special_dtype(vlen=str)
with h5py.File(filepath, 'w') as h5f:
h5f.create_group('report')
gpop_element = h5f.create_group('/report/' + population_names[0])
d1 = np.array([np.arange(6) + j*0.1 for j in range(10)])
d1 = np.array([np.arange(7) + j*0.1 for j in range(10)])
ddata = gpop_element.create_dataset('data', data=d1, dtype=np.float32)
ddata.attrs.create('units', data="mV", dtype=string_dtype)
gmapping = h5f.create_group('/report/' + population_names[0] + '/mapping')
Expand All @@ -85,7 +86,7 @@ def write_element_report(filepath):
dtimes.attrs.create('units', data="ms", dtype=string_dtype)

gpop_element2 = h5f.create_group('/report/' + population_names[1])
d1 = np.array([np.arange(6) + j * 0.1 for j in range(10)])
d1 = np.array([np.arange(7) + j * 0.1 for j in range(10)])
ddata = gpop_element2.create_dataset('data', data=d1, dtype=np.float32)
ddata.attrs.create('units', data="mR", dtype=string_dtype)
gmapping = h5f.create_group('/report/' + population_names[1] + '/mapping')
Expand Down
Binary file modified tests/data/reporting/soma_report.h5
Binary file not shown.
Binary file modified tests/data/reporting/spikes.h5
Binary file not shown.
35 changes: 25 additions & 10 deletions tests/test_frame_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ def test_filter(self):
assert filtered.report.columns.tolist() == [("default2", 1, 0), ("default2", 1, 1)]

filtered = self.test_obj.filter(group={"population": "default2"}, t_start=0.3, t_stop=0.6)
assert filtered.report.columns.tolist() == [("default2", 0, 0), ("default2", 0, 1), ("default2", 1, 0), ("default2", 1, 1), ("default2", 2, 0), ("default2", 2, 1)]
assert filtered.report.columns.tolist() == [("default2", 0, 0), ("default2", 0, 1),
("default2", 1, 0), ("default2", 1, 1),
("default2", 2, 0), ("default2", 2, 1),
("default2", 2, 1)]

filtered = self.test_obj.filter(group={"population": "default3"}, t_start=0.3, t_stop=0.6)
pdt.assert_frame_equal(filtered.report, pd.DataFrame())
Expand Down Expand Up @@ -163,11 +166,9 @@ def setup(self):
self.simulation = Simulation(str(TEST_DATA_DIR / 'simulation_config.json'))
self.test_obj = test_module.CompartmentReport(self.simulation, "section_report")["default"]
timestamps = np.linspace(0, 0.9, 10)
data = np.array([np.arange(6) + j * 0.1 for j in range(10)])

data = {(0, 0): data[:, 0], (0, 1): data[:, 1], (1, 0): data[:, 2], (1, 1): data[:, 3],
(2, 0): data[:, 4], (2, 1): data[:, 5]}
self.df = pd.DataFrame(data=data, index=timestamps)
data = np.array([np.arange(7) + j * 0.1 for j in range(10)], dtype=np.float32)
ids = [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (2, 1)]
self.df = pd.DataFrame(data=data, columns=pd.MultiIndex.from_tuples(ids), index=timestamps)

def test__resolve(self):
npt.assert_array_equal(self.test_obj._resolve({Cell.MTYPE: "L6_Y"}), [1, 2])
Expand Down Expand Up @@ -213,8 +214,6 @@ def test_get(self):
pdt.assert_frame_equal(
self.test_obj.get([2, 1], t_start=0.2, t_stop=0.8), self.df.iloc[2:9].loc[:, [1, 2]])

pdt.assert_frame_equal(self.test_obj.get([0, 2], t_start=15), pd.DataFrame())

pdt.assert_frame_equal(
self.test_obj.get(group={Cell.MTYPE: "L6_Y"}, t_start=0.2, t_stop=0.8),
self.df.iloc[2:9].loc[:, [1, 2]])
Expand All @@ -225,12 +224,28 @@ def test_get(self):
pdt.assert_frame_equal(
self.test_obj.get(group="Layer23"), self.df.loc[:, [0]])

with pytest.raises(BluepySnapError):
self.test_obj.get(-1, t_start=0.2)

with pytest.raises(BluepySnapError):
self.test_obj.get(0, t_start=-1)

with pytest.raises(BluepySnapError):
self.test_obj.get([0, 2], t_start=15)

with pytest.raises(BluepySnapError):
self.test_obj.get(4)

def test_get_partially_not_in_report(self):
with patch.object(self.test_obj.__class__, "_resolve", return_value=np.asarray([0, 4])):
pdt.assert_frame_equal(self.test_obj.get([0, 4]), self.df.loc[:, [0]])

def test_get_not_in_report(self):
with patch.object(self.test_obj.__class__, "_resolve", return_value=np.asarray([4])):
pdt.assert_frame_equal(self.test_obj.get(4), pd.DataFrame())
pdt.assert_frame_equal(self.test_obj.get([4]), pd.DataFrame())

def test_node_ids(self):
npt.assert_array_equal(self.test_obj.node_ids, np.array(sorted([0, 1, 2]), dtype=np.int64))


class TestPopulationSomaReport(TestPopulationCompartmentReport):
Expand All @@ -239,4 +254,4 @@ def setup(self):
self.test_obj = test_module.SomaReport(self.simulation, "soma_report")["default"]
timestamps = np.linspace(0, 0.9, 10)
data = {0: timestamps, 1: timestamps + 1, 2: timestamps + 2}
self.df = pd.DataFrame(data=data, index=timestamps, columns=[0, 1, 2])
self.df = pd.DataFrame(data=data, index=timestamps, columns=[0, 1, 2]).astype(np.float32)
1 change: 0 additions & 1 deletion tests/test_morph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import bluepysnap.morph as test_module
from bluepysnap.circuit import Circuit
from bluepysnap.nodes import NodeStorage
from bluepysnap.sonata_constants import Node
from bluepysnap.exceptions import BluepySnapError

Expand Down
27 changes: 23 additions & 4 deletions tests/test_spike_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
def _create_series(node_ids, index, name="ids"):
def _get_index(ids):
return pd.Index(ids, name="times")

return pd.Series(node_ids, index=_get_index(index), name=name)


Expand Down Expand Up @@ -132,8 +133,10 @@ def test_get(self):
pdt.assert_series_equal(self.test_obj.get(2), _create_series([2, 2], [0.1, 0.7]))
pdt.assert_series_equal(self.test_obj.get(0, t_start=1.), _create_series([0], [1.3]))
pdt.assert_series_equal(self.test_obj.get(0, t_stop=1.), _create_series([0], [0.2]))
pdt.assert_series_equal(self.test_obj.get(0, t_start=1., t_stop=12), _create_series([0], [1.3]))
pdt.assert_series_equal(self.test_obj.get(0, t_start=0.1, t_stop=12), _create_series([0, 0], [0.2, 1.3]))
pdt.assert_series_equal(self.test_obj.get(0, t_start=1., t_stop=12),
_create_series([0], [1.3]))
pdt.assert_series_equal(self.test_obj.get(0, t_start=0.1, t_stop=12),
_create_series([0, 0], [0.2, 1.3]))

pdt.assert_series_equal(self.test_obj.get([2, 0]),
_create_series([2, 0, 2, 0], [0.1, 0.2, 0.7, 1.3]))
Expand All @@ -157,8 +160,6 @@ def test_get(self):
pdt.assert_series_equal(self.test_obj.get([2, 1], t_start=0.2, t_stop=0.8),
_create_series([1, 2], [0.3, 0.7]))

pdt.assert_series_equal(self.test_obj.get([0, 2], t_start=12), _create_series([], []))

pdt.assert_series_equal(
self.test_obj.get(group={Cell.MTYPE: "L6_Y"}, t_start=0.2, t_stop=0.8),
_create_series([1, 2], [0.3, 0.7]))
Expand All @@ -169,6 +170,15 @@ def test_get(self):
pdt.assert_series_equal(
self.test_obj.get(group="Layer23"), _create_series([0, 0], [0.2, 1.3]))

with pytest.raises(BluepySnapError):
self.test_obj.get([-1], t_start=0.2)

with pytest.raises(BluepySnapError):
self.test_obj.get([0, 2], t_start=-1)

with pytest.raises(BluepySnapError):
self.test_obj.get([0, 2], t_start=12)

with pytest.raises(BluepySnapError):
self.test_obj.get(4)

Expand All @@ -188,3 +198,12 @@ def test_get2(self):
def test_get_not_in_report(self, mock):
pdt.assert_series_equal(self.test_obj.get(4),
_create_series([], []))

@patch(test_module.__name__ + '.PopulationSpikeReport._resolve_nodes',
return_value=np.asarray([0, 4]))
def test_get_not_in_report(self, mock):
pdt.assert_series_equal(self.test_obj.get([0, 4]),
_create_series([0, 0], [0.2, 1.3]))

def test_node_ids(self):
npt.assert_array_equal(self.test_obj.node_ids, np.array(sorted([0, 1, 2]), dtype=np.int64))
4 changes: 0 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ def test_roundrobin():
assert list(test_module.roundrobin(*a)) == [1, 4, 5, 2, 6, 3]


def test_fix_libsonata_empty_list():
npt.assert_array_equal(test_module.fix_libsonata_empty_list(), np.array([-2]))


def test_add_dynamic_prefix():
assert test_module.add_dynamic_prefix(["a", "b"]) == [DYNAMICS_PREFIX + "a",
DYNAMICS_PREFIX + "b"]
Expand Down