From 1619edd243d5edd6544a17206aa2f997604b7382 Mon Sep 17 00:00:00 2001 From: tomdele <42337286+tomdele@users.noreply.github.com> Date: Thu, 23 Jul 2020 14:48:01 +0200 Subject: [PATCH] Propagate changes from libsonata ElementReport (#62) * Propagate changes in libsonata ElementReport * Change tests data for the repeated element ids in frame reports * Remove libsonata hidden API and add Sonata error * Adding node_ids to the api, node_ids returned as int64 * Fix the slicer for the trace plotting. * Remove the fix_libsonata_empty_list --- bluepysnap/_plotting.py | 11 +++++-- bluepysnap/frame_report.py | 37 ++++++++++++++-------- bluepysnap/nodes.py | 2 +- bluepysnap/spike_report.py | 29 +++++++++++------ bluepysnap/utils.py | 8 ----- setup.py | 2 +- tests/data/reporting/compartment_named.h5 | Bin 15216 -> 15216 bytes tests/data/reporting/create_reports.py | 7 ++-- tests/data/reporting/soma_report.h5 | Bin 15216 -> 15216 bytes tests/data/reporting/spikes.h5 | Bin 7648 -> 7648 bytes tests/test_frame_report.py | 35 ++++++++++++++------ tests/test_morph.py | 1 - tests/test_spike_report.py | 27 +++++++++++++--- tests/test_utils.py | 4 --- 14 files changed, 105 insertions(+), 58 deletions(-) diff --git a/bluepysnap/_plotting.py b/bluepysnap/_plotting.py index 5bb5fc22..a5e0c0b3 100644 --- a/bluepysnap/_plotting.py +++ b/bluepysnap/_plotting.py @@ -23,6 +23,7 @@ from bluepysnap.sonata_constants import Node from bluepysnap.utils import roundrobin + L = logging.getLogger(__name__) @@ -359,9 +360,13 @@ def frame_trace(filtered_report, plot_type='mean', ax=None): # pragma: no cover elif plot_type == "all": max_per_pop = 15 levels = filtered_report.report.columns.levels - slicer = tuple(slice(None) if i != len(levels) - 1 else slice(None, max_per_pop) - for i in range(len(levels))) - data = filtered_report.report.loc[:, slicer].T + slicer = [] + # create a slicer that will slice only on the last level of the columns + # that is, node_id for the soma report, element_id for the compartment report + for i, _ in enumerate(levels): + max_ = levels[i][:max_per_pop][-1] + slicer.append(slice(None) if i != len(levels) - 1 else slice(None, max_)) + data = filtered_report.report.loc[:, tuple(slicer)].T # create [[(pop1, id1), (pop1, id2),...], [(pop2, id1), (pop2, id2),...]] indexes = [[(pop, idx) for idx in data.loc[pop].index] for pop in levels[0]] # try to keep the maximum of ids from each population diff --git a/bluepysnap/frame_report.py b/bluepysnap/frame_report.py index 35f1743f..6d40a044 100644 --- a/bluepysnap/frame_report.py +++ b/bluepysnap/frame_report.py @@ -21,15 +21,14 @@ from pathlib2 import Path import numpy as np import pandas as pd -from libsonata import ElementReportReader +from libsonata import ElementReportReader, SonataError import bluepysnap._plotting from bluepysnap.exceptions import BluepySnapError -from bluepysnap.utils import fix_libsonata_empty_list, ensure_list +from bluepysnap.utils import ensure_list L = logging.getLogger(__name__) - FORMAT_TO_EXT = {"ASCII": ".txt", "HDF5": ".h5", "BIN": ".bbp"} @@ -98,22 +97,36 @@ def get(self, group=None, t_start=None, t_stop=None): Returns: pandas.DataFrame: frame as columns indexed by timestamps. """ - ids = [] if group is None else self._resolve(group).tolist() - t_start = -1 if t_start is None else t_start - t_stop = -1 if t_stop is None else t_stop + ids = self._resolve(group).tolist() + try: + view = self._frame_population.get(node_ids=ids, tstart=t_start, tstop=t_stop) + except SonataError as e: + raise BluepySnapError(e) - view = self._frame_population.get(node_ids=ids, tstart=t_start, tstop=t_stop) - if not view.data: + if len(view.ids) == 0: return pd.DataFrame() - res = pd.DataFrame(data=view.data, index=view.index) + + res = pd.DataFrame(data=view.data, + columns=pd.MultiIndex.from_arrays(np.asarray(view.ids).T), + index=view.times).sort_index(axis=1) + # rename from multi index to index cannot be achieved easily through df.rename res.columns = self._wrap_columns(res.columns) - res.sort_index(inplace=True) return res + @cached_property + def node_ids(self): + """Returns the node ids present in the report. + + Returns: + np.Array: Numpy array containing the node_ids included in the report + """ + return np.sort(np.asarray(self._frame_population.get_node_ids(), dtype=np.int64)) + class FilteredFrameReport(object): """Access to filtered FrameReport data.""" + def __init__(self, frame_report, group=None, t_start=None, t_stop=None): """Initialize a FilteredFrameReport. @@ -237,7 +250,7 @@ def simulation(self): @cached_property def population_names(self): """Returns the population names included in this report.""" - return sorted(self._frame_reader.get_populations_names()) + return sorted(self._frame_reader.get_population_names()) @cached_property def _population_report(self): @@ -282,8 +295,6 @@ def nodes(self): def _resolve(self, group): """Transform a group into a node_id array.""" - if isinstance(group, (np.ndarray, list, tuple)) and len(group) == 0: - return fix_libsonata_empty_list() return self.nodes.ids(group=group) diff --git a/bluepysnap/nodes.py b/bluepysnap/nodes.py index 5dfb8f9b..10401866 100644 --- a/bluepysnap/nodes.py +++ b/bluepysnap/nodes.py @@ -505,7 +505,7 @@ def positions(self, group=None): return result.astype(float) def orientations(self, group=None): - """Node orientation(s) as a pandas Series or DataFrame. + """Node orientation(s) as a pandas numpy array or pandas Series. Args: group (int/sequence/str/mapping/None): Which nodes will have their positions diff --git a/bluepysnap/spike_report.py b/bluepysnap/spike_report.py index 71a35871..165e4edc 100644 --- a/bluepysnap/spike_report.py +++ b/bluepysnap/spike_report.py @@ -22,14 +22,13 @@ from cached_property import cached_property import pandas as pd import numpy as np +from libsonata import SpikeReader, SonataError from bluepysnap.exceptions import BluepySnapError -from bluepysnap.utils import fix_libsonata_empty_list import bluepysnap._plotting def _get_reader(spike_report): - from libsonata import SpikeReader path = str(Path(spike_report.config["output_dir"]) / spike_report.config["spikes_file"]) return SpikeReader(path) @@ -86,8 +85,6 @@ def nodes(self): def _resolve_nodes(self, group): """Transform a node group into a node_id array.""" - if isinstance(group, (np.ndarray, list, tuple)) and len(group) == 0: - return fix_libsonata_empty_list() return self.nodes.ids(group=group) def get(self, group=None, t_start=None, t_stop=None): @@ -101,23 +98,35 @@ def get(self, group=None, t_start=None, t_stop=None): Returns: pandas.Series: return spiking node_ids indexed by sorted spike time. """ - node_ids = [] if group is None else self._resolve_nodes(group).tolist() + node_ids = self._resolve_nodes(group).tolist() - t_start = -1 if t_start is None else t_start - t_stop = -1 if t_stop is None else t_stop series_name = "ids" - res = self._spike_population.get(node_ids=node_ids, tstart=t_start, tstop=t_stop) + try: + res = self._spike_population.get(node_ids=node_ids, tstart=t_start, tstop=t_stop) + except SonataError as e: + raise BluepySnapError(e) + if not res: return pd.Series(data=[], index=pd.Index([], name="times"), name=series_name) res = pd.DataFrame(data=res, columns=[series_name, "times"]).set_index("times")[series_name] if self._sorted_by != "by_time": res.sort_index(inplace=True) - return res + return res.astype(np.int64) + + @cached_property + def node_ids(self): + """Returns the node ids present in the report. + + Returns: + np.Array: Numpy array containing the node_ids included in the report + """ + return np.unique(self.get()) class FilteredSpikeReport(object): """Access to filtered SpikeReport data.""" + def __init__(self, spike_report, group=None, t_start=None, t_stop=None): """Initialize a FilteredSpikeReport. @@ -227,7 +236,7 @@ def _spike_reader(self): @cached_property def population_names(self): """Returns the population names included in this report.""" - return sorted(self._spike_reader.get_populations_names()) + return sorted(self._spike_reader.get_population_names()) @cached_property def _population(self): diff --git a/bluepysnap/utils.py b/bluepysnap/utils.py index c3de98c8..1f6fa155 100644 --- a/bluepysnap/utils.py +++ b/bluepysnap/utils.py @@ -72,14 +72,6 @@ def roundrobin(*iterables): nexts = itertools.cycle(itertools.islice(nexts, num_active)) -def fix_libsonata_empty_list(): - """Temporary solution to return empty list from libsonata report readers `.get` functions. - - see: https://github.com/BlueBrain/libsonata/issues/84 - """ - return np.array([-2]) - - def add_dynamic_prefix(properties): """Add the dynamic prefix to a list of properties.""" return [DYNAMICS_PREFIX + name for name in list(properties)] diff --git a/setup.py b/setup.py index 55fb7650..d9bb179e 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ def __init__(self, *args, **kwargs): 'cached_property>=1.0', 'functools32;python_version<"3.2"', 'h5py>=2.2', - 'libsonata>=0.1.3', + 'libsonata>=0.1.4', 'neurom>=1.3', 'numpy>=1.8', 'pandas>=0.17', diff --git a/tests/data/reporting/compartment_named.h5 b/tests/data/reporting/compartment_named.h5 index adfab9ffea0f649c72625d7a765631a9a8f8a019..1354e61e9bd8bb85d855003502af8aaf2c3cd7e5 100644 GIT binary patch delta 487 zcmexR_MvP;1h*hN0|anEY4*v6hK^hgJP@t~<76Ib31+>GFD9So?qp_QI52rVuOzdv z@!`q&Jd(_5X~!nt2l8jlJUMwjkbm~fnaTWol0bDACZFe#X1qMv9>_~eyEgehPN(vx7($3!}m0 z)e?#zFKs?AVa&wnF_~FfksZj5V_;~QxUqHf2I&m~oUl+~n_MU%Jh?_kjWGkH4N2YP z0NpIc2Dn-Gbsd>;gyjrg%&>Ibe8(`4g{gyY@)`*Ruw(MgjG34g@J(isRzS98fs_KK WW0($r^&zVRx`ycn+_cH>tr`J$36aGB diff --git a/tests/data/reporting/create_reports.py b/tests/data/reporting/create_reports.py index bf9d9e28..e10961b9 100644 --- a/tests/data/reporting/create_reports.py +++ b/tests/data/reporting/create_reports.py @@ -64,7 +64,8 @@ def write_element_report(filepath): population_names = ['default', 'default2'] node_ids = np.arange(0, 3) index_pointers = np.arange(0, 8, 2) - element_ids = np.array([0, 1] * 3) + index_pointers[-1] = index_pointers[-1] + 1 + element_ids = np.array([0, 1] * 3 + [1]) times = (0.0, 1, 0.1) @@ -72,7 +73,7 @@ def write_element_report(filepath): with h5py.File(filepath, 'w') as h5f: h5f.create_group('report') gpop_element = h5f.create_group('/report/' + population_names[0]) - d1 = np.array([np.arange(6) + j*0.1 for j in range(10)]) + d1 = np.array([np.arange(7) + j*0.1 for j in range(10)]) ddata = gpop_element.create_dataset('data', data=d1, dtype=np.float32) ddata.attrs.create('units', data="mV", dtype=string_dtype) gmapping = h5f.create_group('/report/' + population_names[0] + '/mapping') @@ -85,7 +86,7 @@ def write_element_report(filepath): dtimes.attrs.create('units', data="ms", dtype=string_dtype) gpop_element2 = h5f.create_group('/report/' + population_names[1]) - d1 = np.array([np.arange(6) + j * 0.1 for j in range(10)]) + d1 = np.array([np.arange(7) + j * 0.1 for j in range(10)]) ddata = gpop_element2.create_dataset('data', data=d1, dtype=np.float32) ddata.attrs.create('units', data="mR", dtype=string_dtype) gmapping = h5f.create_group('/report/' + population_names[1] + '/mapping') diff --git a/tests/data/reporting/soma_report.h5 b/tests/data/reporting/soma_report.h5 index 750be1a9002c1275c3c8143283d87e6f9be7b116..17ccc6d0be9868b9639da4129b050bb0dd0db7b6 100644 GIT binary patch delta 125 zcmexR_MvRU8E$61jW0Hz=l;tBVMv5Bu>(co7#JERZfxDWL0Uu*i9b0&w+F25zV23L k5M%OrZr9Cs3@5OFMf1&2bWaYjY(dtXUpF delta 125 zcmexR_MvRU8E$6&J?WdzbN}UmFeE~m*nuK(3=9nuH@0rxAT1(@#Gf3X+XGg2Uw11r kh%xy*x9jFRh7(x8qWNYhx+e!%wjgUxumWqI{NAb&05AzKk^lez diff --git a/tests/data/reporting/spikes.h5 b/tests/data/reporting/spikes.h5 index 966f78a2647c11d89ab93f78bdd667874a1e90bd..a644b0989b4abe396a4eb5046447c9dee798b4d7 100644 GIT binary patch delta 50 ycmaE0{lI#I4-d26#uuCYd3?FRjQL`zEFi{Yf1bw47eoX$&ye232o{(x)(HR#=oE(l delta 50 ycmaE0{lI#I4-Yf{p7hQBJic6D#(c3<77$~yKTqT23nBuWXGrg11Pja;>jVJN0}?X; diff --git a/tests/test_frame_report.py b/tests/test_frame_report.py index 0b37879c..c17afd68 100644 --- a/tests/test_frame_report.py +++ b/tests/test_frame_report.py @@ -102,7 +102,10 @@ def test_filter(self): assert filtered.report.columns.tolist() == [("default2", 1, 0), ("default2", 1, 1)] filtered = self.test_obj.filter(group={"population": "default2"}, t_start=0.3, t_stop=0.6) - assert filtered.report.columns.tolist() == [("default2", 0, 0), ("default2", 0, 1), ("default2", 1, 0), ("default2", 1, 1), ("default2", 2, 0), ("default2", 2, 1)] + assert filtered.report.columns.tolist() == [("default2", 0, 0), ("default2", 0, 1), + ("default2", 1, 0), ("default2", 1, 1), + ("default2", 2, 0), ("default2", 2, 1), + ("default2", 2, 1)] filtered = self.test_obj.filter(group={"population": "default3"}, t_start=0.3, t_stop=0.6) pdt.assert_frame_equal(filtered.report, pd.DataFrame()) @@ -163,11 +166,9 @@ def setup(self): self.simulation = Simulation(str(TEST_DATA_DIR / 'simulation_config.json')) self.test_obj = test_module.CompartmentReport(self.simulation, "section_report")["default"] timestamps = np.linspace(0, 0.9, 10) - data = np.array([np.arange(6) + j * 0.1 for j in range(10)]) - - data = {(0, 0): data[:, 0], (0, 1): data[:, 1], (1, 0): data[:, 2], (1, 1): data[:, 3], - (2, 0): data[:, 4], (2, 1): data[:, 5]} - self.df = pd.DataFrame(data=data, index=timestamps) + data = np.array([np.arange(7) + j * 0.1 for j in range(10)], dtype=np.float32) + ids = [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (2, 1)] + self.df = pd.DataFrame(data=data, columns=pd.MultiIndex.from_tuples(ids), index=timestamps) def test__resolve(self): npt.assert_array_equal(self.test_obj._resolve({Cell.MTYPE: "L6_Y"}), [1, 2]) @@ -213,8 +214,6 @@ def test_get(self): pdt.assert_frame_equal( self.test_obj.get([2, 1], t_start=0.2, t_stop=0.8), self.df.iloc[2:9].loc[:, [1, 2]]) - pdt.assert_frame_equal(self.test_obj.get([0, 2], t_start=15), pd.DataFrame()) - pdt.assert_frame_equal( self.test_obj.get(group={Cell.MTYPE: "L6_Y"}, t_start=0.2, t_stop=0.8), self.df.iloc[2:9].loc[:, [1, 2]]) @@ -225,12 +224,28 @@ def test_get(self): pdt.assert_frame_equal( self.test_obj.get(group="Layer23"), self.df.loc[:, [0]]) + with pytest.raises(BluepySnapError): + self.test_obj.get(-1, t_start=0.2) + + with pytest.raises(BluepySnapError): + self.test_obj.get(0, t_start=-1) + + with pytest.raises(BluepySnapError): + self.test_obj.get([0, 2], t_start=15) + with pytest.raises(BluepySnapError): self.test_obj.get(4) + def test_get_partially_not_in_report(self): + with patch.object(self.test_obj.__class__, "_resolve", return_value=np.asarray([0, 4])): + pdt.assert_frame_equal(self.test_obj.get([0, 4]), self.df.loc[:, [0]]) + def test_get_not_in_report(self): with patch.object(self.test_obj.__class__, "_resolve", return_value=np.asarray([4])): - pdt.assert_frame_equal(self.test_obj.get(4), pd.DataFrame()) + pdt.assert_frame_equal(self.test_obj.get([4]), pd.DataFrame()) + + def test_node_ids(self): + npt.assert_array_equal(self.test_obj.node_ids, np.array(sorted([0, 1, 2]), dtype=np.int64)) class TestPopulationSomaReport(TestPopulationCompartmentReport): @@ -239,4 +254,4 @@ def setup(self): self.test_obj = test_module.SomaReport(self.simulation, "soma_report")["default"] timestamps = np.linspace(0, 0.9, 10) data = {0: timestamps, 1: timestamps + 1, 2: timestamps + 2} - self.df = pd.DataFrame(data=data, index=timestamps, columns=[0, 1, 2]) + self.df = pd.DataFrame(data=data, index=timestamps, columns=[0, 1, 2]).astype(np.float32) diff --git a/tests/test_morph.py b/tests/test_morph.py index a88b8672..a18ce01d 100644 --- a/tests/test_morph.py +++ b/tests/test_morph.py @@ -10,7 +10,6 @@ import bluepysnap.morph as test_module from bluepysnap.circuit import Circuit -from bluepysnap.nodes import NodeStorage from bluepysnap.sonata_constants import Node from bluepysnap.exceptions import BluepySnapError diff --git a/tests/test_spike_report.py b/tests/test_spike_report.py index b4985241..982c3897 100644 --- a/tests/test_spike_report.py +++ b/tests/test_spike_report.py @@ -17,6 +17,7 @@ def _create_series(node_ids, index, name="ids"): def _get_index(ids): return pd.Index(ids, name="times") + return pd.Series(node_ids, index=_get_index(index), name=name) @@ -132,8 +133,10 @@ def test_get(self): pdt.assert_series_equal(self.test_obj.get(2), _create_series([2, 2], [0.1, 0.7])) pdt.assert_series_equal(self.test_obj.get(0, t_start=1.), _create_series([0], [1.3])) pdt.assert_series_equal(self.test_obj.get(0, t_stop=1.), _create_series([0], [0.2])) - pdt.assert_series_equal(self.test_obj.get(0, t_start=1., t_stop=12), _create_series([0], [1.3])) - pdt.assert_series_equal(self.test_obj.get(0, t_start=0.1, t_stop=12), _create_series([0, 0], [0.2, 1.3])) + pdt.assert_series_equal(self.test_obj.get(0, t_start=1., t_stop=12), + _create_series([0], [1.3])) + pdt.assert_series_equal(self.test_obj.get(0, t_start=0.1, t_stop=12), + _create_series([0, 0], [0.2, 1.3])) pdt.assert_series_equal(self.test_obj.get([2, 0]), _create_series([2, 0, 2, 0], [0.1, 0.2, 0.7, 1.3])) @@ -157,8 +160,6 @@ def test_get(self): pdt.assert_series_equal(self.test_obj.get([2, 1], t_start=0.2, t_stop=0.8), _create_series([1, 2], [0.3, 0.7])) - pdt.assert_series_equal(self.test_obj.get([0, 2], t_start=12), _create_series([], [])) - pdt.assert_series_equal( self.test_obj.get(group={Cell.MTYPE: "L6_Y"}, t_start=0.2, t_stop=0.8), _create_series([1, 2], [0.3, 0.7])) @@ -169,6 +170,15 @@ def test_get(self): pdt.assert_series_equal( self.test_obj.get(group="Layer23"), _create_series([0, 0], [0.2, 1.3])) + with pytest.raises(BluepySnapError): + self.test_obj.get([-1], t_start=0.2) + + with pytest.raises(BluepySnapError): + self.test_obj.get([0, 2], t_start=-1) + + with pytest.raises(BluepySnapError): + self.test_obj.get([0, 2], t_start=12) + with pytest.raises(BluepySnapError): self.test_obj.get(4) @@ -188,3 +198,12 @@ def test_get2(self): def test_get_not_in_report(self, mock): pdt.assert_series_equal(self.test_obj.get(4), _create_series([], [])) + + @patch(test_module.__name__ + '.PopulationSpikeReport._resolve_nodes', + return_value=np.asarray([0, 4])) + def test_get_not_in_report(self, mock): + pdt.assert_series_equal(self.test_obj.get([0, 4]), + _create_series([0, 0], [0.2, 1.3])) + + def test_node_ids(self): + npt.assert_array_equal(self.test_obj.node_ids, np.array(sorted([0, 1, 2]), dtype=np.int64)) diff --git a/tests/test_utils.py b/tests/test_utils.py index 35a27e7b..744d7548 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -34,10 +34,6 @@ def test_roundrobin(): assert list(test_module.roundrobin(*a)) == [1, 4, 5, 2, 6, 3] -def test_fix_libsonata_empty_list(): - npt.assert_array_equal(test_module.fix_libsonata_empty_list(), np.array([-2])) - - def test_add_dynamic_prefix(): assert test_module.add_dynamic_prefix(["a", "b"]) == [DYNAMICS_PREFIX + "a", DYNAMICS_PREFIX + "b"]