Skip to content

Commit

Permalink
Initialize Compare with (a list of) features (bugfix) (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpweytjens authored and J535D165 committed Dec 1, 2019
1 parent 244d0e7 commit 47a0383
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
11 changes: 8 additions & 3 deletions recordlinkage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
index_split,
frame_indexing)
from recordlinkage.types import (is_numpy_like,
is_list_like,
is_pandas_2d_multiindex)
from recordlinkage.measures import max_pairs
from recordlinkage.utils import DeprecationHelper, LearningError
Expand Down Expand Up @@ -285,7 +286,7 @@ def _dedup_index(self, df_a):
pairs = pairs[pairs.codes[0] > pairs.codes[1]]
except AttributeError:
# backwards compat pandas <24
pairs = pairs[pairs.labels[0] > pairs.labels[1]]
pairs = pairs[pairs.labels[0] > pairs.labels[1]]

return pairs

Expand Down Expand Up @@ -543,7 +544,6 @@ def __init__(self, features=[], n_jobs=1, indexing_type='label',
else:
self.n_jobs = n_jobs
self.indexing_type = indexing_type # label of position
self.features = []

# logging
self._i = 1
Expand Down Expand Up @@ -582,7 +582,10 @@ def add(self, model):
A (list of) compare feature(s) from
:mod:`recordlinkage.compare`.
"""
self.features.append(model)
if isinstance(model, list):
self.features.extend(model)
else:
self.features.append(model)

def compare_vectorized(self, comp_func, labels_left, labels_right,
*args, **kwargs):
Expand Down Expand Up @@ -766,6 +769,8 @@ def _union(self, objs, index=None, column_i=0):
if isinstance(feat, tuple):
if label is None:
label = [None] * len(feat)
if not is_list_like(label):
label = (label,)

partial_result = self._union(
zip(feat, label), column_i=column_i)
Expand Down
20 changes: 20 additions & 0 deletions tests/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,26 @@ def test_indexing_types(self):

pdt.assert_frame_equal(result_label, result_position)

def test_pass_list_of_features(self):

from recordlinkage.compare import FrequencyA, VariableA, VariableB

# setup datasets and record pairs
A = DataFrame({'col': ['abc', 'abc', 'abc', 'abc', 'abc']})
B = DataFrame({'col': ['abc', 'abc', 'abc', 'abc', 'abc']})
ix = MultiIndex.from_arrays([np.arange(5), np.arange(5)])

# test with label indexing type

features = [
VariableA('col', label='y1'),
VariableB('col', label='y2'),
FrequencyA('col', label='y3')
]
comp_label = recordlinkage.Compare(features=features)
result_label = comp_label.compute(ix, A, B)

assert list(result_label) == ["y1", "y2", "y3"]

class TestCompareFeatures(TestData):
def test_feature(self):
Expand Down

0 comments on commit 47a0383

Please sign in to comment.