Skip to content

Commit

Permalink
Generalise multiple groups validation of edges and nodes (#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
asanin-epfl authored Jul 17, 2020
1 parent c0c6663 commit e029b6f
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 46 deletions.
93 changes: 52 additions & 41 deletions bluepysnap/circuit_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import h5py
import six

from bluepysnap import BluepySnapError
from bluepysnap.config import Config

MAX_MISSING_FILES_DISPLAY = 10
Expand Down Expand Up @@ -199,6 +200,42 @@ def _get_model_template_file(model_template):
return parts[1] + '.' + parts[0]


def _get_population_groups(population_h5):
"""Get groups from an edge or node population."""
return [population_h5[name] for name in population_h5
if isinstance(population_h5[name], h5py.Group) and name.isdigit()]


def _get_group_size(group_h5):
"""Gets size of an edges or nodes group."""
for name in group_h5:
if isinstance(group_h5[name], h5py.Dataset):
return group_h5[name].shape[0]
raise BluepySnapError('Empty group {}'.format(group_h5))


def _check_multi_groups(group_id_h5, group_index_h5, population):
"""Checks multiple groups of nodes or edges population."""
group_id_h5 = group_id_h5[:]
group_index_h5 = group_index_h5[:]
if len(group_id_h5) != len(group_index_h5):
return [fatal('Population {} of {} has different sizes of "group_id" and "group_index"'.
format(population.name, population.file.filename))]
group_ids = np.unique(group_id_h5)
group_names = [_get_group_name(group).name for group in _get_population_groups(population)]
missing_groups = set(group_ids) - set(np.array(group_names, dtype=int))
if missing_groups:
return [fatal('Population {} of {} misses group(s): {}'.
format(population.name, population.file.filename, missing_groups))]
for group_id in group_ids:
group = population[str(group_id)]
max_id = group_index_h5[group_id_h5 == int(group_id)].max()
if _get_group_size(group) < max_id:
return [fatal('Group {} in file {} should have ids up to {}'.format(
_get_group_name(group, parents=1), population.file.filename, max_id))]
return []


def _check_bio_nodes_group(group, config):
"""Checks biophysical nodes group for errors.
Expand Down Expand Up @@ -303,17 +340,20 @@ def _check_nodes_population(nodes_dict, config):
if not nodes or len(nodes) == 0:
errors.append(fatal('No "nodes" in {}.'.format(nodes_file)))
return errors
if len(nodes.keys()) > 1:
required_datasets += ['node_group_id', 'node_group_index']
for population_name in nodes:
population = nodes[population_name]
groups = _get_population_groups(population)
if len(groups) > 1:
required_datasets += ['node_group_id', 'node_group_index']
missing_datasets = sorted(set(required_datasets) - set(population))
if missing_datasets:
errors.append(fatal('Population {} of {} misses datasets {}'.
format(population_name, nodes_file, missing_datasets)))
for name in population:
if isinstance(population[name], h5py.Group):
errors += _check_nodes_group(population[name], config)
elif 'node_group_id' in population:
errors += _check_multi_groups(
population['node_group_id'], population['node_group_index'], population)
for group in groups:
errors += _check_nodes_group(group, config)
return errors


Expand Down Expand Up @@ -435,16 +475,10 @@ def _check(indices, nodes_ds):
return errors


def _get_edge_population_groups(population):
"""Get groups from a edge population."""
return [name for name in population
if name != "indices" and isinstance(population[name], h5py.Group)]


def _check_edge_population_data(population, groups, nodes):
# pylint: disable=too-many-locals,too-many-return-statements,too-many-branches
def _check_edge_population_data(population, nodes):
errors = []
population_name = _get_group_name(population)
groups = _get_population_groups(population)
if len(groups) > 1:
errors.append(BbpError(Error.WARNING, 'Population {} of {} have multiple groups. '
'Cannot be read via bluepysnap or libsonata'.
Expand Down Expand Up @@ -473,38 +507,17 @@ def _check_edge_population_data(population, groups, nodes):
# no "edge_group_id", "edge_group_index" and only one group --> can use implicit ids
return errors

edge_group_ids = population["edge_group_id"][:]
edge_group_index = population["edge_group_index"][:]

if len(edge_group_ids) != len(edge_group_index):
return errors + [fatal('Population {} of {} "edge_group_id" and "edge_'
'group_index" of different sizes'.
format(population_name, population.file.filename))]

group_ids = np.unique(edge_group_ids)
missing_groups = set(group_ids) - set(np.array(groups, dtype=int))

if missing_groups:
return errors + [fatal('Population {} of {} misses group(s): {}'.
format(population_name, population.file.filename, missing_groups))]
for group_id in group_ids:
group = population[str(group_id)]
max_edge_id = edge_group_index[edge_group_ids == int(group_id)].max()
if group[list(group)[0]].shape[0] < max_edge_id:
errors.append(fatal('Group {} in file {} should have ids up to {}'.
format(_get_group_name(group, parents=1), population.file.filename,
max_edge_id)))

errors += _check_multi_groups(
population["edge_group_id"], population["edge_group_index"], population)
if 'source_node_id' in children_object_names:
errors += _check_edges_node_ids(population['source_node_id'], nodes)
if 'target_node_id' in children_object_names:
errors += _check_edges_node_ids(population['target_node_id'], nodes)
if 'indices' in children_object_names:
errors += _check_edges_indices(population)

for name in groups:
if isinstance(population[name], h5py.Group):
errors += _check_edges_group_bbp(population[name])
for group in groups:
errors += _check_edges_group_bbp(group)

return errors

Expand All @@ -528,11 +541,9 @@ def _check_edges_population(edges_dict, nodes):
return errors

for population_name in edges:

population_path = '/edges/' + population_name
population = h5f[population_path]
groups = _get_edge_population_groups(population)
errors += _check_edge_population_data(population, groups, nodes)
errors += _check_edge_population_data(population, nodes)

return errors

Expand Down
24 changes: 19 additions & 5 deletions tests/test_circuit_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import h5py

import six
from bluepysnap.exceptions import BluepySnapError
import bluepysnap.circuit_validation as test_module
from bluepysnap.circuit_validation import Error, BbpError
import numpy as np
import pytest

from utils import TEST_DATA_DIR, copy_circuit, edit_config

Expand All @@ -22,6 +24,15 @@ def test_error_comparison():
assert (err == 'hello') is False


def test_empty_group_size():
with copy_circuit() as (circuit_copy_path, config_copy_path):
nodes_file = circuit_copy_path / 'nodes.h5'
with h5py.File(nodes_file, 'r+') as h5f:
grp = h5f['nodes/default/'].create_group('3')
with pytest.raises(BluepySnapError):
test_module._get_group_size(grp)


def test_ok_circuit():
errors = test_module.validate(str(TEST_DATA_DIR / 'circuit_config.json'))
assert errors == []
Expand Down Expand Up @@ -139,13 +150,14 @@ def test_no_required_node_single_population_datasets():
format(nodes_file, ['node_type_id']))]


def test_no_required_node_multi_population_datasets():
required_datasets = ['node_type_id', 'node_group_id', 'node_group_index']
def test_no_required_node_multi_group_datasets():
required_datasets = ['node_group_id', 'node_group_index']
for ds in required_datasets:
with copy_circuit() as (circuit_copy_path, config_copy_path):
nodes_file = circuit_copy_path / 'nodes.h5'
with h5py.File(nodes_file, 'r+') as h5f:
del h5f['nodes/default/' + ds]
h5f.copy('nodes/default/0', 'nodes/default/1')
errors = test_module.validate(str(config_copy_path))
assert errors == [Error(Error.FATAL, 'Population default of {} misses datasets {}'.
format(nodes_file, [ds]))]
Expand Down Expand Up @@ -443,8 +455,8 @@ def test_edge_population_edge_group_different_length():
h5f.create_dataset('edges/default/edge_group_index', data=[0, 1, 2, 3, 4])
errors = test_module.validate(str(config_copy_path))
assert errors == [Error(Error.FATAL,
'Population default of {} "edge_group_id" and "edge_group_index" of different sizes'.
format(edges_file))]
'Population {} of {} has different sizes of "group_id" and "group_index"'.
format('/edges/default', edges_file))]


def test_edge_population_wrong_group_id():
Expand All @@ -454,7 +466,7 @@ def test_edge_population_wrong_group_id():
del h5f['edges/default/edge_group_id']
h5f.create_dataset('edges/default/edge_group_id', data=[0, 1, 0, 0])
errors = test_module.validate(str(config_copy_path))
assert errors == [Error(Error.FATAL, 'Population default of {} misses group(s): {}'.
assert errors == [Error(Error.FATAL, 'Population /edges/default of {} misses group(s): {}'.
format(edges_file, {1}))]


Expand Down Expand Up @@ -526,6 +538,8 @@ def test_no_edge_all_node_ids():
del h5f['nodes/default/0']
errors = test_module.validate(str(config_copy_path))
assert errors == [
Error(Error.FATAL, 'Population /nodes/default of {} misses group(s): {}'.
format(nodes_file, {0})),
Error(Error.FATAL,
'/edges/default/source_node_id does not have node ids in its node population'),
Error(Error.FATAL,
Expand Down

0 comments on commit e029b6f

Please sign in to comment.