diff --git a/src/hdmf/common/table.py b/src/hdmf/common/table.py index 5c2501ff5..8278619cb 100644 --- a/src/hdmf/common/table.py +++ b/src/hdmf/common/table.py @@ -85,6 +85,27 @@ class ElementIdentifiers(Data): def __init__(self, **kwargs): call_docval_func(super(ElementIdentifiers, self).__init__, kwargs) + @docval({'name': 'other', 'type': (Data, np.ndarray, list, tuple, int), + 'doc': 'List of ids to search for in this ElementIdentifer object'}, + rtype=np.ndarray, + returns='Array with the list of indices where the elements in the list where found.' + 'Note, the elements in the returned list are ordered in increasing index' + 'of the found elements, rather than in the order in which the elements' + 'where given for the search. Also the length of the result may be different from the length' + 'of the input array. E.g., if our ids are [1,2,3] and we are search for [3,1,5] the ' + 'result would be [0,2] and NOT [2,0,None]') + def __eq__(self, other): + """ + Given a list of ids return the indices in the ElementIdentifiers array where the + indices are found. + """ + # Determine the ids we want to find + search_ids = other if not isinstance(other, Data) else other.data + if isinstance(search_ids, int): + search_ids = [search_ids] + # Find all matching locations + return np.in1d(self.data, search_ids).nonzero()[0] + @register_class('DynamicTable') class DynamicTable(Container): diff --git a/tests/unit/common/test_table.py b/tests/unit/common/test_table.py index 0f47e542a..25f7cc295 100644 --- a/tests/unit/common/test_table.py +++ b/tests/unit/common/test_table.py @@ -198,6 +198,13 @@ def test_not_enforce_unique_id_error(self): except ValueError as e: self.fail("add row with non unique id raised error %s" % str(e)) + def test_bad_id_type_error(self): + table = self.with_spec() + with self.assertRaises(TypeError): + table.add_row(id=10.1, data={'foo': 1, 'bar': 10.0, 'baz': 'cat'}, enforce_unique_id=True) + with self.assertRaises(TypeError): + table.add_row(id='str', data={'foo': 1, 'bar': 10.0, 'baz': 'cat'}, enforce_unique_id=True) + def test_extra_columns(self): table = self.with_spec() @@ -226,6 +233,20 @@ def test_nd_array_to_df(self): index=pd.Index(name='id', data=[0, 1, 2])) pd.testing.assert_frame_equal(df, df2) + def test_id_search(self): + table = self.with_spec() + data = [{'foo': 1, 'bar': 10.0, 'baz': 'cat'}, + {'foo': 2, 'bar': 20.0, 'baz': 'dog'}, + {'foo': 3, 'bar': 30.0, 'baz': 'bird'}, + {'foo': 4, 'bar': 40.0, 'baz': 'fish'}, + {'foo': 5, 'bar': 50.0, 'baz': 'lizard'}] + for i in data: + table.add_row(i) + res = table[table.id == [2, 4]] + self.assertEqual(len(res), 2) + self.assertTupleEqual(res[0], (2, 3, 30.0, 'bird')) + self.assertTupleEqual(res[1], (4, 5, 50.0, 'lizard')) + class TestDynamicTableRoundTrip(base.TestMapRoundTrip): @@ -258,3 +279,50 @@ def test_from_dataframe(self): 'b': ['4', '5', '6'] }), 'test_table', table_description='the expected table', column_descriptions=coldesc) self.assertContainerEqual(expected, received) + + +class TestElementIdentifiers(unittest.TestCase): + + def test_identifier_search_single_list(self): + e = ElementIdentifiers('ids', [0, 1, 2, 3, 4]) + a = (e == [1]) + np.testing.assert_array_equal(a, [1]) + + def test_identifier_search_single_int(self): + e = ElementIdentifiers('ids', [0, 1, 2, 3, 4]) + a = (e == 2) + np.testing.assert_array_equal(a, [2]) + + def test_identifier_search_single_list_not_found(self): + e = ElementIdentifiers('ids', [0, 1, 2, 3, 4]) + a = (e == [10]) + np.testing.assert_array_equal(a, []) + + def test_identifier_search_single_int_not_found(self): + e = ElementIdentifiers('ids', [0, 1, 2, 3, 4]) + a = (e == 10) + np.testing.assert_array_equal(a, []) + + def test_identifier_search_single_list_all_match(self): + e = ElementIdentifiers('ids', [0, 1, 2, 3, 4]) + a = (e == [1, 2, 3]) + np.testing.assert_array_equal(a, [1, 2, 3]) + + def test_identifier_search_single_list_partial_match(self): + e = ElementIdentifiers('ids', [0, 1, 2, 3, 4]) + a = (e == [1, 2, 10]) + np.testing.assert_array_equal(a, [1, 2]) + a = (e == [-1, 2, 10]) + np.testing.assert_array_equal(a, [2, ]) + + def test_identifier_search_with_element_identifier(self): + e = ElementIdentifiers('ids', [0, 1, 2, 3, 4]) + a = (e == ElementIdentifiers('ids', [1, 2, 10])) + np.testing.assert_array_equal(a, [1, 2]) + + def test_identifier_search_with_bad_ids(self): + e = ElementIdentifiers('ids', [0, 1, 2, 3, 4]) + with self.assertRaises(TypeError): + _ = (e == 0.1) + with self.assertRaises(TypeError): + _ = (e == 'test')