From 47a0383775cc0799ba0b1f7f0e9a658c05c6ca44 Mon Sep 17 00:00:00 2001 From: Johannes Weytjens <34494702+jpweytjens@users.noreply.github.com> Date: Sun, 1 Dec 2019 15:56:25 +0100 Subject: [PATCH] Initialize Compare with (a list of) features (bugfix) (#124) --- recordlinkage/base.py | 11 ++++++++--- tests/test_compare.py | 20 ++++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/recordlinkage/base.py b/recordlinkage/base.py index aae7576c..7deca0d0 100755 --- a/recordlinkage/base.py +++ b/recordlinkage/base.py @@ -23,6 +23,7 @@ index_split, frame_indexing) from recordlinkage.types import (is_numpy_like, + is_list_like, is_pandas_2d_multiindex) from recordlinkage.measures import max_pairs from recordlinkage.utils import DeprecationHelper, LearningError @@ -285,7 +286,7 @@ def _dedup_index(self, df_a): pairs = pairs[pairs.codes[0] > pairs.codes[1]] except AttributeError: # backwards compat pandas <24 - pairs = pairs[pairs.labels[0] > pairs.labels[1]] + pairs = pairs[pairs.labels[0] > pairs.labels[1]] return pairs @@ -543,7 +544,6 @@ def __init__(self, features=[], n_jobs=1, indexing_type='label', else: self.n_jobs = n_jobs self.indexing_type = indexing_type # label of position - self.features = [] # logging self._i = 1 @@ -582,7 +582,10 @@ def add(self, model): A (list of) compare feature(s) from :mod:`recordlinkage.compare`. """ - self.features.append(model) + if isinstance(model, list): + self.features.extend(model) + else: + self.features.append(model) def compare_vectorized(self, comp_func, labels_left, labels_right, *args, **kwargs): @@ -766,6 +769,8 @@ def _union(self, objs, index=None, column_i=0): if isinstance(feat, tuple): if label is None: label = [None] * len(feat) + if not is_list_like(label): + label = (label,) partial_result = self._union( zip(feat, label), column_i=column_i) diff --git a/tests/test_compare.py b/tests/test_compare.py index 8938f3cf..91274d11 100644 --- a/tests/test_compare.py +++ b/tests/test_compare.py @@ -566,6 +566,26 @@ def test_indexing_types(self): pdt.assert_frame_equal(result_label, result_position) + def test_pass_list_of_features(self): + + from recordlinkage.compare import FrequencyA, VariableA, VariableB + + # setup datasets and record pairs + A = DataFrame({'col': ['abc', 'abc', 'abc', 'abc', 'abc']}) + B = DataFrame({'col': ['abc', 'abc', 'abc', 'abc', 'abc']}) + ix = MultiIndex.from_arrays([np.arange(5), np.arange(5)]) + + # test with label indexing type + + features = [ + VariableA('col', label='y1'), + VariableB('col', label='y2'), + FrequencyA('col', label='y3') + ] + comp_label = recordlinkage.Compare(features=features) + result_label = comp_label.compute(ix, A, B) + + assert list(result_label) == ["y1", "y2", "y3"] class TestCompareFeatures(TestData): def test_feature(self):