diff --git a/glue/core/data.py b/glue/core/data.py index 745e8872a..f46a9d338 100644 --- a/glue/core/data.py +++ b/glue/core/data.py @@ -173,45 +173,165 @@ def remove_component(self, component_id): cid_other='cid_like') def join_on_key(self, other, cid, cid_other): """ - Create an *element* mapping to another dataset, by - joining on values of ComponentIDs in both datasets. - - This join allows any subsets defined on `other` to be - propagated to self. - - :param other: :class:`~glue.core.data.Data` to join with - :param cid: str or :class:`glue.core.component_id.ComponentID` in this dataset to use as a key - :param cid_other: ComponentID in the other dataset to use as a key - - :example: - - >>> d1 = Data(x=[1, 2, 3, 4, 5], k1=[0, 0, 1, 1, 2], label='d1') - >>> d2 = Data(y=[2, 4, 5, 8, 4], k2=[1, 3, 1, 2, 3], label='d2') - >>> d2.join_on_key(d1, 'k2', 'k1') - - >>> s = d1.new_subset() - >>> s.subset_state = d1.id['x'] > 2 - >>> s.to_mask() - array([False, False, True, True, True], dtype=bool) - - >>> s = d2.new_subset() - >>> s.subset_state = d1.id['x'] > 2 - >>> s.to_mask() - array([ True, False, True, True, False], dtype=bool) - - The subset state selects the last 3 items in d1. These have - key values k1 of 1 and 2. Thus, the selected items in d2 - are the elements where k2 = 1 or 2. - """ - _i1, _i2 = cid, cid_other - cid = self.find_component_id(cid) - cid_other = other.find_component_id(cid_other) - if cid is None: - raise ValueError("ComponentID not found in %s: %s" % - (self.label, _i1)) - if cid_other is None: - raise ValueError("ComponentID not found in %s: %s" % - (other.label, _i2)) + Create an *element* mapping to another dataset, by joining on values of + ComponentIDs in both datasets. + + This join allows any subsets defined on `other` to be propagated to + self. The different ways to call this method are described in the + **Examples** section below. + + Parameters + ---------- + other : :class:`~glue.core.data.Data` + Data object to join with + cid : str or :class:`~glue.core.component_id.ComponentID` or iterable + Component(s) in this dataset to use as a key + cid_other : str or :class:`~glue.core.component_id.ComponentID` or iterable + Component(s) in the other dataset to use as a key + + Examples + -------- + + There are several ways to use this function, depending on how many + components are passed to ``cid`` and ``cid_other``. + + **Joining on single components** + + First, one can specify a single component ID for both ``cid`` and + ``cid_other``: this is the standard mode, and joins one component from + one dataset to the other: + + >>> d1 = Data(x=[1, 2, 3, 4, 5], k1=[0, 0, 1, 1, 2], label='d1') + >>> d2 = Data(y=[2, 4, 5, 8, 4], k2=[1, 3, 1, 2, 3], label='d2') + >>> d2.join_on_key(d1, 'k2', 'k1') + + Selecting all values in ``d1`` where x is greater than 2 returns + the last three items as expected: + + >>> s = d1.new_subset() + >>> s.subset_state = d1.id['x'] > 2 + >>> s.to_mask() + array([False, False, True, True, True], dtype=bool) + + The linking was done between k1 and k2, and the values of + k1 for the last three items are 1 and 2 - this means that the + first, third, and fourth item in ``d2`` will then get selected, + since k2 has a value of either 1 or 2 for these items. + + >>> s = d2.new_subset() + >>> s.subset_state = d1.id['x'] > 2 + >>> s.to_mask() + array([ True, False, True, True, False], dtype=bool) + + **Joining on multiple components** + + .. note:: This mode is currently slow, and will be optimized + significantly in future. + + Next, one can specify several components for each dataset: in this + case, the number of components given should match for both datasets. + This causes items in both datasets to be linked when (and only when) + the set of keys match between the two datasets: + + >>> d1 = Data(x=[1, 2, 3, 5, 5], + ... y=[0, 0, 1, 1, 2], label='d1') + >>> d2 = Data(a=[2, 5, 5, 8, 4], + ... b=[1, 3, 2, 2, 3], label='d2') + >>> d2.join_on_key(d1, ('a', 'b'), ('x', 'y')) + + Selecting all items where x is 5 in ``d1`` in which x is a + component works as expected and selects the two last items:: + + >>> s = d1.new_subset() + >>> s.subset_state = d1.id['x'] == 5 + >>> s.to_mask() + array([False, False, False, True, True], dtype=bool) + + If we apply this selection to ``d2``, only items where a is 5 + and b is 2 will be selected: + + >>> s = d2.new_subset() + >>> s.subset_state = d1.id['x'] == 5 + >>> s.to_mask() + array([False, False, True, False, False], dtype=bool) + + and in particular, the second item (where a is 5 and b is 3) is not + selected. + + **One-to-many and many-to-one joining** + + Finally, you can specify one component in one dataset and multiple ones + in the other. In the case where one component is specified for this + dataset and multiple ones for the other dataset, then when an item + is selected in the other dataset, it will cause any item in the present + dataset which matches any of the keys in the other data to be selected: + + >>> d1 = Data(x=[1, 2, 3], label='d1') + >>> d2 = Data(a=[1, 1, 2], + ... b=[2, 3, 3], label='d2') + >>> d1.join_on_key(d2, 'x', ('a', 'b')) + + In this case, if we select all items in ``d2`` where a is 2, this + will select the third item: + + >>> s = d2.new_subset() + >>> s.subset_state = d2.id['a'] == 2 + >>> s.to_mask() + array([False, False, True], dtype=bool) + + Since we have joined the datasets using both a and b, we select + all items in ``d1`` where x is either the value or a or b + (2 or 3) which means we select the second and third item: + + >>> s = d1.new_subset() + >>> s.subset_state = d2.id['a'] == 2 + >>> s.to_mask() + array([False, True, True], dtype=bool) + + We can also join the datasets the other way around: + + >>> d1 = Data(x=[1, 2, 3], label='d1') + >>> d2 = Data(a=[1, 1, 2], + ... b=[2, 3, 3], label='d2') + >>> d2.join_on_key(d1, ('a', 'b'), 'x') + + In this case, selecting items in ``d1`` where x is 1 selects the + first item, as expected: + + >>> s = d1.new_subset() + >>> s.subset_state = d1.id['x'] == 1 + >>> s.to_mask() + array([ True, False, False], dtype=bool) + + This then causes any item in ``d2`` where either a or b are 1 + to be selected, i.e. the first two items: + + >>> s = d2.new_subset() + >>> s.subset_state = d1.id['x'] == 1 + >>> s.to_mask() + array([ True, True, False], dtype=bool) + """ + + # To make things easier, we transform all component inputs to a tuple + if isinstance(cid, six.string_types) or isinstance(cid, ComponentID): + cid = (cid,) + if isinstance(cid_other, six.string_types) or isinstance(cid_other, ComponentID): + cid_other = (cid_other,) + + if len(cid) > 1 and len(cid_other) > 1 and len(cid) != len(cid_other): + raise Exception("Either the number of components in the key join " + "sets should match, or one of the component sets " + "should contain a single component.") + + def get_component_id(data, name): + cid = data.find_component_id(name) + if cid is None: + raise ValueError("ComponentID not found in %s: %s" % + (data.label, name)) + return cid + + cid = tuple(get_component_id(self, name) for name in cid) + cid_other = tuple(get_component_id(other, name) for name in cid_other) self._key_joins[other] = (cid, cid_other) other._key_joins[self] = (cid_other, cid) diff --git a/glue/core/state.py b/glue/core/state.py index 2b2ae244b..f8f1f01bb 100644 --- a/glue/core/state.py +++ b/glue/core/state.py @@ -706,6 +706,25 @@ def _load_data_3(rec, context): for k, v0, v1 in rec['_key_joins']) +@saver(Data, version=4) +def _save_data_4(data, context): + result = _save_data_2(data, context) + def save_cid_tuple(cids): + return tuple(context.id(cid) for cid in cids) + result['_key_joins'] = [[context.id(k), save_cid_tuple(v0), save_cid_tuple(v1)] + for k, (v0, v1) in data._key_joins.items()] + return result + + +@loader(Data, version=4) +def _load_data_4(rec, context): + result = _load_data_2(rec, context) + yield result + def load_cid_tuple(cids): + return tuple(context.object(cid) for cid in cids) + result._key_joins = dict((context.object(k), (load_cid_tuple(v0), load_cid_tuple(v1))) + for k, v0, v1 in rec['_key_joins']) + @saver(ComponentID) def _save_component_id(cid, context): return dict(label=cid.label, hidden=cid.hidden) diff --git a/glue/core/subset.py b/glue/core/subset.py index cf1bf8791..32fddc180 100644 --- a/glue/core/subset.py +++ b/glue/core/subset.py @@ -168,9 +168,12 @@ def _to_index_list_join(self): return np.where(self._to_mask_join(None).flat)[0] def _to_mask_join(self, view): - """Conver the subset to a mask through an entity join - to another dataset. """ + """ + Convert the subset to a mask through an entity join to another + dataset. + """ for other, (cid1, cid2) in self.data._key_joins.items(): + if getattr(other, '_recursing', False): continue @@ -178,17 +181,66 @@ def _to_mask_join(self, view): self.data._recursing = True s2 = Subset(other) s2.subset_state = self.subset_state - key_right = s2.to_mask() + mask_right = s2.to_mask() except IncompatibleAttribute: continue finally: self.data._recursing = False - key_left = self.data[cid1, view] - result = np.in1d(key_left.ravel(), - other[cid2, key_right]) + if len(cid1) == 1 and len(cid2) == 1: + + key_left = self.data[cid1[0], view] + key_right = other[cid2[0], mask_right] + mask = np.in1d(key_left.ravel(), key_right.ravel()) + + return mask.reshape(key_left.shape) + + elif len(cid1) == len(cid2): + + key_left_all = [] + key_right_all = [] + + for cid1_i, cid2_i in zip(cid1, cid2): + key_left_all.append(self.data[cid1_i, view].ravel()) + key_right_all.append(other[cid2_i, mask_right].ravel()) + + # TODO: The following is slow because we are looping in Python. + # This could be made significantly faster by switching to + # C/Cython. + + key_left_all = zip(*key_left_all) + key_right_all = set(zip(*key_right_all)) + + result = [key in key_right_all for key in key_left_all] + result = np.array(result) - return result.reshape(key_left.shape) + return result.reshape(self.data[cid1_i, view].shape) + + elif len(cid1) == 1: + + key_left = self.data[cid1[0], view].ravel() + mask = np.zeros_like(key_left, dtype=bool) + for cid2_i in cid2: + key_right = other[cid2_i, mask_right].ravel() + mask |= np.in1d(key_left, key_right) + + return mask.reshape(self.data[cid1[0], view].shape) + + elif len(cid2) == 1: + + key_right = other[cid2[0], mask_right].ravel() + mask = np.zeros_like(self.data[cid1[0], view].ravel(), dtype=bool) + for cid1_i in cid1: + key_left = self.data[cid1_i, view].ravel() + mask |= np.in1d(key_left, key_right) + + return mask.reshape(self.data[cid1[0], view].shape) + + else: + + raise Exception("Either the number of components in the key join sets " + "should match, or one of the component sets should ", + "contain a single component.") raise IncompatibleAttribute @@ -773,6 +825,7 @@ def copy(self): return MaskSubsetState(self.mask, self.cids) def to_mask(self, data, view=None): + view = view or slice(None) # shortcut for data on the same pixel grid diff --git a/glue/core/tests/test_joins.py b/glue/core/tests/test_joins.py index 7f40af5a6..8f6a373c3 100644 --- a/glue/core/tests/test_joins.py +++ b/glue/core/tests/test_joins.py @@ -125,3 +125,73 @@ def test_clone(self): s.subset_state = y.id['x'] > 1 assert_array_equal(s.to_mask(), [False, True, True]) + + +def test_many_to_many(): + """ + Test the use of multiple keys to denote that combinations of components + have to match. + """ + + d1 = Data(x=[1, 2, 3, 5, 5], + y=[0, 0, 1, 1, 2], label='d1') + d2 = Data(a=[2, 5, 5, 8, 4], + b=[1, 3, 2, 2, 3], label='d2') + d2.join_on_key(d1, ('a', 'b'), ('x', 'y')) + + s = d1.new_subset() + s.subset_state = d1.id['x'] == 5 + assert_array_equal(s.to_mask(), [0, 0, 0, 1, 1]) + + s = d2.new_subset() + s.subset_state = d1.id['x'] == 5 + assert_array_equal(s.to_mask(), [0, 0, 1, 0, 0]) + + +def test_one_and_many(): + """ + Test the use of one-to-many keys or many-to-one key to indicate that any of + the components can match the other. + """ + + d1 = Data(x=[1, 2, 3], label='d1') + d2 = Data(a=[1, 1, 2], + b=[2, 3, 3], label='d2') + d1.join_on_key(d2, 'x', ('a', 'b')) + + s = d2.new_subset() + s.subset_state = d2.id['a'] == 2 + assert_array_equal(s.to_mask(), [0, 0, 1]) + + s = d1.new_subset() + s.subset_state = d2.id['a'] == 2 + assert_array_equal(s.to_mask(), [0, 1, 1]) + + d1 = Data(x=[1, 2, 3], label='d1') + d2 = Data(a=[1, 1, 2], + b=[2, 3, 3], label='d2') + d2.join_on_key(d1, ('a', 'b'), 'x') + + s = d1.new_subset() + s.subset_state = d1.id['x'] == 1 + assert_array_equal(s.to_mask(), [1, 0, 0]) + + s = d2.new_subset() + s.subset_state = d1.id['x'] == 1 + assert_array_equal(s.to_mask(), [1, 1, 0]) + + +def test_mismatch(): + + d1 = Data(x=[1, 1, 2], + y=[2, 3, 3], + z=[2, 3, 3], label='d1') + d2 = Data(a=[1, 1, 2], + b=[2, 3, 3], label='d2') + + with pytest.raises(Exception) as exc: + d1.join_on_key(d2, ('x', 'y', 'z'), ('a', 'b')) + assert exc.value.args[0] == ("Either the number of components in the key " + "join sets should match, or one of the " + "component sets should contain a single " + "component.")