Skip to content

Commit

Permalink
Remove libsonata hidden API and add Sonata error
Browse files Browse the repository at this point in the history
  • Loading branch information
tomdele committed Jul 7, 2020
1 parent 7870430 commit 0dd7b00
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 14 deletions.
12 changes: 7 additions & 5 deletions bluepysnap/frame_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pathlib2 import Path
import numpy as np
import pandas as pd
from libsonata import ElementReportReader
from libsonata import ElementReportReader, SonataError

import bluepysnap._plotting
from bluepysnap.exceptions import BluepySnapError
Expand Down Expand Up @@ -98,11 +98,13 @@ def get(self, group=None, t_start=None, t_stop=None):
Returns:
pandas.DataFrame: frame as columns indexed by timestamps.
"""
ids = [] if group is None else self._resolve(group).tolist()
t_start = -1 if t_start is None else t_start
t_stop = -1 if t_stop is None else t_stop
ids = self._resolve(group).tolist()

try:
view = self._frame_population.get(node_ids=ids, tstart=t_start, tstop=t_stop)
except SonataError as e:
raise BluepySnapError(e)

view = self._frame_population.get(node_ids=ids, tstart=t_start, tstop=t_stop)
if len(view.ids) == 0:
return pd.DataFrame()
res = pd.DataFrame(data=view.data,
Expand Down
12 changes: 7 additions & 5 deletions bluepysnap/spike_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
from cached_property import cached_property
import pandas as pd
import numpy as np
from libsonata import SpikeReader, SonataError

from bluepysnap.exceptions import BluepySnapError
from bluepysnap.utils import fix_libsonata_empty_list
import bluepysnap._plotting


def _get_reader(spike_report):
from libsonata import SpikeReader
path = str(Path(spike_report.config["output_dir"]) / spike_report.config["spikes_file"])
return SpikeReader(path)

Expand Down Expand Up @@ -101,12 +101,14 @@ def get(self, group=None, t_start=None, t_stop=None):
Returns:
pandas.Series: return spiking node_ids indexed by sorted spike time.
"""
node_ids = [] if group is None else self._resolve_nodes(group).tolist()
node_ids = self._resolve_nodes(group).tolist()

t_start = -1 if t_start is None else t_start
t_stop = -1 if t_stop is None else t_stop
series_name = "ids"
res = self._spike_population.get(node_ids=node_ids, tstart=t_start, tstop=t_stop)
try:
res = self._spike_population.get(node_ids=node_ids, tstart=t_start, tstop=t_stop)
except SonataError as e:
raise BluepySnapError(e)

if not res:
return pd.Series(data=[], index=pd.Index([], name="times"), name=series_name)

Expand Down
11 changes: 9 additions & 2 deletions tests/test_frame_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,6 @@ def test_get(self):
pdt.assert_frame_equal(
self.test_obj.get([2, 1], t_start=0.2, t_stop=0.8), self.df.iloc[2:9].loc[:, [1, 2]])

pdt.assert_frame_equal(self.test_obj.get([0, 2], t_start=15), pd.DataFrame())

pdt.assert_frame_equal(
self.test_obj.get(group={Cell.MTYPE: "L6_Y"}, t_start=0.2, t_stop=0.8),
self.df.iloc[2:9].loc[:, [1, 2]])
Expand All @@ -226,6 +224,15 @@ def test_get(self):
pdt.assert_frame_equal(
self.test_obj.get(group="Layer23"), self.df.loc[:, [0]])

with pytest.raises(BluepySnapError):
self.test_obj.get(-1, t_start=0.2)

with pytest.raises(BluepySnapError):
self.test_obj.get(0, t_start=-1)

with pytest.raises(BluepySnapError):
self.test_obj.get([0, 2], t_start=15)

with pytest.raises(BluepySnapError):
self.test_obj.get(4)

Expand Down
11 changes: 9 additions & 2 deletions tests/test_spike_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,6 @@ def test_get(self):
pdt.assert_series_equal(self.test_obj.get([2, 1], t_start=0.2, t_stop=0.8),
_create_series([1, 2], [0.3, 0.7]))

pdt.assert_series_equal(self.test_obj.get([0, 2], t_start=12), _create_series([], []))

pdt.assert_series_equal(
self.test_obj.get(group={Cell.MTYPE: "L6_Y"}, t_start=0.2, t_stop=0.8),
_create_series([1, 2], [0.3, 0.7]))
Expand All @@ -169,6 +167,15 @@ def test_get(self):
pdt.assert_series_equal(
self.test_obj.get(group="Layer23"), _create_series([0, 0], [0.2, 1.3]))

with pytest.raises(BluepySnapError):
self.test_obj.get([-1], t_start=0.2)

with pytest.raises(BluepySnapError):
self.test_obj.get([0, 2], t_start=-1)

with pytest.raises(BluepySnapError):
self.test_obj.get([0, 2], t_start=12)

with pytest.raises(BluepySnapError):
self.test_obj.get(4)

Expand Down

0 comments on commit 0dd7b00

Please sign in to comment.